Skip to content

Commit

Permalink
add new test to test callback functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdsharpe committed Jan 18, 2021
1 parent b339105 commit 98487ce
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
7 changes: 6 additions & 1 deletion aerosandbox/optimization/opti.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import casadi as cas
from typing import Union, List, Dict
from typing import Union, List, Dict, Callable
import numpy as np
import pytest
import json
Expand Down Expand Up @@ -260,6 +260,7 @@ def get_solution_dict_from_cache(self):
def solve(self,
parameter_mapping: Dict[cas.MX, float] = None,
max_iter: int = 3000,
callback: Callable = None, # A Callable that takes in the argument opti
solver: str = 'ipopt'
) -> cas.OptiSol:
"""
Expand Down Expand Up @@ -332,6 +333,10 @@ def solve(self,
s_opts["mu_strategy"] = "adaptive"
self.solver(solver, p_opts, s_opts) # Default to IPOPT solver

# Set the callback
if callback is not None:
self.callback(callback)

# Do the actual solve
sol = super().solve()

Expand Down
83 changes: 83 additions & 0 deletions aerosandbox/optimization/test_opti_hanging_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import aerosandbox as asb
from aerosandbox import cas
import pytest
import matplotlib.pyplot as plt

"""
Hanging Chain problem from https://web.casadi.org/blog/opti/
Next, we will visit the hanging chain problem. We consider
N point masses, connected by springs, hung from two fixed points at (-2, 1) and (2, 1), subject to gravity.
We seek the rest position of the system, obtained by minimizing the total energy in the system.
"""


def test_opti_hanging_chain_with_callback():
N = 40
m = 40 / N
D = 70 * N
g = 9.81
L = 1

opti = asb.Opti()

x = opti.variable(
n_vars=N,
init_guess=cas.linspace(-2, 2, N)
)
y = opti.variable(
n_vars=N,
init_guess=1
)

distance = cas.sqrt( # Distance from one node to the next
cas.diff(x) ** 2 + cas.diff(y) ** 2
)

potential_energy_spring = 0.5 * D * cas.sumsqr(distance - L / N)
potential_energy_gravity = g * m * cas.sum1(y)
potential_energy = potential_energy_spring + potential_energy_gravity

opti.minimize(potential_energy)

# Add end point constraints
opti.subject_to([
x[0] == -2,
y[0] == 1,
x[-1] == 2,
y[-1] == 1
])

# Add a ground constraint
opti.subject_to(
y >= cas.cos(0.1 * x) - 0.5
)

# Add a callback

def plot(iter: int):
plt.plot(
opti.debug.value(x),
opti.debug.value(y),
".-",
label=f"Iter {iter}"
)

fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8), dpi=200)

sol = opti.solve(
callback=plot
)

plt.legend()
plt.show()

assert sol.value(potential_energy) == pytest.approx(626.462, abs=1e-3)


if __name__ == '__main__':
pytest.main()

0 comments on commit 98487ce

Please sign in to comment.