In [None]:
from importlib import reload

import numpy as np
from matplotlib import pyplot as plt
from scipy.integrate import solve_ivp

import L96
reload(L96)

Josh's paper assumes the two parameters to be estimated are

$$
\gamma_1 = \gamma_{i, j} = \gamma_i \quad \text{for all } i, j\\
\gamma_2 = \bar d_i \quad \text{for all } i
$$

The first parameter $\gamma_1$ corresponds to `γs` and `γs2` which are all identical.

The second parameter $\gamma_2$ corresponds to `ds` which are all identical.

Thus the sensitivities in Josh's paper correspond to these two parameters, and the following code will reflect this.
That is, while the preceding `ode` code is general to the Lorenz '96 system, the following is not.

In [None]:
# Dimensions
I, J = 10, 3
J_sim = J - 2

# System evolution parameters
tlim = t0, tf = 0, 2
tn = 100
tls = np.linspace(*tlim, tn)

# True system parameters
ds = np.full(I, 1)
γs = np.full((I, J), 1)
ds2 = np.full((I, J), 1)
γs2 = np.full(I, 1)
F = -1

# Nudging parameter and simulated system parameters
μ = 10
γs_sim = np.full((I, J_sim), 1)
ds2_sim = np.full((I, J_sim), 1)

# Initial true state
U0 = np.full(I, 1)
V0 = np.full((I, J), 2)
state0 = L96.together(U0, V0)

# Initial simulation state
U0_sim = U0 + 1
V0_sim = np.full((I, J_sim), 2)
state0_sim = L96.together(U0_sim, V0_sim)

# Evolve true and simulated systems
sol = solve_ivp(
    L96.ode,
    tlim,
    state0,
    args=(I, J, ds, γs, ds2, γs2, F),
    t_eval=tls,
    dense_output=True,
)

sim = solve_ivp(
    L96.ode,
    tlim,
    state0_sim,
    args=(I, J_sim, ds, γs_sim, ds2_sim, γs2, F, μ, sol.sol, J),
    t_eval=tls,
)

# Unpack true and simulated states
states = sol.y
Us, Vs = zip(*(L96.apart(state, I, J) for state in states.T))
Us, Vs = np.stack(Us), np.stack(Vs)

states_sim = sim.y
Us_sim, Vs_sim = zip(*(L96.apart(state, I, J_sim) for state in states_sim.T))
Us_sim, Vs_sim = np.stack(Us_sim), np.stack(Vs_sim)

# TODO: Perform gradient descent update, then simulate again with new
# parameters.

In [None]:
fig, ax = plt.subplots(1, 1)

ax.plot(sol.t, Us.T[0], label="true", color="blue")
ax.plot(sim.t, Us_sim.T[0], label="sim", color="red", linestyle="--")

ax.legend()
plt.show()