In [1]:
try:
    from openmdao.utils.notebook_utils import notebook_mode  # noqa: F401
except ImportError:
    !python -m pip install openmdao[notebooks]

# Computing Post-Optimality Sensitivities of a Constrained Optimization Problem

Lets consider a problem such that we have an active bound and an active inequality constraint.

\begin{align*}
\min_{\theta_0,\, \theta_1} \quad & f(\theta_0, \theta_1; \mathbf{p}) = (\theta_0 - p_0)^2 + \theta_0 \theta_1 + (\theta_1 + p_1)^2 - p_2 \\
\text{where} \quad \mathbf{p} &= \begin{bmatrix} 3 \\ 4 \\ 3 \end{bmatrix} \in \mathbb{R}^3 \\
\text{bounds:} \quad \theta_0 &\le 6 \\
\text{equality constraints:} \quad \theta_0 + \theta_1 &= 0
\end{align*}

We want to know the sensitivities of the optimization outputs with respect to the optimization inputs.

In this context, consider the outputs of the optimization to be the objective and any other functions of interest, $f$.

The design variables $\theta$ and Lagrange multipliers $\lambda$ are effectively the implicit outputs of the optimization.

The _inputs_ to the optimization process consists of:
- any independent parameters, $\bar{p}$
- the bounding values of any **active** design variables, $\bar{b}_{\theta}$
- the bounding values of any **active** constraints, $\bar{b}o_{g}$

In our case we have:

\begin{align*}
    \bar{p} &= \begin{bmatrix} p_0 \\ p_1 \\ p_2 \end{bmatrix} \\
    \bar{b}_{\theta} &= \begin{bmatrix} \theta_0^{ub} \end{bmatrix} \\
    \bar{b}_g &= \begin{bmatrix} g_0^{eq} \end{bmatrix}
\end{align*}

<!-- If active, we can treat the bound on $\theta_0$ as just another equality constraint.

\begin{align*}
  \bar{\mathcal{G}}(\bar{\theta}, \bar{p}) &= \begin{bmatrix}
                                   \theta_0 + \theta_1 \\
                                   \theta_0 - p_3
                                \end{bmatrix} = \bar 0
\end{align*}

**How will my system design ($\bar{\theta}^*$) respond to changes in my assumptions and system inputs ($\bar{p}$)?** -->

### The Universal Derivatives Equation

The UDE is:

\begin{align*}
  \left[ \frac{\partial \mathcal{R}}{\partial \mathcal{u}} \right] \left[ \frac{d u}{d \mathcal{R}} \right]
  &=
  \left[ I \right]
  =
  \left[ \frac{\partial \mathcal{R}}{\partial \mathcal{u}} \right]^T \left[ \frac{d u}{d \mathcal{R}} \right]^T\\
\end{align*}

Here, the residuals are the primal and dual residuals of the optimization process, given above.

## Applying the UDE to solving post-optimality sensitivities

In our case, the unknowns vector consists of
- the optimization parameters ($\bar{p}$)
- the bounding values of any active design variables ($\bar{b}_{\theta}$)
- the bounding values of any active constraints ($\bar{b}_{g}$)
- the design variables of the optimization ($\bar{\theta}$)
- the Lagrange multipliers associated with the active design variables ($\bar{\lambda}_{\theta}$)
- the Lagrange multipliers associated with the active constraints ($\bar{\lambda}_{g}$)
- the objective value **as well as** any other outputs for which we want the sensitivities ($f$)

The total size of the unknowns vector is $N_p + N_{\theta} + 2N_{\lambda \theta} + 2N_{\lambda g} + N_{f}$

\begin{align*}
  \hat{u} &=
  \begin{bmatrix}
    \bar{p} \\
    \bar{b}_{\theta} \\
    \bar{b}_{g} \\
    \bar{\theta} \\
    \bar{\lambda_{\theta}} \\
    \bar{\lambda_{g}} \\
    \bar{f}
  \end{bmatrix}
\end{align*}

Under the UDE, the corresponding residual equations for these unknowns are
- the implicit form of the parameter values
- the implicit form of the active design variable values
- the implicit form of the active constraint values
- the stationarity condition
- the active design variable residuals
- the active constraint residuals
- the implicit form of the explicit calculations of $f$

\begin{align*}
\bar{\mathcal{R}}
&=
\begin{bmatrix}
\bar{\mathcal{R}}_p \\
\bar{\mathcal{R}}_{b \theta} \\
\bar{\mathcal{R}}_{b g} \\
\bar{\mathcal{R}}_{\theta} \\
\bar{\mathcal{R}}_{\lambda \theta} \\
\bar{\mathcal{R}}_{\lambda g} \\
\bar{\mathcal{R}}_{f}
\end{bmatrix}
&=
\begin{bmatrix}
  \bar{p} - \check{p} \\[1.1ex]
  \hline \\
  \bar{b}_{\theta} - \check{b}_{\theta} \\[1.1ex]
  \hline \\
  \bar{b}_{g} - \check{b}_g \\[1.1ex]
  \hline \\
  \bar{r}_{\theta} - \left[ \nabla_{\bar{\theta}} \check{f} (\bar{\theta}, \bar{p}) + \nabla_{\bar{\theta}} \check{g}_{\mathcal{A}} (\bar{\theta}, \bar{p})^T \bar{\lambda}_g + \nabla_{\bar{\theta}} \check{\theta}_{\mathcal{A}} (\bar{\theta}, \bar{p})^T \bar{\lambda}_{\theta} \right] \\[1.1ex]
  \hline \\
  \bar{r}_{\lambda \theta} - \left[ \check{\theta}_{\mathcal{A}} \left( \bar{\theta} \right) - \bar{b}_{\theta} \right] \\[1.1ex]
  \hline \\
  \bar{r}_{\lambda g} - \left[ \check{g}_{\mathcal{A}} \left( \bar{\theta}, \bar{p} \right) - \bar{b}_g \right] \\[1.1ex]
  \hline \\
  \bar{f} - \check{f}\left(\bar{\theta}, \bar{p} \right) 
\end{bmatrix}
&= 
\begin{bmatrix}
  \bar{p} - \check{p} \\[1.1ex]
  \hline \\
  \bar{b}_{\theta} - \check{b}_{\theta} \\[1.1ex]
  \hline \\
  \bar{b}_{g} - \check{b}_g \\[1.1ex]
  \hline \\
  \bar{r}_{\theta} - \nabla_{\theta} \check{\mathcal{L}} \left( \bar{\theta}, \bar{p} \right) \\[1.1ex]
  \hline \\
  \bar{r}_{\lambda \theta} - \left[ \check{\theta}_{\mathcal{A}} \left( \bar{\theta} \right) - \bar{b}_{\theta} \right] \\[1.1ex]
  \hline \\
  \bar{r}_{\lambda g} - \left[ \check{g}_{\mathcal{A}} \left( \bar{\theta}, \bar{p} \right) - \bar{b}_g \right] \\[1.1ex]
  \hline \\
  \bar{f} - \check{f}\left(\bar{\theta}, \bar{p} \right) 
\end{bmatrix}
&=
\bar 0
\end{align*}

In order to find the total derivatives that we seek ($\frac{d f^*}{d \bar{p}}$ and $\frac{d \bar{\theta}^*}{d \bar{p}}$), we need $\frac{\partial \bar{\mathcal{R}}}{\partial \bar{u}}$.

The optimizer has served as the nonlinear solver in this case which has computed the values in the unknowns vector: $\bar{\theta}$, $\bar{\lambda}$, and $\bar{f}$ such that the residuals are satisfied.

