<a href="https://colab.research.google.com/github/yingzibu/ODE/blob/main/learn/ODE_adjoint_sensitivities_of_a_linear_systems.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://www.youtube.com/watch?v=dKqoXFULsbQ&list=PLISXH-iEM4Jk27AmSvISooRRKH4WtlWKP&index=3

**Linear system**

$Ax = b$ -> solve x -> loss function, reference solution $x_r$, loss function could be $J = \frac{1}{2} (x - x_r)^⊤ (x-x_r)$



$A x = b (\theta)$

sensitivities: $\frac{dJ}{d\theta}$?

Here:

$A \in \mathbf{R}^{3 \times 3}$, fixed

$b$ variable, parameter dependent, $b = [\theta_0, \theta_1, \theta_2]$, but we guess


Solve "classical" problem

1. solve $Ax = \hat{b}$ for $x$, can be solved by python

   evaluate $J = \frac{1}{2} (x - x_r)^⊤ (x-x_r)$

2. Obtain the gradients $\frac{dJ}{d\theta}$, the gradient is a row vector



a. finite differences

   $\frac{dJ}{d\theta} = [\frac{dJ}{d\theta_0}, \frac{dJ}{d\theta_1},\frac{dJ}{d\theta_2}]$

   for $\frac{dJ}{d\theta_0}$: $̃\tilde{b} = \hat{b} + ϵe_0$, solve $A\tilde{x} = \tilde{b}$ for $\tilde{x}$

   Evaluate $\tilde{J} = \frac{1}{2} (\tilde{x} - x_r)^⊤ (\tilde{x}-x_r)$

   $\frac{dJ}{d\theta_0} \approx \frac{\tilde{J}-J}{\epsilon}$



b. forward sensitivities
   
   solve $A\frac{dx}{d\theta} = \frac{db}{d\theta} - \frac{dA}{d\theta} x$

   here: $\frac{db}{d\theta} = I_3$

   $\frac{dA}{d\theta}$ this is not correct, wrong shape, yet our A is fixed, thus this is 0

   solve $A\frac{dx}{d\theta} = I_3$ for $\frac{dx}{d\theta}$

   then $\frac{dJ}{d\theta} = 0^\top$

   $\frac{\partial J} {\partial x} = (x - x_r)^\top$

   then $\frac{dJ}{d\theta} = (x - x_r)^\top \frac{dx}{d\theta}$





3. Adjoint backward sensitivities

   solve $A^\top \lambda = (\frac{\partial J}{\partial x})^\top = x - x_r$ for $\lambda$

   then $\frac{dJ}{d\theta} = \frac{\partial J}{\partial \theta} + \lambda^\top = \lambda^\top$

In [None]:
import numpy as np

In [None]:
A = np.array([
    [10,2,1],
    [2,5,1],
    [1,1,3]
])

### Creating a reference solution
b_true = np.array([5,4,3])
x_ref = np.linalg.solve(A, b_true) # Ax = b, solve for x

### [A] Solve the classical system
b_guess = np.ones(3)

x = np.linalg.solve(A, b_guess)
J = 0.5 * (x - x_ref).T @ (x-x_ref)

### [B] Obtaining gradients

## [1] Adjoint sensitivities

del_J__del_theta = np.zeros((1, 3))
del_J__del_x = (x - x_ref).T

d_b__d_theta = np.eye(3)

# Solve adjoint system

adjoint_variable = np.linalg.solve(A.T, del_J__del_x.T)

# plug in
d_J__d_theta__adjoint = del_J__del_theta + adjoint_variable.T @ d_b__d_theta

print(d_J__d_theta__adjoint)

[[-0.00415401 -0.05307211 -0.12790626]]


In [None]:
### [1] finite differences

eps = 1.0e-6

d_J__d_theta_finite_difference = np.empty((1, 3))

for i in range(3):
    b_augmented = b_guess.copy()
    b_augmented[i] += eps

    x_augmented = np.linalg.solve(A, b_augmented)
    J_augmented = 0.5 * (x_augmented - x_ref).T @ (x_augmented - x_ref)

    d_J__d_theta_finite_difference[0,i] = (J_augmented - J) / eps

print(d_J__d_theta_finite_difference)


[[-0.004154   -0.05307208 -0.12790619]]


In [None]:
### [2] forward sensitivities
# solve forward system

d_x__d_theta = np.linalg.solve(A, d_b__d_theta)

d_J__d_theta_forward = del_J__del_theta + del_J__del_x @ d_x__d_theta

print(d_J__d_theta_forward)

[[-0.00415401 -0.05307211 -0.12790626]]
