# First, a few installations for packages that aren't natively on Colab

In [None]:

try:
    import openmdao  # noqa: F401
except ImportError:
    # Development version of OpenMDAO required.
    !python -m pip install git+https://github.com/OpenMDAO/OpenMDAO

try:
    import dymos  # noqa: F401
except ImportError:
    # This particular branch of dymos is required for the timeseries
    # plots to display the S state correctly.
    !python -m pip install git+https://github.com/robfalck/dymos@seywald_doc

# Imports that we'll need

In [1]:
import scipy.sparse as sp
import numpy as np

import jax
import jax.numpy as jnp

import openmdao.api as om

import dymos as dm

jax.config.update("jax_enable_x64", True)

# Jax Functions

These functions are defined with Jax's math libraries to allow for automatic differentiation.  Unlike "typical" dymos Equations of Motion, we need the derivatives of the Zermelo ODE with respect to its inputs to get the time derivative of the S matrix.

In [2]:
def f_zermelo(x1, x2, c, u):
    x1_dot = jnp.cos(u) + c * x2
    x2_dot = jnp.sin(u)
    c_dot = jnp.zeros_like(c)
    return x1_dot, x2_dot, c_dot

# Define a wrapper for vectorized computation of outputs
def f_zermelo_wrapped(inputs):
    x1, x2, c, u = inputs.T
    return jnp.column_stack(f_zermelo(x1, x2, c, u))  # Use column_stack to handle batched outputs

# Compute the Jacobian of the function
jac_zermelo = jax.vmap(jax.jacobian(f_zermelo_wrapped), in_axes=(0,))  # Vectorized over the batch

Next is the definition of the ZermeloODE component for dymos.

Dymos is a pseudospectral optimal control tool built in OpenMDAO.
OpenMDAO allows the user to build complex systems using function-like blocks (Components) which can be arranged into more complex systems (Groups).

OpenMDAO allows for complex implicit behavior and will assemble the total derivatives correctly using the Implicit Function Theorem through a generalized framework called the Modular Architecture for Unified Derivatives.

Typically, components must define the partial derivatives of their outputs wrt their inputs, either using finite-difference/complex-step or analytically (user-defined) derivatives. We can do this by allowing the user to specify sparse partial derivatives in a `compute_partials` method, or with JVP/VJPs in a `compute_jacvec_prod` method.

The Desensitized Optimal Control functionality is a bit unique in that we need the derivatives of the ODE to generate rates for the sensitivity matrix. OpenMDAO does not currently do second derivatives.

Fortunately, we can use machine learning tools like Jax or PyTorch to get code that can be algorithmically/automatically differentiated.

In OpenMDAO (and you need a recent version to do this), `JaxExplicitComponent` allows you to define only the "primal" calculation. It will leverage Jax under the hood to do the compute partials for you.