\begin{align*}
\frac{\partial \bar{\mathcal{R}}}{\partial \bar{u}}
&=
\begin{bmatrix}
\frac{\partial \bar{\mathcal{R}_p}}{\partial \bar{p}} & 0 & 0 & 0 & 0 & 0 & 0 \\[1.1ex]
0 & \frac{\partial \bar{\mathcal{R}_{\bar{b} \theta}}}{\partial \bar{b}_{\theta}} & 0 & 0 & 0 & 0 & 0 \\[1.1ex]
0 & 0 & \frac{\partial \bar{\mathcal{R}_{\bar{b} g}}}{\partial \bar{b}_{g}} & 0 & 0 & 0 & 0 \\[1.1ex]
\frac{\partial \bar{\mathcal{R}_{\theta}}}{\partial \bar{p}} & 0 & 0 & \frac{\partial \bar{\mathcal{R}_{\theta}}}{\partial \bar{\theta}} & \frac{\partial \bar{\mathcal{R}_{\theta}}}{\partial \bar{\lambda_{\theta}}} & \frac{\partial \bar{\mathcal{R}_{\theta}}}{\partial \bar{\lambda_g}} & 0 \\[1.1ex]
0 & \frac{\partial \bar{\mathcal{R}_{\lambda \theta}}}{\partial \bar{b}_{\theta}} & 0 & \frac{\partial \bar{\mathcal{R}_{\lambda \theta}}}{\partial \bar{\theta}} & 0 & 0 & 0 \\[1.1ex]
\frac{\partial \bar{\mathcal{R}}_{\lambda g}}{\partial \bar{p}} & 0 & \frac{\partial \bar{\mathcal{R}}_{\lambda g}}{\partial \bar{b}_g} & \frac{\partial \bar{\mathcal{R}}_{\lambda g}}{\partial \bar{\theta}} & 0 & 0 & 0 \\[1.1ex]
\frac{\partial \bar{\mathcal{R}_f}}{\partial \bar{p}} & 0 & 0 & \frac{\partial \bar{\mathcal{R}_f}}{\partial \bar{\theta}} & 0 & 0 & \frac{\partial \bar{\mathcal{R}_f}}{\partial f}
\end{bmatrix}
&=
\begin{bmatrix}
\left[ I_p \right] & 0 & 0 & 0 & 0 & 0 & 0 \\[1.1ex]
0 & \left[ I_{b\theta} \right] & 0 & 0 & 0 & 0 & 0 \\[1.1ex]
0 & 0 & \left[ I_{bg} \right] & 0 & 0 & 0 & 0 \\[1.1ex]
-\frac{\partial \nabla_{\theta} \bar{\mathcal{L}}}{\partial \bar{p}} & 0 & 0 & - \nabla_{\theta}^2 \check{\mathcal{L}} & - \nabla_{\theta} \check{\theta}_{\mathcal{A}}^T & - \nabla_{\theta} \check{g}_{\mathcal{A}}^T & 0 \\[1.1ex]
0 & \left[ I_{b\theta} \right] & 0 & -\nabla_{\theta} \check{\theta}_{\mathcal{A}} & 0 & 0 & 0 \\[1.1ex]
-\frac{\partial \check{g}_{\mathcal{A}}}{\partial \bar{p}} & 0 & \left[ I_{bg} \right] & -\nabla_{\theta} \check{g}_{\mathcal{A}} & 0 & 0 & 0 \\[1.1ex]
-\frac{\partial \check{f}}{\partial \bar{p}} & 0 & 0 & -\frac{\partial \check{f}}{\partial \bar{\theta}} & 0 & 0 & \left[ I_f \right]
\end{bmatrix}
\end{align*}

This nomenclature can be a bit confusing.

**The _partial_ derivatives of the post-optimality residuals are the _total_ derivatives of the analysis.**

In this case of the stationarity residuals $\mathcal{R}_{\bar{\theta}}$, which already include _total_ derivatives of the analysis for the objective and constraint gradients, second derivatives are required.

The corresponding total sensitivities derivaties which we need to solve for are:

\begin{align}
\frac{d \bar{u}}{d \bar{\mathcal{R}}}
&=
\begin{bmatrix}
  \frac{d \bar{p}}{d \bar{\mathcal{R}_p}} &
  \frac{d \bar{p}}{d \bar{\mathcal{R}_{b\theta}}} &
  \frac{d \bar{p}}{d \bar{\mathcal{R}_{bg}}} &
  \frac{d \bar{p}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{p}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{p}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \frac{d \bar{p}}{d \bar{\mathcal{R}_f}}
\\[1.1ex]
  \frac{d \bar{b}_{\theta}}{d \bar{\mathcal{R}_p}} &
  \frac{d \bar{b}_{\theta}}{d \bar{\mathcal{R}_{b\theta}}} &
  \frac{d \bar{b}_{\theta}}{d \bar{\mathcal{R}_{bg}}} &
  \frac{d \bar{b}_{\theta}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{b}_{\theta}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{b}_{\theta}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \frac{d \bar{b}_{\theta}}{d \bar{\mathcal{R}_f}}
\\[1.1ex]
  \frac{d \bar{b}_{g}}{d \bar{\mathcal{R}_p}} &
  \frac{d \bar{b}_{g}}{d \bar{\mathcal{R}_{b\theta}}} &
  \frac{d \bar{b}_{g}}{d \bar{\mathcal{R}_{bg}}} &
  \frac{d \bar{b}_{g}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{b}_{g}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{b}_{g}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \frac{d \bar{b}_{g}}{d \bar{\mathcal{R}_f}}
\\[1.1ex]
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_p}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{b\theta}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{bg}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_f}}
\\[1.1ex]
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_p}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{b\theta}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{bg}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_f}}
\\[1.1ex]
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_p}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{b\theta}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{bg}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_f}}
\\[1.1ex]
  \frac{d \bar{f}}{d \bar{\mathcal{R}_p}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{b\theta}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{bg}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_f}}
\end{bmatrix}
&=
\begin{bmatrix}
  \left[ I_p \right] &
  0 &
  0 &
  0 &
  0 &
  0 &
  0
\\[1.1ex]
  0 &
  \left[ I_{b\theta} \right] &
  0 &
  0 &
  0 &
  0 &
  0
\\[1.1ex]
  0 &
  0 &
  \left[ I_{bg} \right] &
  0 &
  0 &
  0 &
  0
\\[1.1ex]
  \mathbf{\frac{d \bar{\theta}}{d \bar{p}}} &
  \mathbf{\frac{d \bar{\theta}}{d \bar{b_{\theta}}}} &
  \mathbf{\frac{d \bar{\theta}}{d \bar{b_g}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\lambda g}}} &
  0
\\[1.1ex]
  \frac{d \bar{\lambda_{\theta}}}{d \bar{p}} &
  \frac{d \bar{\lambda_{\theta}}}{d \bar{b_{\theta}}} &
  \frac{d \bar{\lambda_{\theta}}}{d \bar{b_g}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\lambda g}}} &
  0
\\[1.1ex]
  \frac{d \bar{\lambda_{g}}}{d \bar{p}} &
  \frac{d \bar{\lambda_{g}}}{d \bar{b_{\theta}}} &
  \frac{d \bar{\lambda_{g}}}{d \bar{b_g}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\lambda g}}} &
  0
\\[1.1ex]
  \mathbf{\frac{d \bar{f}}{d \bar{p}}} &
  \mathbf{\frac{d \bar{f}}{d \bar{b_{\theta}}}} &
  \mathbf{\frac{d \bar{f}}{d \bar{b_g}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \left[ I_f \right]
\end{bmatrix}
\end{align}

The sensitivities of the objective and the design variable values with respect to the parameters of the optimization are highlighted.

In this case, we can solve them with four linear solves of the forward system, or three solves of the reverse system.

TODO: Need to explain how du/dRf becomes du/df.

The UDE for this case, in forward form, is

\begin{align*}
\begin{bmatrix}
\left[ I_p \right] & 0 & 0 & 0 & 0 & 0 & 0 \\[1.1ex]
0 & \left[ I_{b\theta} \right] & 0 & 0 & 0 & 0 & 0 \\[1.1ex]
0 & 0 & \left[ I_{bg} \right] & 0 & 0 & 0 & 0 \\[1.1ex]
-\frac{\partial \nabla_{\theta} \bar{\mathcal{L}}}{\partial \bar{p}} & 0 & 0 & - \nabla_{\theta}^2 \check{\mathcal{L}} & - \nabla_{\theta} \check{\theta}_{\mathcal{A}}^T & - \nabla_{\theta} \check{g}_{\mathcal{A}}^T & 0 \\[1.1ex]
0 & \left[ I_{b\theta} \right] & 0 & -\nabla_{\theta} \check{\theta}_{\mathcal{A}} & 0 & 0 & 0 \\[1.1ex]
-\frac{\partial \check{g}_{\mathcal{A}}}{\partial \bar{p}} & 0 & \left[ I_{bg} \right] & -\nabla_{\theta} \check{g}_{\mathcal{A}} & 0 & 0 & 0 \\[1.1ex]
-\frac{\partial \check{f}}{\partial \bar{p}} & 0 & 0 & -\frac{\partial \check{f}}{\partial \bar{\theta}} & 0 & 0 & \left[ I_f \right]
\end{bmatrix}
\begin{bmatrix}
  \left[ I_p \right] &
  0 &
  0 &
  0 &
  0 &
  0 &
  0
\\[1.1ex]
  0 &
  \left[ I_{b\theta} \right] &
  0 &
  0 &
  0 &
  0 &
  0
\\[1.1ex]
  0 &
  0 &
  \left[ I_{bg} \right] &
  0 &
  0 &
  0 &
  0
\\[1.1ex]
  \mathbf{\frac{d \bar{\theta}}{d \bar{p}}} &
  \mathbf{\frac{d \bar{\theta}}{d \bar{b_{\theta}}}} &
  \mathbf{\frac{d \bar{\theta}}{d \bar{b_g}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\lambda g}}} &
  0
\\[1.1ex]
  \frac{d \bar{\lambda_{\theta}}}{d \bar{p}} &
  \frac{d \bar{\lambda_{\theta}}}{d \bar{b_{\theta}}} &
  \frac{d \bar{\lambda_{\theta}}}{d \bar{b_g}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\lambda g}}} &
  0
