Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 25 additions & 52 deletions examples/rosenbrock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#include <chrono>

// Here's the Rosenbrock banana function
FLT banana (FLT x, FLT y) {
FLT a = 1.0;
FLT b = 100.0;
FLT rtn = ((a-x)*(a-x)) + (b * (y-(x*x)) * (y-(x*x)));
template<typename T>
T banana (T x, T y)
{
constexpr T a = T{1};
constexpr T b = T{100};
T rtn = ((a - x) * (a - x)) + (b * (y - (x * x)) * (y - (x * x)));
return rtn;
}

Expand All @@ -30,14 +32,15 @@ int main()
mplot::Visual v(2600, 1800, "Rosenbrock bananas");
v.zNear = 0.001;
v.zFar = 100000;
v.fov=60;
v.fov = 60;
v.showCoordArrows (true);
v.lightingEffects (true);

// Initialise the vertices
sm::vvec<FLT> v1 = { 0.7, 0.0 };
sm::vvec<FLT> v2 = { 0.0, 0.6 };
sm::vvec<FLT> v3 = { -0.6, -1.0 };
sm::rand_uniform<FLT> rng(-3, 3);
sm::vvec<FLT> v1 = { rng.get(), rng.get() };
sm::vvec<FLT> v2 = { rng.get(), rng.get() };
sm::vvec<FLT> v3 = { rng.get(), rng.get() };
sm::vvec<sm::vvec<FLT>> i_vertices = { v1, v2, v3 };

// Add a 'triangle visual' to be visualised as three rods
Expand All @@ -58,15 +61,15 @@ int main()
auto tfvp = v.addVisualModel (tfv);

// Check banana function
FLT test = banana (1.0, 1.0);
FLT test = banana<FLT> (1.0, 1.0);
std::cout << "test point on banana function = " << test << " (should be 0).\n";

// Evaluate banana function and plot
sm::hexgrid hg (0.01, 10, 0);
hg.setCircularBoundary (2.5);
std::vector<FLT> banana_vals(hg.num(), 0.0f);
for (size_t i = 0; i < hg.num(); ++i) {
banana_vals[i] = banana (hg.d_x[i], hg.d_y[i]);
banana_vals[i] = banana<FLT> (hg.d_x[i], hg.d_y[i]);
}
sm::range<FLT> mm = sm::range<FLT>::get_from (banana_vals);
std::cout << "Banana surface range: " << mm << std::endl;
Expand All @@ -87,11 +90,7 @@ int main()
simp.termination_threshold = std::numeric_limits<FLT>::epsilon();
// You can prevent the algo getting stuck if termination_threshold is too small
simp.too_many_operations = 10000;

// Temporary variable
FLT val = FLT{0};

sm::rand_uniform<float> rng(-3, 3);
simp.objective = [](sm::vvec<FLT> x) { return banana<FLT>(x[0], x[1]); }; // objective defined as lambda

// This is the same as the NM_Simplex::run function, but it is reproduced here to *visualize*
// the Simplex as it descends the surface. For a more compact way to write your NM_Simplex, see
Expand All @@ -103,60 +102,34 @@ int main()
std::chrono::steady_clock::time_point lastrender = std::chrono::steady_clock::now();
std::chrono::steady_clock::time_point lastoptstep = std::chrono::steady_clock::now();

// Now do the business
unsigned int lcount = 0;
// Now step until the algorithm is ready to finish
while (simp.state != sm::nm_simplex_state::ReadyToStop && !v.readyToFinish()) {

// Perform optimisation steps slowly
std::chrono::steady_clock::duration sinceoptstep = std::chrono::steady_clock::now() - lastoptstep;
if (std::chrono::duration_cast<std::chrono::milliseconds>(sinceoptstep).count() > 50) {
lcount++;
if (simp.state == sm::nm_simplex_state::NeedToComputeThenOrder) {
// 1. apply objective to each vertex
for (unsigned int i = 0; i <= simp.n; ++i) {
simp.values[i] = banana (simp.vertices[i][0], simp.vertices[i][1]);
}
simp.order();

} else if (simp.state == sm::nm_simplex_state::NeedToOrder) {
simp.order();

} else if (simp.state == sm::nm_simplex_state::NeedToComputeReflection) {
val = banana (simp.xr[0], simp.xr[1]);
simp.apply_reflection (val);

} else if (simp.state == sm::nm_simplex_state::NeedToComputeExpansion) {
val = banana (simp.xe[0], simp.xe[1]);
simp.apply_expansion (val);

} else if (simp.state == sm::nm_simplex_state::NeedToComputeContraction) {
val = banana (simp.xc[0], simp.xc[1]);
simp.apply_contraction (val);
}
// Step the NM simplex optimizer process once
simp.step();
lastoptstep = std::chrono::steady_clock::now();
}

// Visualise the triangle defined by simp.vertices
// Copy data out from NM_Simplex
// Visualize at about 60 Hz
std::chrono::steady_clock::duration sincerender = std::chrono::steady_clock::now() - lastrender;
if (std::chrono::duration_cast<std::chrono::milliseconds>(sincerender).count() > 17) { // 17 is about 60 Hz
// Copy data out from NM_Simplex to update the triangle visualization
for (unsigned int i = 0; i <= simp.n; ++i) {
tri_coords[i] = { simp.vertices[i][0], simp.vertices[i][1], 0.0 };
tri_values[i] = simp.values[i];
}
tfvp->reinit();

lastoptstep = std::chrono::steady_clock::now();
}

std::chrono::steady_clock::duration sincerender = std::chrono::steady_clock::now() - lastrender;
if (std::chrono::duration_cast<std::chrono::milliseconds>(sincerender).count() > 17) { // 17 is about 60 Hz
v.poll();
v.render();
lastrender = std::chrono::steady_clock::now();
}
}
std::vector<FLT> thebest = simp.best_vertex();
FLT bestval = simp.best_value();
std::cout << "FINISHED! lcount=" << lcount
<< ". Best approximation: (" << thebest[0] << "," << thebest[1]
<< ") has value " << bestval << std::endl;
std::cout << "Finished in " << simp.operation_count << " operations. Best approximation at: ("
<< thebest[0] << "," << thebest[1] << ") has value " << simp.best_value() << std::endl;

// Randomly set the next start position
v1 = { rng.get(), rng.get() };
Expand Down
Loading