Note that we still have a setup-partials method to provide the sparsity pattern of the derivatives (we can't do this automatically, yet).

So here our `compute_primal` method computes the rates of `x1`, `x2`, and `c` using the `f_zermelo` function.

Then we ask Jax to compute the derivatives
of that function wrt the states and controls, which we use to assemble $\dot{S}$.



In [3]:
class ZermeloODE(om.JaxExplicitComponent):

    def initialize(self):
        """
        All Dymos ODE systems are required to have an option "num_nodes",
        which is the number of points at which the ODE is simultaneously evaluated.

        This will be set by the Phase during setup once the transcription details are known.
        """
        self.options.declare('num_nodes', types=(int,))
        self.options.declare('matrix_free', types=(bool,), default=False)

    def setup(self):
        """
        In setup, we add inputs and outputs.

        The first dimension is assumed to pertain to the index of the node.

        An input that's a scalar at each node should have a shape of
        (num_nodes, 1) or just (num_nodes,).

        For vectors or matrices, it's just the shape of the matrix at each
        node prepended with num_nodes.

        We provide units for the scalars, but OpenMDAO doesn't do unit conversion on an index-by-index basis,
        so we just assume that no unit conversion should be done for the S matrix and K vector.
        """
        nn = self.options['num_nodes']

        self.matrix_free = self.options['matrix_free']

        # ODE inputs
        self.add_input('x1', shape=(nn,), units='m')
        self.add_input('x2', shape=(nn,), units='m')
        self.add_input('c', shape=(nn,), units='1/s')
        self.add_input('u', shape=(nn,), units='rad')
        self.add_input('S', shape=(nn, 3, 3), units=None)
        self.add_input('K', shape=(nn, 1, 3), units=None)

        # State rates
        self.add_output('x1_dot', shape=(nn,), units='m/s', tags=['dymos.state_rate_source:x1'])
        self.add_output('x2_dot', shape=(nn,), units='m/s', tags=['dymos.state_rate_source:x2'])
        self.add_output('c_dot', shape=(nn,), units='1/s**2', tags=['dymos.state_rate_source:c'])
        self.add_output('S_dot', shape=(nn, 3, 3), units='1/s', tags=['dymos.state_rate_source:S'])

    # because our compute primal output depends on static variables, in this case
    # and self.options['num_noswa'], we must define a get_self_statics method. This method must
    # return a tuple of all static variables. Their order in the tuple doesn't matter.  If your
    # component happens to have discrete inputs, do NOT return them here. Discrete inputs are passed
    # into the compute_primal function individually, after the continuous variables.
    def get_self_statics(self):
        # return value must be hashable
        return self.options['num_nodes'],

    def setup_partials(self):
        nn = self.options['num_nodes']
        ar = jnp.arange(nn, dtype=int)
        self.declare_partials('x1_dot', 'x2', rows=ar, cols=ar)
        self.declare_partials('x1_dot', 'c', rows=ar, cols=ar)
        self.declare_partials('x1_dot', 'u', rows=ar, cols=ar)
        self.declare_partials('x2_dot', 'u', rows=ar, cols=ar)

        # S_dot is 3x3 at each node, so we expact 9 scalar nonzeros at each node for each "scaler at each node" input
        rs = jnp.arange(nn * 3 * 3, dtype=int)
        cs = jnp.repeat(jnp.arange(nn, dtype=int), 3 * 3)
        self.declare_partials('S_dot', ['x1', 'x2', 'c', 'u'], rows=rs, cols=cs)

        # For S wrt S, we conservatively have a block diagonal of nn 9x9 blocks.
        rs, cs = sp.block_diag(nn * [np.ones((9, 9))]).nonzero()
        self.declare_partials('S_dot', 'S', rows=rs, cols=cs)

        # For S wrt S, we conservatively have a block diagonal of nn 9x9 blocks.
        rs, cs = sp.block_diag(nn * [np.ones((9, 3))]).nonzero()
        self.declare_partials('S_dot', 'K', rows=rs, cols=cs)

    def compute_primal(self, x1, x2, c, u, S, K):
        """
        This method does the "primal" computation in jax, and then OpenMDAO may
        differentiate it under-the-hood in order to get the partial derivatives
        through the component.

        Because we're using Jax's AD to get the derivatives here, everything within
        this method, and those functions it calls, need to be Jax-composed functions.
        """
        nn = self.options['num_nodes']
        vec_inputs = jnp.column_stack([x1, x2, c, u])

        # vec_outputs = f_zermelo_wrapped(vec_inputs)
        # x1_dot, x2_dot, c_dot = vec_outputs.T
        x1_dot, x2_dot, c_dot = f_zermelo(x1, x2, c, u)

        # The jacobian df_dxu is a 3x4 jacobian matrix at each node.
        df_dxu = jac_zermelo(vec_inputs).reshape((nn, 3, 4))

        df_dx = df_dxu[:, :, :-1] # Extract the sensitivities wrt x.
        df_du = df_dxu[:, :, -1:] # Extract the sensitivities wrt u.

        # Just matmult to get a matrix-matrix product at each node
        S_dot = jnp.matmul(df_dx, S) + jnp.matmul(df_du, K)

        return x1_dot, x2_dot, c_dot, S_dot

In [4]:
def solve_zermelo_open_loop(matrix_free=False, num_segments=10):

    # Create a standard OpenMDAO problem.
    p = om.Problem(name='zermelo_open_loop')

    # Trajectory is a special OpenMDAO Group defined by dymos.
    traj = dm.Trajectory()

    # The transcription, which defines how to convert the continuous optimal control
    # problem into a discrete NLP problem.
    # This contains things like informatino regarding the grid segmentation, and
    # defines what specific OpenMDAO systems are needed for the numerics.
    tx = dm.Radau(num_segments=num_segments, order=3)

    # Add a Phase to the trajectory.
    # Phase is also a special OpenMDAO Group defined by dymos which combines
    # the ODE system and the transcription.
    phase = traj.add_phase('phase', dm.Phase(ode_class=ZermeloODE,
                                             ode_init_kwargs={'matrix_free': matrix_free},
                                             transcription=tx))

    # Set the time options in the phase. (Whether the initial value and final value are fixed, the time units)
    phase.set_time_options(fix_initial=True, fix_duration=True, units='s')

    # The options for each of our state variables.
    # OpenMDAO does unit conversions internally.  The units defined here are the units
    # in which the design variables of the optimizer are defined.
    phase.set_state_options('x1', fix_initial=True, fix_final=False, units='m')
    phase.set_state_options('x2', fix_initial=True, fix_final=True, units='m')
    phase.set_state_options('c', fix_initial=True, fix_final=False, units='1/s')
    phase.set_state_options('S', fix_initial=True, fix_final=False, units=None)

    # Tell dymos that 'u' is a dynamic control variable. Control variable discretization node
    # values are design variables by default (option `opt=True``)
    phase.add_control('u', units='rad')
    phase.add_control('K', units=None)

    # Add the trajectory group to our OpenMDAO model.
    p.model.add_subsystem('traj', traj)

    # Add an optimization driver to the problem.
    # For this problem, ScipyOptimizeDriver (which uses scipy SLSQP by default) is fine.
    p.driver = om.ScipyOptimizeDriver()
    # p.driver = om.pyOptSparseDriver(optimizer='IPOPT')
    # p.driver.opt_settings['print_level'] = 5
    # p.driver.opt_settings['mu_init'] = 1e-3
    # # p.driver.opt_settings['max_iter'] = 500
    # # p.driver.opt_settings['acceptable_tol'] = 1e-5
    # # p.driver.opt_settings['constr_viol_tol'] = 1e-6
    # # p.driver.opt_settings['compl_inf_tol'] = 1e-6
    # p.driver.opt_settings['tol'] = 1e-5
    # p.driver.opt_settings['nlp_scaling_method'] = 'gradient-based'  # for faster convergence
    # # p.driver.opt_settings['alpha_for_y'] = 'safer-min-dual-infeas'
    # p.driver.opt_settings['mu_strategy'] = 'monotone'
    # p.driver.opt_settings['bound_mult_init_method'] = 'mu-based'
    
    # Use an automated graph "coloring" algorithm to determine how to most efficiently
    # compute the total derivatives for the optimizer.
    p.driver.declare_coloring()

    # Tell the phase that x1 is our objective, and that we are maximizing it.
    # ref is the value that the optimizer sees as 1.0.
    # A negative sign means we are minimizing the negative.
    phase.add_objective('x1', ref=-1.0)

    # Call setup, which is a bit analogous to "compiling" the model in OpenMDAO.
    # OpenMDAO will determine all connections for data in the model, allocate memory, etc.
    p.setup()

    # The set_***_val methods on dymos phases allow us to provide values
    # for times, states, controls, and parameters.
    # States and controls are interpolated, linearly by default, from the first value at the start
    # of the phase, to the second value at the end of the phase.
    # If we fixed the initial or final values, then these are the initial or final values that will
    # be set and not changed while solving the problem.
    phase.set_time_val(initial=0.0, duration=1.0)
    phase.set_state_val('x1', vals=[0.0, 1])
    phase.set_state_val('x2', vals=[0, 0])
    phase.set_state_val('c', vals=[10, 10])

    # S is initially an identity matrix at each node. We interpolate this to be constant throughout the phase
    # as the initial guess.
    S0 = jnp.eye(3)
    phase.set_state_val('S', [S0, S0])

    K = jnp.zeros((3,))
    phase.set_control_val('K', vals=[K, K])

    # Control u is interpolated as 0 across the phase as the initial guess.
    phase.set_control_val('u', vals=[0, 0])

    # The run_problem function in dymos runs an OpenMDAO problem that involves dymos.
    # The simulate argument means that we will explicitly integrate the final control solution
    # as an IVP so that we can see how the implicit solution compares to the explicit simulation.
    # Make plots will automatically generate plots in zermelo_doc_out/reports/traj_results.
    dm.run_problem(p, simulate=True, make_plots=True, refine_iteration_limit=10)

    return p

In [5]:
zermelo_prob = solve_zermelo_open_loop(num_segments=5, matrix_free=False)


--- Constraint Report [traj] ---
    --- phase ---
        None





Full total jacobian for problem 'zermelo_open_loop' was computed 3 times, taking 0.9715524578932673 seconds.
Total jacobian shape: (213, 239) 


Jacobian shape: (213, 239)  (5.55% nonzero)
FWD solves: 45   REV solves: 0
Total colors vs. total size: 45 vs 239  (81.17% improvement)

Sparsity computed using tolerance: 1e-25
Time to compute sparsity:   0.9716 sec
Time to compute coloring:   0.1239 sec
Memory to compute coloring:   0.7656 MB
Coloring created on: 2024-12-14 07:51:35
Optimization terminated successfully    (Exit mode 0)
            Current function value: -2.777269921301414
            Iterations: 31
            Function evaluations: 33
            Gradient evaluations: 31
Optimization Complete
-----------------------------------


          Grid Refinement - Iteration 1           
--------------------------------------------------
    Phase: traj.phases.phase
        Refinement Options:
            Allow Refinement = True
            Tolerance = 0.0001
            Min Order 



Full total jacobian for problem 'zermelo_open_loop' was computed 3 times, taking 1.2881130001042038 seconds.
Total jacobian shape: (393, 479) 


Jacobian shape: (393, 479)  (3.62% nonzero)
FWD solves: 0   REV solves: 78
Total colors vs. total size: 78 vs 393  (80.15% improvement)

Sparsity computed using tolerance: 1e-25
Time to compute sparsity:   1.2881 sec
Time to compute coloring:   0.3144 sec
Memory to compute coloring:   1.5312 MB
Coloring created on: 2024-12-14 07:51:40
Optimization terminated successfully    (Exit mode 0)
            Current function value: -2.78073348754754
            Iterations: 20
            Function evaluations: 21
            Gradient evaluations: 20
Optimization Complete
-----------------------------------


          Grid Refinement - Iteration 2           
--------------------------------------------------
    Phase: traj.phases.phase
        Refinement Options:
            Allow Refinement = True
            Tolerance = 0.0001
            Min Order =




Simulating trajectory traj
Done simulating trajectory traj


In [None]:
x1 = zermelo_prob.get_val('traj.phase.timeseries.x1')
x2 = zermelo_prob.get_val('traj.phase.timeseries.x2')
t = zermelo_prob.get_val('traj.phase.timeseries.time')
c = zermelo_prob.get_val('traj.phase.timeseries.c')
u = zermelo_prob.get_val('traj.phase.timeseries.u')
S = zermelo_prob.get_val('traj.phase.timeseries.S')

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, axes = plt.subplots(13, 1, figsize=(10, 15), sharex=True)
axes[0].plot(t, x1, 'o')
axes[0].grid()
axes[0].set_ylabel('x1')

axes[1].plot(t, x2, 'o')
axes[1].grid()
axes[1].set_ylabel('x2')

axes[2].plot(t, u, 'o')
axes[2].grid()
axes[2].set_ylabel('u')

axes[3].plot(t, c, 'o')
axes[3].grid()
axes[3].set_ylabel('c')

a_i = 4
for s_i in range(3):
    for s_j in range(3):
        axes[a_i].plot(t, S[:, s_i, s_j], 'o')
        axes[a_i].grid()
        axes[a_i].set_ylabel(f'S_[{s_i},{s_j}]')
        a_i += 1
plt.tight_layout()
plt.show()

S[-1, ...]


Term $S$ represents the sensitivity of the states x at any given point wrt the initial state.

\begin{align}
S &= \frac{dx}{dx_0}
\end{align}

Therefore at the final time,

\begin{align}
S(t_f) &= \frac{dx_f}{dx_0}
\end{align}

Inverting S(t), we get

\begin{align}
S^{-1}(t) &= \frac{dx_0}{dx}
\end{align}

Therefore, we can obtain the sensitivities due to perturbations
along the trajectory (not just the initial states), using

\begin{align}
S(t_f)S^{-1}(t) &= \frac{dx_f}{dx_0} \frac{dx_0}{dx} = \frac{dx_f}{dx}
\end{align}

We can integrate this term to obtain a total "drift" in a state due to uncertainties along the way.

Collocation techniques are particularly useful here because we always have an estimate for $S(t_f)$.

Furthermore, we don't need to treat $S(t_f)S^{-1}(t)$ as an integrated state. We have the S matrix at all points, which we can invert at all points, multiply by S at the final point, and then use an explicit quadrature to obtain its integral.

And for efficiency, we can determine the sensitivity $\frac{dx_f}{dx}$ using a linear solve rather than actually inverting the $S$ matrix at each node in the trajectory.

\begin{align}
    S(t_f)S^{-1}(t) &= \left[ \frac{dx_f}{dx} \right] \\
    S(t_f) &= \left[ \frac{dx_f}{dx} \right] S(t) \\
\end{align}

So therefore at each node $i$ in the trajectory we have the linear solve

\begin{align}
    \left[ \frac{dx_f}{dx_i} \right]^T = \mathrm{solve}(S_i^T, S_f^T)
\end{align}

# More Jax Functions for the composite objective.

In [None]:
# Need to update dymos to save the node weights in GridData
w = zermelo_prob.model.traj.phases.phase.options['transcription'].grid_data.node_weight

def quadrature(f, dt_dstau):
    _w = w.reshape((-1,) + (1,) * (len(f.shape) - 1))
    _dt_dstau = dt_dstau.reshape((-1,) + (1,) * (len(f.shape) - 1))
    return jnp.sum(_dt_dstau * _w * f, axis=0)

In [None]:
def solve_right(A, B):
    """
    Solves for X in batched systems X @ A = B.

    Parameters:
    A: jnp.ndarray, shape (batch_size, m, m)
        The square matrices on the right side of the equation.
    B: jnp.ndarray, shape (m, m)
        The right-hand side tensors.

    Returns:
    X: jnp.ndarray, shape (batch_size, m, m)
        The solution tensors.
    """
    B_T = jnp.transpose(B)
    def solve_single(A_single):
        A_T = jnp.transpose(A_single)
        X_T = jnp.linalg.solve(A_T, B_T)
        return jnp.transpose(X_T)

    # Use vmap over the batch dimension of A only
    solve_batched = jax.vmap(solve_single, in_axes=0, out_axes=0)
    return solve_batched(A)

In [None]:
class ObjectiveComp(om.JaxExplicitComponent):

    # def initialize(self):
    #     """
    #     All Dymos ODE systems are required to have an option "num_nodes",
    #     which is the number of points at which the ODE is simultaneously evaluated.

    #     This will be set by the Phase during setup once the transcription details are known.
    #     """
    #     self.options.declare('num_nodes', types=(int,))

    def setup(self):
        """
        In setup, we add inputs and outputs.

        The first dimension is assumed to pertain to the index of the node.

        An input that's a scalar at each node should have a shape of
        (num_nodes, 1) or just (num_nodes,).

        For vectors or matrices, it's just the shape of the matrix at each
        node prepended with num_nodes.

        We provide units for the scalars, but OpenMDAO doesn't do unit conversion on an index-by-index basis,
        so we just assume that no unit conversion should be done for the S matrix and K vector.
        """
        # nn = self.options['num_nodes']

        # Inputs
        self.add_input('q', val=0.0, units=None)
        self.add_input('r', val=1.0, units=None)
        self.add_input('S', shape_by_conn=True, units=None)
        self.add_input('K', shape_by_conn=True, units=None)
        self.add_input('x1_f', shape=(1,), units=None)
        self.add_input('dt_dstau', shape_by_conn=True, units=None)

        # Outputs
        self.add_output('J', val=1.0, units=None)

    # def get_self_statics(self):
    #     # return value must be hashable
    #     return self.options['num_nodes'], self.options['node_weight']

    def setup_partials(self):
        # J is a function of all inputs
        # J is a scalar so we don't need to specify the sparsity.
        self.declare_partials('J', ['*'])

    def compute_primal(self, q, r, S, K, x1_f, dt_dstau):
        """
        This method does the "primal" computation in jax, and then OpenMDAO may
        differentiate it under-the-hood in order to get the partial derivatives
        through the component.

        Because we're using Jax's AD to get the derivatives here, everything within
        this method, and those functions it calls, need to be Jax-composed functions.
        """
        S_f = S[-1, ...]
        dxf_dx = solve_right(S, S_f)

        J1 = quadrature(q * (dxf_dx[:, 0, 2][:, jnp.newaxis]**2 + dxf_dx[:, 1, 2][:, jnp.newaxis]**2),
                        dt_dstau)
        J2 = quadrature(r * (K[:, 0, 0][:, jnp.newaxis]**2 + K[:, 0, 1][:, jnp.newaxis]**2 + K[:, 0, 2][:, jnp.newaxis]**2),
                        dt_dstau)

        return -x1_f + J1 + J2

# Now lets build the model again with our objective in the loop.

In [None]:
def solve_zermelo_closed_loop(q=0.0, r=1.0, run_driver=False, opt_K=False, K_lb=-10, K_ub=0):

    # Create a standard OpenMDAO problem.
    p = om.Problem()

    # Trajectory is a special OpenMDAO Group defined by dymos.
    traj = dm.Trajectory()

    # The transcription, which defines how to convert the continuous optimal control
    # problem into a discrete NLP problem.
    # This contains things like informatino regarding the grid segmentation, and
    # defines what specific OpenMDAO systems are needed for the numerics.
    tx = dm.Radau(num_segments=10, order=3)

    # Add a Phase to the trajectory.
    # Phase is also a special OpenMDAO Group defined by dymos which combines
    # the ODE system and the transcription.
    phase = traj.add_phase('phase', dm.Phase(ode_class=ZermeloODE, transcription=tx))

    # Set the time options in the phase. (Whether the initial value and final value are fixed, the time units)
    phase.set_time_options(fix_initial=True, fix_duration=True, units='s')

    # The options for each of our state variables.
    # OpenMDAO does unit conversions internally.  The units defined here are the units
    # in which the design variables of the optimizer are defined.
    phase.set_state_options('x1', fix_initial=True, fix_final=False, units='m')
    phase.set_state_options('x2', fix_initial=True, fix_final=True, units='m')
    phase.set_state_options('c', fix_initial=True, fix_final=False, units='1/s')
    phase.set_state_options('S', fix_initial=True, fix_final=False, units=None)

    # Tell dymos that 'u' is a dynamic control variable. Control variable discretization node
    # values are design variables by default (option `opt=True``)
    phase.add_control('u', units='rad')
    phase.add_control('K', units=None, opt=opt_K, lower=K_lb, upper=K_ub)

    # Add the trajectory group to our OpenMDAO model.
    p.model.add_subsystem('traj', traj)

    # Add the objective calculation
    obj_comp = p.model.add_subsystem('obj_comp', ObjectiveComp())

    p.model.connect('traj.phase.timeseries.S', 'obj_comp.S')
    p.model.connect('traj.phase.timeseries.K', 'obj_comp.K')
    p.model.connect('traj.phase.timeseries.x1', 'obj_comp.x1_f', src_indices=om.slicer[-1, ...])
    p.model.connect('traj.phase.dt_dstau', 'obj_comp.dt_dstau')

    # Add an optimization driver to the problem.
    # For this problem, ScipyOptimizeDriver (which uses scipy SLSQP by default) is fine.
    p.driver = om.ScipyOptimizeDriver()
    # Use an automated graph "coloring" algorithm to determine how to most efficiently
    # compute the total derivatives for the optimizer.
    p.driver.declare_coloring()

    # Now we're no longer adding the objecitve to an output of the phase itself, but
    # a downstream calculation of the objective.
    obj_comp.add_objective('J', ref=-1.0)

    # Call setup, which is a bit analogous to "compiling" the model in OpenMDAO.
    # OpenMDAO will determine all connections for data in the model, allocate memory, etc.
    p.setup()

    # The set_***_val methods on dymos phases allow us to provide values
    # for times, states, controls, and parameters.
    # States and controls are interpolated, linearly by default, from the first value at the start
    # of the phase, to the second value at the end of the phase.
    # If we fixed the initial or final values, then these are the initial or final values that will
    # be set and not changed while solving the problem.
    phase.set_time_val(initial=0.0, duration=1.0)
    phase.set_state_val('x1', vals=[0.0, 1])
    phase.set_state_val('x2', vals=[0, 0])
    phase.set_state_val('c', vals=[10, 10])

    # S is initially an identity matrix at each node. We interpolate this to be constant throughout the phase
    # as the initial guess.
    S0 = jnp.eye(3)
    phase.set_state_val('S', [S0, S0])

    K = jnp.zeros((3,))
    phase.set_control_val('K', vals=[K, K])

    # Control u is interpolated as 0 across the phase as the initial guess.
    phase.set_control_val('u', vals=[0, 0])

    p.set_val('obj_comp.q', q)
    p.set_val('obj_comp.r', r)

    # The run_problem function in dymos runs an OpenMDAO problem that involves dymos.
    # The simulate argument means that we will explicitly integrate the final control solution
    # as an IVP so that we can see how the implicit solution compares to the explicit simulation.
    # Make plots will automatically generate plots in zermelo_doc_out/reports/traj_results.
    dm.run_problem(p, run_driver=run_driver, simulate=True, make_plots=True)
    # p.run_model()

    return p

In [None]:
restart = zermelo_prob.get_outputs_dir() / 'dymos_solution.db'
zermelo_closed_loop_prob = solve_zermelo_closed_loop(q=0, r=10, run_driver=True, opt_K=True, K_lb=-1.0E-1, K_ub=0.0)

In [None]:
zermelo_closed_loop_prob.get_reports_dir()

In [None]:
x1 = zermelo_prob.get_val('traj.phase.timeseries.x1')
x2 = zermelo_prob.get_val('traj.phase.timeseries.x2')
t = zermelo_prob.get_val('traj.phase.timeseries.time')
c = zermelo_prob.get_val('traj.phase.timeseries.c')
u = zermelo_prob.get_val('traj.phase.timeseries.u')
S = zermelo_prob.get_val('traj.phase.timeseries.S')

In [None]:
fig, axes = plt.subplots(13, 1, figsize=(10, 15), sharex=True)
axes[0].plot(t, x1, 'o')
axes[0].grid()
axes[0].set_ylabel('x1')

axes[1].plot(t, x2, 'o')
axes[1].grid()
axes[1].set_ylabel('x2')

axes[2].plot(t, u, 'o')
axes[2].grid()
axes[2].set_ylabel('u')

axes[3].plot(t, c, 'o')
axes[3].grid()
axes[3].set_ylabel('c')

a_i = 4
for s_i in range(3):
    for s_j in range(3):
        axes[a_i].plot(t, S[:, s_i, s_j], 'o')
        axes[a_i].grid()
        axes[a_i].set_ylabel(f'S_[{s_i},{s_j}]')
        a_i += 1
plt.tight_layout()