\\[1.1ex]
  \frac{d \bar{\lambda_{g}}}{d \bar{p}} &
  \frac{d \bar{\lambda_{g}}}{d \bar{b_{\theta}}} &
  \frac{d \bar{\lambda_{g}}}{d \bar{b_g}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\lambda g}}} &
  0
\\[1.1ex]
  \mathbf{\frac{d \bar{f}}{d \bar{p}}} &
  \mathbf{\frac{d \bar{f}}{d \bar{b_{\theta}}}} &
  \mathbf{\frac{d \bar{f}}{d \bar{b_g}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\theta}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\lambda \theta}}} &
  \frac{d \bar{f}}{d \bar{\mathcal{R}_{\lambda g}}} &
  \left[ I_f \right]
\end{bmatrix}
&=
\begin{bmatrix}
    \left[ I_p \right] & 0 & 0 & 0 & 0 & 0 & 0\\[1.1ex]
    0 & \left[ I_{b\theta} \right] & 0 & 0 & 0 & 0 & 0\\[1.1ex]
    0 & 0 & \left[ I_{bg} \right] & 0 & 0 & 0 & 0\\[1.1ex]
    0 & 0 & 0 & \left[ I_{\theta} \right] & 0 & 0 & 0\\[1.1ex]
    0 & 0 & 0 & 0 & \left[ I_{\lambda \theta} \right] & 0 & 0\\[1.1ex]
    0 & 0 & 0 & 0 & 0 & \left[ I_{\lambda g} \right] & 0\\[1.1ex]
    0 & 0 & 0 & 0 & 0 & 0 & \left[ I_f \right]
\end{bmatrix}
\end{align*}

The sensitivities of the objective and the design variable values with respect to the parameters of the optimization are highlighted.

The linear solve of this system can proceed one column of $\left[ \frac{d \bar{u}}{d \mathcal{R}} \right]$ at a time (the forward solve).

In this case we would need one solve for each column ($N_p$ + $N_{b\theta}$ + $N_b{g}$).

In our example optimization with three parameters, an active bound and an active equality constraint, this would be five solves.


Alternatively, we could transpose this system and solve it in reverse mode. Solving for one column at a time in the transposed system would mean solving once for each design variable and each output of interest. In our example optimziation this would be three solves.

\begin{align*}
\begin{bmatrix}
\left[ I_p \right]                                    & 0                    & 0                    & -\frac{\partial \nabla_{\theta} \bar{\mathcal{L}}}{\partial \bar{p}}^T & 0                    & -\frac{\partial \check{g}_{\mathcal{A}}}{\partial \bar{p}}^T & -\frac{\partial \check{f}}{\partial \bar{p}}^T \\[1.1ex]
0                                                     & \left[ I_{b\theta} \right] & 0                    & 0                                                                & \left[ I_{b\theta} \right] & 0                                                    & 0 \\[1.1ex]
0                                                     & 0                    & \left[ I_{bg} \right]      & 0                                                                & 0                    & \left[ I_{bg} \right]                                      & 0 \\[1.1ex]
0                                                     & 0                    & 0                    & - \nabla_{\theta}^2 \check{\mathcal{L}}                        & -\nabla_{\theta} \check{\theta}_{\mathcal{A}}  & -\nabla_{\theta} \check{g}_{\mathcal{A}}                & -\frac{\partial \check{f}}{\partial \bar{\theta}}^T \\[1.1ex]
0                                                     & 0                    & 0                    & - \nabla_{\theta} \check{\theta}_{\mathcal{A}}^T               & 0                    & 0                                                    & 0 \\[1.1ex]
0                                                     & 0                    & 0                    & - \nabla_{\theta} \check{g}_{\mathcal{A}}^T                    & 0                    & 0                                                    & 0 \\[1.1ex]
0                                                     & 0                    & 0                    & 0                                                                & 0                    & 0                                                    & \left[ I_f \right]
\end{bmatrix}
\begin{bmatrix}
\left[ I_p \right]                                    & 0                         & 0                       & \mathbf{\frac{d \bar{\theta}}{d \bar{p}}}^T                    & \frac{d \bar{\lambda_{\theta}}}{d \bar{p}}^T              & \frac{d \bar{\lambda_{g}}}{d \bar{p}}^T                 & \mathbf{\frac{d \bar{f}}{d \bar{p}}}^T \\[1.1ex]
0                                                     & \left[ I_{b\theta} \right]      & 0                       & \mathbf{\frac{d \bar{\theta}}{d \bar{b_{\theta}}}}^T           & \frac{d \bar{\lambda_{\theta}}}{d \bar{b_{\theta}}}^T     & \frac{d \bar{\lambda_{g}}}{d \bar{b_{\theta}}}^T        & \mathbf{\frac{d \bar{f}}{d \bar{b_{\theta}}}}^T \\[1.1ex]
0                                                     & 0                         & \left[ I_{bg} \right]         & \mathbf{\frac{d \bar{\theta}}{d \bar{b_g}}}^T                  & \frac{d \bar{\lambda_{\theta}}}{d \bar{b_g}}^T            & \frac{d \bar{\lambda_{g}}}{d \bar{b_g}}^T                & \mathbf{\frac{d \bar{f}}{d \bar{b_g}}}^T \\[1.1ex]
0                                                     & 0                         & 0                       & \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\theta}}}^T          & \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\theta}}}^T    & \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\theta}}}^T       & \frac{d \bar{f}}{d \bar{\mathcal{R}_{\theta}}}^T \\[1.1ex]
0                                                     & 0                         & 0                       & \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\lambda \theta}}}^T  & \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\lambda \theta}}}^T & \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\lambda \theta}}}^T & \frac{d \bar{f}}{d \bar{\mathcal{R}_{\lambda \theta}}}^T \\[1.1ex]
0                                                     & 0                         & 0                       & \frac{d \bar{\theta}}{d \bar{\mathcal{R}_{\lambda g}}}^T       & \frac{d \bar{\lambda \theta}}{d \bar{\mathcal{R}_{\lambda g}}}^T    & \frac{d \bar{\lambda g}}{d \bar{\mathcal{R}_{\lambda g}}}^T       & \frac{d \bar{f}}{d \bar{\mathcal{R}_{\lambda g}}}^T \\[1.1ex]
0                                                     & 0                         & 0                       & 0                                                                & 0                                                       & 0                                                         & \left[ I_f \right]
\end{bmatrix}
&=
\begin{bmatrix}
    \left[ I_p \right] & 0 & 0 & 0 & 0 & 0 & 0\\[1.1ex]
    0 & \left[ I_{b\theta} \right] & 0 & 0 & 0 & 0 & 0\\[1.1ex]
    0 & 0 & \left[ I_{bg} \right] & 0 & 0 & 0 & 0\\[1.1ex]
    0 & 0 & 0 & \left[ I_{\theta} \right] & 0 & 0 & 0\\[1.1ex]
    0 & 0 & 0 & 0 & \left[ I_{\lambda \theta} \right] & 0 & 0\\[1.1ex]
    0 & 0 & 0 & 0 & 0 & \left[ I_{\lambda g} \right] & 0\\[1.1ex]
    0 & 0 & 0 & 0 & 0 & 0 & \left[ I_f \right]
\end{bmatrix}
\end{align*}


## The All-JVP Way

In [68]:
"""
Post-optimality sensitivity analysis using JAX and the UDE approach
with Jacobian-Vector Products (JVPs) - forward mode sensitivity computation.
"""

import jax
import jax.numpy as jnp
import numpy as np
from scipy.sparse.linalg import LinearOperator, gmres
from functools import partial

