-
Notifications
You must be signed in to change notification settings - Fork 0
/
NelderMead.hpp
100 lines (97 loc) · 2.48 KB
/
NelderMead.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/* Copyright (c) 2019 Evgeniy Vodolazskiy (waterlaz) */
#pragma once
#include <set>
#include <vector>
#include <assert.h>
template<typename D, typename R>
class NelderMead {
private:
class CompareFirst {
public:
int operator()(const std::pair<R, D>& a,
const std::pair<R, D>& b) const {
return a.first < b.first;
}
};
public:
std::multiset<std::pair<R, D>, CompareFirst > xs;
public:
virtual R f(const D& x) = 0;
R alpha = 1.0;
R gamma = 2.0;
R ro = 0.5;
R delta = 0.5;
int n;
const std::pair<R, D>& current(){
assert( !xs.empty() );
return *xs.begin();
}
void setSimplex(const std::vector<D>& _xs){
n =_xs.size()-1;
assert(n>=1);
for(auto && x:_xs){
xs.insert( std::pair<R, D>(f(x), x) );
}
}
void addToSimplex(const D& x){
xs.insert(std::pair<R, D>( f(x), x) );
n = xs.size()-1;
}
void iterate(){
assert( !xs.empty() );
auto worst = xs.end();
worst--;
//compute centroid:
auto it = xs.begin();
D c = it->second;
it++;
while( it!=worst ){
c += it->second;
it++;
}
c /= n;
auto secondWorst = worst;
secondWorst--;
auto best = xs.begin();
//reflection:
D r = c + alpha*( c - worst->second );
R rv = f(r);
if( rv<secondWorst->first && rv>=best->first ){
xs.erase( worst );
xs.insert( std::pair<R, D>(rv, r) );
return;
}
//expansion:
if( rv<best->first ){
D e = c + gamma*(r-c);
R ev = f(e);
xs.erase(worst);
if(ev<rv){
xs.insert( std::pair<R, D>(ev, e) );
} else {
xs.insert( std::pair<R, D>(rv, r) );
}
return;
}
//contraction:
D e = c + ro*(worst->second - c);
R ev = f(e);
if( ev<worst->first ){
xs.erase(worst);
xs.insert( std::pair<R, D>(ev, e) );
return;
}
//shrink:
it = xs.begin();
std::multiset<std::pair<R, D>, CompareFirst > ys;
ys.insert(*it);
it++;
while(it!=xs.end()){
D x = it->second;
x = best->second + delta*(x - best->second);
ys.insert(std::pair<R, D>(f(x), x));
it++;
}
xs = ys;
}
};