# Define the optimization problem functions
def f(Θ, p):
    """Objective function"""
    f_val = (Θ[0] - p[0])**2 + Θ[0] * Θ[1] + (Θ[1] + p[1])**2 - p[2]
    return jnp.array([f_val])

def Θ_active(Θ):
    """Active bound constraint on Θ[0]"""
    return jnp.array([Θ[0]])

def g_active(Θ, p):
    """Active equality constraint"""
    return jnp.array([Θ[0] + Θ[1]])

def compute_sensitivities_jax():
    """
    Compute post-optimality sensitivities using JAX and the UDE approach.
    Uses Jacobian-Vector Products (JVPs) for forward mode sensitivity computation.
    Solves: ∂R/∂u @ (du/dp) = -∂R/∂p
    """

    # Known optimal solution
    Θ_opt = jnp.array([6.0, -6.0])
    p_opt = jnp.array([3.0, 4.0, 3.0])
    b_theta_opt = jnp.array([6.0])
    b_g_opt = jnp.array([0.0])
    λ_theta_opt = jnp.array([2.0])
    λ_g_opt = jnp.array([-2.0])

    # Problem dimensions
    n_p = 3  # parameters
    n_btheta = 1  # active bounds on design vars
    n_bg = 1  # active constraint bounds
    n_theta = 2  # design variables
    n_lambda_theta = 1  # multipliers for active bounds
    n_lambda_g = 1  # multipliers for active constraints
    n_f = 1  # original objective

    total_size = n_p + n_btheta + n_bg + n_theta + n_lambda_theta + n_lambda_g + n_f

    # Define Lagrangian and its gradient
    def lagrangian(Θ, p, λ_theta, λ_g):
        """Lagrangian function"""
        return f(Θ, p)[0] + λ_theta @ Θ_active(Θ) + λ_g @ g_active(Θ, p)

    def lagrangian_grad(Θ, p, λ_theta, λ_g):
        """Gradient of Lagrangian w.r.t. Θ"""
        return jax.grad(lagrangian, argnums=0)(Θ, p, λ_theta, λ_g)

    # Function to compute Hessian-vector product using finite differences
    def hessian_vector_product(v_theta, h=1e-8):
        """
        Compute H @ v where H is the Hessian of the Lagrangian w.r.t. Θ
        Uses finite differences: H @ v ≈ (∇L(Θ + h*v) - ∇L(Θ - h*v)) / (2*h)
        """
        if not np.any(v_theta):
            return np.zeros(n_theta)

        v_theta_jax = jnp.array(v_theta)

        # Compute gradients at perturbed points
        grad_plus = lagrangian_grad(Θ_opt + h * v_theta_jax, p_opt, λ_theta_opt, λ_g_opt)
        grad_minus = lagrangian_grad(Θ_opt - h * v_theta_jax, p_opt, λ_theta_opt, λ_g_opt)

        # Finite difference approximation of Hessian-vector product
        hvp = (grad_plus - grad_minus) / (2 * h)

        return np.array(hvp)

    # JVP functions for matrix-vector products
    def jvp_d2L_dtheta_dp(v_p):
        """
        Compute [∂²L/∂θ∂p] @ v_p using JVP
        """
        if not np.any(v_p):
            return np.zeros(n_theta)

        v_p_jax = jnp.array(v_p)

        # JVP of the gradient function w.r.t. p
        _, jvp_result = jax.jvp(lambda p: lagrangian_grad(Θ_opt, p, λ_theta_opt, λ_g_opt),
                                (p_opt,), (v_p_jax,))

        return np.array(jvp_result)

    def jvp_dg_active_dp(v_p):
        """
        Compute [∂g_active/∂p] @ v_p using JVP
        """
        if not np.any(v_p):
            return np.zeros(n_lambda_g)

        v_p_jax = jnp.array(v_p)

        # JVP of g_active w.r.t. p
        _, jvp_result = jax.jvp(lambda p: g_active(Θ_opt, p), (p_opt,), (v_p_jax,))

        return np.array(jvp_result)

    def jvp_df_dp(v_p):
        """
        Compute [∂f/∂p] @ v_p using JVP
        """
        if not np.any(v_p):
            return np.zeros(n_f)

        v_p_jax = jnp.array(v_p)

        # JVP of f w.r.t. p
        _, jvp_result = jax.jvp(lambda p: f(Θ_opt, p), (p_opt,), (v_p_jax,))

        return np.array(jvp_result)

    def jvp_dtheta_active_dtheta(v_theta):
        """
        Compute [∂Θ_active/∂θ] @ v_theta using JVP
        """
        if not np.any(v_theta):
            return np.zeros(n_lambda_theta)

        v_theta_jax = jnp.array(v_theta)

        _, jvp_result = jax.jvp(Θ_active, (Θ_opt,), (v_theta_jax,))
        return np.array(jvp_result)

    def jvp_dg_active_dtheta(v_theta):
        """
        Compute [∂g_active/∂θ] @ v_theta using JVP
        """
        if not np.any(v_theta):
            return np.zeros(n_lambda_g)

        v_theta_jax = jnp.array(v_theta)

        _, jvp_result = jax.jvp(lambda theta: g_active(theta, p_opt), (Θ_opt,), (v_theta_jax,))
        return np.array(jvp_result)

    def jvp_df_dtheta(v_theta):
        """
        Compute [∂f/∂θ] @ v_theta using JVP
        """
        if not np.any(v_theta):
            return np.zeros(n_f)

        v_theta_jax = jnp.array(v_theta)

        _, jvp_result = jax.jvp(lambda theta: f(theta, p_opt), (Θ_opt,), (v_theta_jax,))
        return np.array(jvp_result)

    def matvec_forward(v):
        """
        Compute matrix-vector product for [∂R/∂u] @ v
        This implements the forward UDE Jacobian matrix.
        Uses JVPs for all matrix-vector operations.
        """
        # Split v into blocks
        idx = 0
        v_p = v[idx:idx+n_p]
        idx += n_p
        v_btheta = v[idx:idx+n_btheta]
        idx += n_btheta
        v_bg = v[idx:idx+n_bg]
        idx += n_bg
        v_theta = v[idx:idx+n_theta]
        idx += n_theta
        v_lambda_theta = v[idx:idx+n_lambda_theta]
        idx += n_lambda_theta
        v_lambda_g = v[idx:idx+n_lambda_g]
        idx += n_lambda_g
        v_f = v[idx:idx+n_f]

        # print(f"Input vector norm: {np.linalg.norm(v):.6f}")

        # Initialize result
        result = np.zeros(total_size)

        # Block 1: Parameter residual equation
        # R_p = p - (actual parameters) = 0
        # ∂R_p/∂u @ v = I @ v_p = v_p
        result[:n_p] = v_p

        # Block 2: Active bound residual equation
        # R_btheta = b_theta - (actual bounds) = 0
        # ∂R_btheta/∂u @ v = I @ v_btheta = v_btheta
        idx = n_p
        result[idx:idx+n_btheta] = v_btheta

        # Block 3: Active constraint bound residual equation
        # R_bg = b_g - (actual constraint bounds) = 0
        # ∂R_bg/∂u @ v = I @ v_bg = v_bg
        idx += n_btheta
        result[idx:idx+n_bg] = v_bg

        # Block 4: Optimality condition (stationarity)
        # R_theta = ∇_θ L = 0
        # ∂R_theta/∂u @ v = ∂²L/∂θ∂p @ v_p + H @ v_theta + [∂Θ_active/∂θ]^T @ v_lambda_theta + [∂g_active/∂θ]^T @ v_lambda_g
        idx += n_bg
        result_theta = np.zeros(n_theta)
        result_theta += jvp_d2L_dtheta_dp(v_p)
        if np.any(v_theta):
            # print('using hessian-vector product!')
            result_theta += hessian_vector_product(v_theta)
        # For transpose operations, we use the fact that [A]^T @ v = A^T @ v
        # But we need to be careful about the shapes
        if np.any(v_lambda_theta):
            # [∂Θ_active/∂θ]^T is (n_theta x n_lambda_theta), so we need to handle this properly
            # This is equivalent to computing the VJP, but we'll use the transpose relationship
            _, vjp_fun = jax.vjp(Θ_active, Θ_opt)
            result_theta += np.array(vjp_fun(jnp.array(v_lambda_theta))[0])
        if np.any(v_lambda_g):
            # [∂g_active/∂θ]^T @ v_lambda_g
            _, vjp_fun = jax.vjp(lambda theta: g_active(theta, p_opt), Θ_opt)
            result_theta += np.array(vjp_fun(jnp.array(v_lambda_g))[0])
        result[idx:idx+n_theta] = result_theta

        # Block 5: Active bound constraints
        # R_lambda_theta = Θ_active(θ) - b_theta = 0
        # ∂R_lambda_theta/∂u @ v = [∂Θ_active/∂θ] @ v_theta - I @ v_btheta
        idx += n_theta
        result[idx:idx+n_lambda_theta] = jvp_dtheta_active_dtheta(v_theta) - v_btheta

        # Block 6: Active equality constraints
        # R_lambda_g = g_active(θ, p) - b_g = 0
        # ∂R_lambda_g/∂u @ v = [∂g_active/∂θ] @ v_theta + [∂g_active/∂p] @ v_p - I @ v_bg
        idx += n_lambda_theta
        result[idx:idx+n_lambda_g] = (jvp_dg_active_dtheta(v_theta) +
                                      jvp_dg_active_dp(v_p) - v_bg)

        # Block 7: Objective function
        # R_f = f(θ, p) - f_target = 0  (where f_target is just a dummy)
        # ∂R_f/∂u @ v = [∂f/∂θ] @ v_theta + [∂f/∂p] @ v_p - I @ v_f
        idx += n_lambda_g
        result[idx:idx+n_f] = (jvp_df_dtheta(v_theta) +
                               jvp_df_dp(v_p) - v_f)

        return result

    # Create LinearOperator for the forward system
    A_forward = LinearOperator((total_size, total_size), matvec=matvec_forward)

    # Storage for sensitivities
    sensitivities = {}

    # For each parameter, solve the forward system
    for i in range(n_p):
        # Create RHS vector: -∂R/∂p_i
        rhs = np.zeros(total_size)

        # Only the constraint and objective residuals depend on parameters
        # Block 4 (theta): -∂²L/∂θ∂p_i
        rhs[n_p + n_btheta + n_bg:n_p + n_btheta + n_bg + n_theta] = -jvp_d2L_dtheta_dp(np.eye(n_p)[i])

        # Block 6 (lambda_g): -∂g_active/∂p_i
        rhs[n_p + n_btheta + n_bg + n_theta + n_lambda_theta:
            n_p + n_btheta + n_bg + n_theta + n_lambda_theta + n_lambda_g] = -jvp_dg_active_dp(np.eye(n_p)[i])

        # Block 7 (f): -∂f/∂p_i
        rhs[n_p + n_btheta + n_bg + n_theta + n_lambda_theta + n_lambda_g:
            n_p + n_btheta + n_bg + n_theta + n_lambda_theta + n_lambda_g + n_f] = -jvp_df_dp(np.eye(n_p)[i])

        # Solve the system
        print(f"\nSolving for sensitivities w.r.t. p[{i}]...")
        solution, info = gmres(A_forward, rhs, rtol=1e-10, maxiter=1000)

        if info == 0:
            # Extract sensitivities from the solution
            dp_dp = solution[:n_p]
            dbtheta_dp = solution[n_p:n_p+n_btheta]
            dbg_dp = solution[n_p+n_btheta:n_p+n_btheta+n_bg]
            dtheta_dp = solution[n_p+n_btheta+n_bg:n_p+n_btheta+n_bg+n_theta]
            dlambda_theta_dp = solution[n_p+n_btheta+n_bg+n_theta:n_p+n_btheta+n_bg+n_theta+n_lambda_theta]
            dlambda_g_dp = solution[n_p+n_btheta+n_bg+n_theta+n_lambda_theta:n_p+n_btheta+n_bg+n_theta+n_lambda_theta+n_lambda_g]
            df_dp = solution[n_p+n_btheta+n_bg+n_theta+n_lambda_theta+n_lambda_g:]

            print(f"  df/dp[{i}] = {df_dp[0]:+.6f}")
            print(f"  dθ₀/dp[{i}] = {dtheta_dp[0]:+.6f}")
            print(f"  dθ₁/dp[{i}] = {dtheta_dp[1]:+.6f}")

            # Store results
            if i == 0:
                sensitivities['f'] = {'wrt_p': np.zeros(n_p), 'wrt_btheta': np.zeros(n_btheta), 'wrt_bg': np.zeros(n_bg)}
                sensitivities['θ₀'] = {'wrt_p': np.zeros(n_p), 'wrt_btheta': np.zeros(n_btheta), 'wrt_bg': np.zeros(n_bg)}
                sensitivities['θ₁'] = {'wrt_p': np.zeros(n_p), 'wrt_btheta': np.zeros(n_btheta), 'wrt_bg': np.zeros(n_bg)}

            sensitivities['f']['wrt_p'][i] = df_dp[0]
            sensitivities['θ₀']['wrt_p'][i] = dtheta_dp[0]
            sensitivities['θ₁']['wrt_p'][i] = dtheta_dp[1]

        else:
            print(f"  Warning: GMRES did not converge (info={info})")

    # Solve for sensitivities w.r.t. b_theta
    print(f"\nSolving for sensitivities w.r.t. b_theta...")
    rhs = np.zeros(total_size)
    # Block 5 (lambda_theta): -(-I) = +I
    rhs[n_p + n_btheta + n_bg + n_theta:n_p + n_btheta + n_bg + n_theta + n_lambda_theta] = 1.0

    solution, info = gmres(A_forward, rhs, rtol=1e-10, maxiter=1000)
    if info == 0:
        dtheta_dbtheta = solution[n_p+n_btheta+n_bg:n_p+n_btheta+n_bg+n_theta]
        df_dbtheta = solution[n_p+n_btheta+n_bg+n_theta+n_lambda_theta+n_lambda_g:]

        sensitivities['f']['wrt_btheta'][0] = df_dbtheta[0]
        sensitivities['θ₀']['wrt_btheta'][0] = dtheta_dbtheta[0]
        sensitivities['θ₁']['wrt_btheta'][0] = dtheta_dbtheta[1]

        print(f"  df/db_θ₀ = {df_dbtheta[0]:+.6f}")
        print(f"  dθ₀/db_θ₀ = {dtheta_dbtheta[0]:+.6f}")
        print(f"  dθ₁/db_θ₀ = {dtheta_dbtheta[1]:+.6f}")

    # Solve for sensitivities w.r.t. b_g
    print(f"\nSolving for sensitivities w.r.t. b_g...")
    rhs = np.zeros(total_size)
    # Block 6 (lambda_g): -(-I) = +I
    rhs[n_p + n_btheta + n_bg + n_theta + n_lambda_theta:n_p + n_btheta + n_bg + n_theta + n_lambda_theta + n_lambda_g] = 1.0

    solution, info = gmres(A_forward, rhs, rtol=1e-10, maxiter=1000)
    if info == 0:
        dtheta_dbg = solution[n_p+n_btheta+n_bg:n_p+n_btheta+n_bg+n_theta]
        df_dbg = solution[n_p+n_btheta+n_bg+n_theta+n_lambda_theta+n_lambda_g:]

        sensitivities['f']['wrt_bg'][0] = df_dbg[0]
        sensitivities['θ₀']['wrt_bg'][0] = dtheta_dbg[0]
        sensitivities['θ₁']['wrt_bg'][0] = dtheta_dbg[1]

        print(f"  df/db_g₀ = {df_dbg[0]:+.6f}")
        print(f"  dθ₀/db_g₀ = {dtheta_dbg[0]:+.6f}")
        print(f"  dθ₁/db_g₀ = {dtheta_dbg[1]:+.6f}")

    return sensitivities

def verify_with_finite_differences(sensitivities, h=1e-6):
    """
    Verify sensitivities using finite differences.
    """
    print("\n" + "="*60)
    print("Verification with Finite Differences")
    print("="*60)

    # Base values
    Θ_base = jnp.array([6.0, -6.0])
    p_base = jnp.array([3.0, 4.0, 3.0])
    b_theta_base = 6.0

    def solve_optimization(p_val, b_theta_val):
        """
        Simplified optimization solve assuming active constraints remain active.
        For the active set: θ₀ = b_theta_val, θ₁ = -b_theta_val
        """
        θ_opt = jnp.array([b_theta_val, -b_theta_val])
        f_val = f(θ_opt, p_val)[0]
        return f_val, θ_opt

    # df/dp₀ by finite differences
    p_plus = p_base.at[0].set(p_base[0] + h)
    f_plus, _ = solve_optimization(p_plus, b_theta_base)
    p_minus = p_base.at[0].set(p_base[0] - h)
    f_minus, _ = solve_optimization(p_minus, b_theta_base)
    df_dp0_fd = (f_plus - f_minus) / (2 * h)

    print(f"\ndf/dp₀:")
    print(f"  UDE:    {sensitivities['f']['wrt_p'][0]:+.6f}")
    print(f"  FD:     {df_dp0_fd:+.6f}")
    print(f"  Error:  {abs(sensitivities['f']['wrt_p'][0] - df_dp0_fd):.2e}")

    # dθ₀/db_θ₀ by finite differences (should be 1.0 since θ₀ = b_θ₀ at optimum)
    _, theta_plus = solve_optimization(p_base, b_theta_base + h)
    _, theta_minus = solve_optimization(p_base, b_theta_base - h)
    dtheta0_dbtheta_fd = (theta_plus[0] - theta_minus[0]) / (2 * h)

    print(f"\ndθ₀/db_θ₀:")
    print(f"  UDE:    {sensitivities['θ₀']['wrt_btheta'][0]:+.6f}")
    print(f"  FD:     {dtheta0_dbtheta_fd:+.6f}")
    print(f"  Error:  {abs(sensitivities['θ₀']['wrt_btheta'][0] - dtheta0_dbtheta_fd):.2e}")

    # dθ₁/db_θ₀ by finite differences (should be -1.0 due to constraint θ₀ + θ₁ = 0)
    dtheta1_dbtheta_fd = (theta_plus[1] - theta_minus[1]) / (2 * h)

    print(f"\ndθ₁/db_θ₀:")
    print(f"  UDE:    {sensitivities['θ₁']['wrt_btheta'][0]:+.6f}")
    print(f"  FD:     {dtheta1_dbtheta_fd:+.6f}")
    print(f"  Error:  {abs(sensitivities['θ₁']['wrt_btheta'][0] - dtheta1_dbtheta_fd):.2e}")

if __name__ == "__main__":
    print("="*60)
    print("Post-Optimality Sensitivity Analysis using JAX")
    print("with JVPs - Forward Mode Sensitivity Computation")
    print("="*60)

    # Compute sensitivities
    sensitivities = compute_sensitivities_jax()

    # Verify with finite differences
    verify_with_finite_differences(sensitivities)

    print("\n" + "="*60)
    print("Summary of Key Results")
    print("="*60)
    print("\nNote: Since θ₀ is bounded at 6.0 and the constraint θ₀ + θ₁ = 0 is active:")
    print("- Changes in b_θ₀ directly affect θ₀ (dθ₀/db_θ₀ = 1)")
    print("- Changes in b_θ₀ inversely affect θ₁ (dθ₁/db_θ₀ = -1)")
    print("- The objective is most sensitive to p₀ (df/dp₀ = -6)")
    print("\nForward mode approach:")
    print("- Solves ∂R/∂u @ (du/dp) = -∂R/∂p directly")
    print("- Uses JVPs for all matrix-vector operations")
    print("- Requires one solve per parameter (less efficient for many parameters)")
    print("- More intuitive: directly computes how variables change with parameters")

Post-Optimality Sensitivity Analysis using JAX
with JVPs - Forward Mode Sensitivity Computation

Solving for sensitivities w.r.t. p[0]...
  df/dp[0] = -6.000000
  dθ₀/dp[0] = +0.000000
  dθ₁/dp[0] = -0.000000

Solving for sensitivities w.r.t. p[1]...
  df/dp[1] = -4.000000
  dθ₀/dp[1] = -0.000000
  dθ₁/dp[1] = +0.000000

Solving for sensitivities w.r.t. p[2]...
  df/dp[2] = -1.000000
  dθ₀/dp[2] = +0.000000
  dθ₁/dp[2] = +0.000000

Solving for sensitivities w.r.t. b_theta...
  df/db_θ₀ = -2.000000
  dθ₀/db_θ₀ = +1.000000
  dθ₁/db_θ₀ = -1.000000

Solving for sensitivities w.r.t. b_g...
  df/db_g₀ = +2.000000
  dθ₀/db_g₀ = +0.000000
  dθ₁/db_g₀ = +1.000000

Verification with Finite Differences

df/dp₀:
  UDE:    -6.000000
  FD:     -6.000000
  Error:  8.39e-10

dθ₀/db_θ₀:
  UDE:    +1.000000
  FD:     +1.000000
  Error:  1.40e-10

dθ₁/db_θ₀:
  UDE:    -1.000000
  FD:     -1.000000
  Error:  1.40e-10

Summary of Key Results

Note: Since θ₀ is bounded at 6.0 and the constraint θ₀ + θ₁ = 0 

## The All-VJP way 

In [67]:
"""
Post-optimality sensitivity analysis using JAX and the UDE approach
with Vector-Jacobian Products (VJPs) - solving for theta sensitivities directly.
"""

import jax
import jax.numpy as jnp
import numpy as np
from scipy.sparse.linalg import LinearOperator, gmres
from functools import partial

# Define the optimization problem functions
def f(Θ, p):
    """Objective function"""
    f_val = (Θ[0] - p[0])**2 + Θ[0] * Θ[1] + (Θ[1] + p[1])**2 - p[2]
    return jnp.array([f_val])

def Θ_active(Θ):
    """Active bound constraint on Θ[0]"""
    return jnp.array([Θ[0]])

def g_active(Θ, p):
    """Active equality constraint"""
    return jnp.array([Θ[0] + Θ[1]])

def compute_sensitivities_jax():
    """
    Compute post-optimality sensitivities using JAX and the UDE approach.
    Uses Vector-Jacobian Products (VJPs) and solves for theta sensitivities directly.
    """

    # Known optimal solution
    Θ_opt = jnp.array([6.0, -6.0])
    p_opt = jnp.array([3.0, 4.0, 3.0])
    b_theta_opt = jnp.array([6.0])
    b_g_opt = jnp.array([0.0])
    λ_theta_opt = jnp.array([2.0])
    λ_g_opt = jnp.array([-2.0])

    # Problem dimensions
    n_p = 3  # parameters
    n_btheta = 1  # active bounds on design vars
    n_bg = 1  # active constraint bounds
    n_theta = 2  # design variables
    n_lambda_theta = 1  # multipliers for active bounds
    n_lambda_g = 1  # multipliers for active constraints
    n_f = 1  # original objective

    total_size = n_p + n_btheta + n_bg + n_theta + n_lambda_theta + n_lambda_g + n_f

    # Define Lagrangian and its gradient
    def lagrangian(Θ, p, λ_theta, λ_g):
        """Lagrangian function"""
        return f(Θ, p)[0] + λ_theta @ Θ_active(Θ) + λ_g @ g_active(Θ, p)

    def lagrangian_grad(Θ, p, λ_theta, λ_g):
        """Gradient of Lagrangian w.r.t. Θ"""
        return jax.grad(lagrangian, argnums=0)(Θ, p, λ_theta, λ_g)

    # Function to compute Hessian-vector product using finite differences
    def hessian_vector_product(v_theta, h=1e-8):
        """
        Compute H @ v where H is the Hessian of the Lagrangian w.r.t. Θ
        Uses finite differences: H @ v ≈ (∇L(Θ + h*v) - ∇L(Θ - h*v)) / (2*h)
        """
        if not np.any(v_theta):
            return np.zeros(n_theta)

        v_theta_jax = jnp.array(v_theta)

        # Compute gradients at perturbed points
        grad_plus = lagrangian_grad(Θ_opt + h * v_theta_jax, p_opt, λ_theta_opt, λ_g_opt)
        grad_minus = lagrangian_grad(Θ_opt - h * v_theta_jax, p_opt, λ_theta_opt, λ_g_opt)

        # Finite difference approximation of Hessian-vector product
        hvp = (grad_plus - grad_minus) / (2 * h)

        return np.array(hvp)

    # VJP functions for matrix-transpose-vector products
    def vjp_d2L_dtheta_dp_T(v_theta):
        """
        Compute [∂²L/∂θ∂p]^T @ v_theta using VJP
        """
        if not np.any(v_theta):
            return np.zeros(n_p)

        v_theta_jax = jnp.array(v_theta)

        # Get VJP function for the gradient w.r.t. θ
        _, vjp_fun = jax.vjp(lambda p: lagrangian_grad(Θ_opt, p, λ_theta_opt, λ_g_opt), p_opt)

        # Apply VJP with v_theta
        result = vjp_fun(v_theta_jax)[0]
        return np.array(result)

    def vjp_dg_active_dp_T(v_lambda_g):
        """
        Compute [∂g_active/∂p]^T @ v_lambda_g using VJP
        """
        if not np.any(v_lambda_g):
            return np.zeros(n_p)

        v_lambda_g_jax = jnp.array(v_lambda_g)

        # Get VJP function for g_active w.r.t. p
        _, vjp_fun = jax.vjp(lambda p: g_active(Θ_opt, p), p_opt)

        # Apply VJP with v_lambda_g
        result = vjp_fun(v_lambda_g_jax)[0]
        return np.array(result)

    def vjp_df_dp_T(v_f):
        """
        Compute [∂f/∂p]^T @ v_f using VJP
        """
        if not np.any(v_f):
            return np.zeros(n_p)

        v_f_jax = jnp.array(v_f)

        # Get VJP function for f w.r.t. p
        _, vjp_fun = jax.vjp(lambda p: f(Θ_opt, p), p_opt)

        # Apply VJP with v_f
        result = vjp_fun(v_f_jax)[0]
        return np.array(result)

    def vjp_dtheta_active_dtheta_T(v_lambda_theta):
        """
        Compute [∂Θ_active/∂θ]^T @ v_lambda_theta using VJP
        """
        if not np.any(v_lambda_theta):
            return np.zeros(n_theta)

        v_lambda_theta_jax = jnp.array(v_lambda_theta)

        # Get VJP function for Θ_active w.r.t. θ
        _, vjp_fun = jax.vjp(Θ_active, Θ_opt)

        # Apply VJP with v_lambda_theta
        result = vjp_fun(v_lambda_theta_jax)[0]
        return np.array(result)

    def vjp_dg_active_dtheta_T(v_lambda_g):
        """
        Compute [∂g_active/∂θ]^T @ v_lambda_g using VJP
        """
        if not np.any(v_lambda_g):
            return np.zeros(n_theta)

        v_lambda_g_jax = jnp.array(v_lambda_g)

        # Get VJP function for g_active w.r.t. θ
        _, vjp_fun = jax.vjp(lambda theta: g_active(theta, p_opt), Θ_opt)

        # Apply VJP with v_lambda_g
        result = vjp_fun(v_lambda_g_jax)[0]
        return np.array(result)

    def vjp_df_dtheta_T(v_f):
        """
        Compute [∂f/∂θ]^T @ v_f using VJP
        """
        if not np.any(v_f):
            return np.zeros(n_theta)

        v_f_jax = jnp.array(v_f)

        # Get VJP function for f w.r.t. θ
        _, vjp_fun = jax.vjp(lambda theta: f(theta, p_opt), Θ_opt)

        # Apply VJP with v_f
        result = vjp_fun(v_f_jax)[0]
        return np.array(result)

    # Regular JVP functions for forward matrix-vector products
    def jvp_dtheta_active_dtheta(v_theta):
        """
        Compute [∂Θ_active/∂θ] @ v_theta using JVP
        """
        if not np.any(v_theta):
            return np.zeros(n_lambda_theta)

        v_theta_jax = jnp.array(v_theta)

        _, jvp_result = jax.jvp(Θ_active, (Θ_opt,), (v_theta_jax,))
        return np.array(jvp_result)

    def jvp_dg_active_dtheta(v_theta):
        """
        Compute [∂g_active/∂θ] @ v_theta using JVP
        """
        if not np.any(v_theta):
            return np.zeros(n_lambda_g)

        v_theta_jax = jnp.array(v_theta)

        _, jvp_result = jax.jvp(lambda theta: g_active(theta, p_opt), (Θ_opt,), (v_theta_jax,))
        return np.array(jvp_result)

    def matvec_transpose(v):
        """
        Compute matrix-vector product for [∂R/∂u]^T @ v
        This implements the transpose of the UDE Jacobian matrix.
        Uses VJPs for transpose operations and JVPs for forward operations.
        """
        # Split v into blocks
        idx = 0
        v_p = v[idx:idx+n_p]
        idx += n_p
        v_btheta = v[idx:idx+n_btheta]
        idx += n_btheta
        v_bg = v[idx:idx+n_bg]
        idx += n_bg
        v_theta = v[idx:idx+n_theta]
        idx += n_theta
        v_lambda_theta = v[idx:idx+n_lambda_theta]
        idx += n_lambda_theta
        v_lambda_g = v[idx:idx+n_lambda_g]
        idx += n_lambda_g
        v_f = v[idx:idx+n_f]

        # print(f"Input vector norm: {np.linalg.norm(v):.6f}")

        # Initialize result
        result = np.zeros(total_size)

        # Block 1: Effect on p
        result[:n_p] = v_p
        result[:n_p] -= vjp_d2L_dtheta_dp_T(v_theta)
        result[:n_p] -= vjp_dg_active_dp_T(v_lambda_g)
        result[:n_p] -= vjp_df_dp_T(v_f)

        # Block 2: Effect on b_theta
        idx = n_p
        result[idx:idx+n_btheta] = v_btheta + v_lambda_theta

        # Block 3: Effect on b_g
        idx += n_btheta
        result[idx:idx+n_bg] = v_bg + v_lambda_g

        # Block 4: Effect on theta
        idx += n_bg
        result_theta = np.zeros(n_theta)
        if np.any(v_theta):
            # print('using hessian-vector product!')
            # Use Hessian-vector product for second derivatives
            result_theta -= hessian_vector_product(v_theta)
        result_theta -= vjp_dtheta_active_dtheta_T(v_lambda_theta)
        result_theta -= vjp_dg_active_dtheta_T(v_lambda_g)
        result_theta -= vjp_df_dtheta_T(v_f)
        result[idx:idx+n_theta] = result_theta

        # Block 5: Effect on lambda_theta
        idx += n_theta
        result[idx:idx+n_lambda_theta] = -jvp_dtheta_active_dtheta(v_theta)

        # Block 6: Effect on lambda_g
        idx += n_lambda_theta
        result[idx:idx+n_lambda_g] = -jvp_dg_active_dtheta(v_theta)

        # Block 7: Effect on f
        idx += n_lambda_g
        result[idx:idx+n_f] = v_f

        return result

    # Create LinearOperator
    A_transpose = LinearOperator((total_size, total_size), matvec=matvec_transpose)

    # Storage for sensitivities
    sensitivities = {}

    # Solve for each output: f, θ₀, θ₁
    output_info = [
        ('f', n_p + n_btheta + n_bg + n_theta + n_lambda_theta + n_lambda_g, 'objective'),
        ('θ₀', n_p + n_btheta + n_bg, 'design variable 0'),
        ('θ₁', n_p + n_btheta + n_bg + 1, 'design variable 1')
    ]

    for name, rhs_idx, description in output_info:
        # Create RHS vector with 1 in appropriate position
        rhs = np.zeros(total_size)
        rhs[rhs_idx] = 1.0

        # Solve the system
        print(f"\nSolving for sensitivities of {name} ({description})...")
        solution, info = gmres(A_transpose, rhs, rtol=1e-10, maxiter=1000)

        if info == 0:
            # Extract sensitivities
            sens_p = solution[:n_p]
            sens_btheta = solution[n_p:n_p+n_btheta]
            sens_bg = solution[n_p+n_btheta:n_p+n_btheta+n_bg]

            sensitivities[name] = {
                'wrt_p': sens_p,
                'wrt_btheta': sens_btheta,
                'wrt_bg': sens_bg
            }

            print(f"  d{name}/dp₀ = {sens_p[0]:+.6f}")
            print(f"  d{name}/dp₁ = {sens_p[1]:+.6f}")
            print(f"  d{name}/dp₂ = {sens_p[2]:+.6f}")
            print(f"  d{name}/db_θ₀ = {sens_btheta[0]:+.6f}")
            print(f"  d{name}/db_g₀ = {sens_bg[0]:+.6f}")
        else:
            print(f"  Warning: GMRES did not converge (info={info})")

    return sensitivities

def verify_with_finite_differences(sensitivities, h=1e-6):
    """
    Verify sensitivities using finite differences.
    """
    print("\n" + "="*60)
    print("Verification with Finite Differences")
    print("="*60)

    # Base values
    Θ_base = jnp.array([6.0, -6.0])
    p_base = jnp.array([3.0, 4.0, 3.0])
    b_theta_base = 6.0

    # Verify df/dp₀ and dθ₀/db_θ₀

    def solve_optimization(p_val, b_theta_val):
        """
        Simplified optimization solve assuming active constraints remain active.
        For the active set: θ₀ = b_theta_val, θ₁ = -b_theta_val
        """
        θ_opt = jnp.array([b_theta_val, -b_theta_val])
        f_val = f(θ_opt, p_val)[0]
        return f_val, θ_opt

    # df/dp₀ by finite differences
    p_plus = p_base.at[0].set(p_base[0] + h)
    f_plus, _ = solve_optimization(p_plus, b_theta_base)
    p_minus = p_base.at[0].set(p_base[0] - h)
    f_minus, _ = solve_optimization(p_minus, b_theta_base)
    df_dp0_fd = (f_plus - f_minus) / (2 * h)

    print(f"\ndf/dp₀:")
    print(f"  UDE:    {sensitivities['f']['wrt_p'][0]:+.6f}")
    print(f"  FD:     {df_dp0_fd:+.6f}")
    print(f"  Error:  {abs(sensitivities['f']['wrt_p'][0] - df_dp0_fd):.2e}")

    # dθ₀/db_θ₀ by finite differences (should be 1.0 since θ₀ = b_θ₀ at optimum)
    _, theta_plus = solve_optimization(p_base, b_theta_base + h)
    _, theta_minus = solve_optimization(p_base, b_theta_base - h)
    dtheta0_dbtheta_fd = (theta_plus[0] - theta_minus[0]) / (2 * h)

    print(f"\ndθ₀/db_θ₀:")
    print(f"  UDE:    {sensitivities['θ₀']['wrt_btheta'][0]:+.6f}")
    print(f"  FD:     {dtheta0_dbtheta_fd:+.6f}")
    print(f"  Error:  {abs(sensitivities['θ₀']['wrt_btheta'][0] - dtheta0_dbtheta_fd):.2e}")

    # dθ₁/db_θ₀ by finite differences (should be -1.0 due to constraint θ₀ + θ₁ = 0)
    dtheta1_dbtheta_fd = (theta_plus[1] - theta_minus[1]) / (2 * h)

    print(f"\ndθ₁/db_θ₀:")
    print(f"  UDE:    {sensitivities['θ₁']['wrt_btheta'][0]:+.6f}")
    print(f"  FD:     {dtheta1_dbtheta_fd:+.6f}")
    print(f"  Error:  {abs(sensitivities['θ₁']['wrt_btheta'][0] - dtheta1_dbtheta_fd):.2e}")

if __name__ == "__main__":
    print("="*60)
    print("Post-Optimality Sensitivity Analysis using JAX")
    print("with VJPs - Direct Theta Sensitivity Computation")
    print("="*60)

    # Compute sensitivities
    sensitivities = compute_sensitivities_jax()

    # Verify with finite differences
    verify_with_finite_differences(sensitivities)

    print("\n" + "="*60)
    print("Summary of Key Results")
    print("="*60)
    print("\nNote: Since θ₀ is bounded at 6.0 and the constraint θ₀ + θ₁ = 0 is active:")
    print("- Changes in b_θ₀ directly affect θ₀ (dθ₀/db_θ₀ = 1)")
    print("- Changes in b_θ₀ inversely affect θ₁ (dθ₁/db_θ₀ = -1)")
    print("- The objective is most sensitive to p₀ (df/dp₀ = -6)")
    print("\nThis approach:")
    print("- Solves for actual design variable sensitivities dθ/dp")
    print("- Uses the Hessian for proper second-order effects")
    print("- Generalizes to real optimization problems")
    print("- Maintains computational efficiency with VJPs/JVPs")

Post-Optimality Sensitivity Analysis using JAX
with VJPs - Direct Theta Sensitivity Computation

Solving for sensitivities of f (objective)...
  df/dp₀ = -6.000000
  df/dp₁ = -4.000000
  df/dp₂ = -1.000000
  df/db_θ₀ = -2.000000
  df/db_g₀ = +2.000000

Solving for sensitivities of θ₀ (design variable 0)...
  dθ₀/dp₀ = -0.000000
  dθ₀/dp₁ = +0.000000
  dθ₀/dp₂ = +0.000000
  dθ₀/db_θ₀ = +1.000000
  dθ₀/db_g₀ = -0.000000

Solving for sensitivities of θ₁ (design variable 1)...
  dθ₁/dp₀ = -0.000000
  dθ₁/dp₁ = -0.000000
  dθ₁/dp₂ = +0.000000
  dθ₁/db_θ₀ = -1.000000
  dθ₁/db_g₀ = +1.000000

Verification with Finite Differences

df/dp₀:
  UDE:    -6.000000
  FD:     -6.000000
  Error:  8.39e-10

dθ₀/db_θ₀:
  UDE:    +1.000000
  FD:     +1.000000
  Error:  1.40e-10

dθ₁/db_θ₀:
  UDE:    -1.000000
  FD:     -1.000000
  Error:  1.40e-10

Summary of Key Results

Note: Since θ₀ is bounded at 6.0 and the constraint θ₀ + θ₁ = 0 is active:
- Changes in b_θ₀ directly affect θ₀ (dθ₀/db_θ₀ = 1)
- Chang

## A different optimization problem

### The G13 problem from _Evolutionary Computation with Biogeography-based Optimization_ by Ma and Simon



In [43]:
import openmdao.api as om
import jax.numpy as jnp

class G13Comp(om.JaxExplicitComponent):

    def setup(self):
        self.add_input('x1')
        self.add_input('x2')
        self.add_input('x3')
        self.add_input('x4')
        self.add_input('x5')
        self.add_input('a')
        self.add_input('b')
        self.add_input('c')
        self.add_output('f')
        self.add_output('h1')
        self.add_output('h2')
        self.add_output('h3')

    def compute_primal(self, x1, x2, x3, x4, x5, a=10., b=5., c=1.):
        f = jnp.exp(x1 * x2 * x3 * x4 * x5)
        h1 = x1 ** 2 + x2 ** 2 + x3 ** 2 + x4 ** 2 + x5 ** 2 - a
        h2 = x2 * x3 - b * x4 * x5
        h3 = x1 ** 3 + x2 ** 3 + c
        return f, h1, h2, h3



In [44]:
import openmdao.api as om

prob = om.Problem()

prob.model.add_subsystem('g13', G13Comp(), promotes=['*'])


for i in [1, 2]:
    prob.model.add_design_var(f'x{i}', lower=-2.3, upper=2.3)

for i in [3, 4, 5]:
    prob.model.add_design_var(f'x{i}', lower=-3.2, upper=3.2)

prob.model.add_objective('f')

for i in [1, 2, 3]:
    prob.model.add_constraint(f'h{i}', equals=0)


prob.driver = om.ScipyOptimizeDriver()
prob.driver.opt_settings['ATOL'] = 1.0E-8



In [45]:
prob.setup()

prob.set_val('a', 10)
prob.set_val('b', 5)
prob.set_val('c', 1)

# prob.set_val('x1', -1.717143)
# prob.set_val('x2', 1.595709)
# prob.set_val('x3', 1.827247)
# prob.set_val('x4', -0.7636413)
# prob.set_val('x5', -0.763645)

prob.run_driver()

Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.053949845295294487
            Iterations: 20
            Function evaluations: 26
            Gradient evaluations: 20
Optimization Complete
-----------------------------------


Problem: problem10
Driver:  ScipyOptimizeDriver
  success     : True
  iterations  : 27
  runtime     : 1.6207E-01 s
  model_evals : 27
  model_time  : 2.1283E-02 s
  deriv_evals : 20
  deriv_time  : 1.3072E-01 s
  exit_status : SUCCESS

In [46]:
prob.list_driver_vars();

----------------
Design Variables
----------------
name  val            size  lower  upper  ref  ref0  indices  adder  scaler  parallel_deriv_color  cache_linear_solution  units  
----  -------------  ----  -----  -----  ---  ----  -------  -----  ------  --------------------  ---------------------  ----- 
x1    [-1.71714333]  1     -2.3   2.3    1.0  0.0   None     None   None    None                  False                  None   
x2    [1.59570942]   1     -2.3   2.3    1.0  0.0   None     None   None    None                  False                  None   
x3    [1.8272462]    1     -3.2   3.2    1.0  0.0   None     None   None    None                  False                  None   
x4    [0.76363602]   1     -3.2   3.2    1.0  0.0   None     None   None    None                  False                  None   
x5    [0.76365019]   1     -3.2   3.2    1.0  0.0   None     None   None    None                  False                  None   

-----------
Constraints
-----------
name  val 

In [47]:
from openmdao.utils.assert_utils import assert_near_equal

assert_near_equal(prob.get_val('f'), 0.053941514041898, tolerance=1.0E-3)

np.float64(0.0001544497507061727)