In [None]:
from importlib import reload

import numpy as np
from matplotlib import pyplot as plt
from scipy.integrate import solve_ivp

import L96

reload(L96)

ndarray = np.ndarray

In [None]:
def evolve(
    system: L96.System,
    tlim: tuple[float],
    U0: ndarray,
    V0: ndarray,
    U0_sim: ndarray,
    V0_sim: ndarray,
):
    """Evolve the true and simulated system from tlim[0] to tlim[1].

    Parameters
    ----------
    system
        Instance of `L96.System`
    tlim
        (start time, stop time)
    U0
        The initial state of the true large-scale system
        shape (I,)
    V0
        The initial states of the true small-scale systems
        shape (I, J)
    U0_sim, V0_sim
        The initial states of the simulated large- and small-scale systems
        shapes same as those for U0 and V0
    """

    state0 = L96.together(U0, V0)
    state0_sim = L96.together(U0_sim, V0_sim)

    # Evolve true and simulated systems
    sol = solve_ivp(
        system.ode_true,
        tlim,
        state0,
        dense_output=True,
    )

    sim = solve_ivp(
        system.ode_sim,
        tlim,
        state0_sim,
        args=(sol.sol,),
        dense_output=True,
    )

    return sol, sim

In [None]:
# Dimensions
I, J = 10, 3
J_sim = J - 2

# System evolution parameters
tlim = t0, tf = 0, 2

# True system parameters
ds = np.full(I, 1)
γs = np.full((I, J), 1)
ds2 = np.full((I, J), 1)
γs2 = np.full(I, 1)
F = -1

# Nudging parameter and simulated system parameters
μ = 10
ds_sim = np.full(I, 1)
γs_sim = np.full((I, J_sim), 1)
ds2_sim = np.full((I, J_sim), 1)
γs2_sim = np.full(I, 1)

# Initial true state
U0 = np.full(I, 1)
V0 = np.full((I, J), 2)

# Initial simulation state
U0_sim = U0 + 1
V0_sim = np.full((I, J_sim), 2)

system = L96.System(
    I, J, J_sim, ds, γs, ds2, γs2, ds_sim, γs_sim, ds2_sim, γs2_sim, F, μ
)

sol, sim = evolve(system, tlim, U0, V0, U0_sim, V0_sim)

tn = 100
tls = np.linspace(*tlim, tn)

# Unpack true and simulated states
states = sol.sol(tls)
Us, Vs = zip(*(L96.apart(state, I, J) for state in states.T))
Us, Vs = np.stack(Us), np.stack(Vs)

states_sim = sim.sol(tls)
Us_sim, Vs_sim = zip(*(L96.apart(state, I, J_sim) for state in states_sim.T))
Us_sim, Vs_sim = np.stack(Us_sim), np.stack(Vs_sim)

# TODO: Perform gradient descent update, then simulate again with new
# parameters.

In [None]:
fig, ax = plt.subplots(1, 1)

ax.plot(tls, Us.T[0], label="true", color="blue")
ax.plot(tls, Us_sim.T[0], label="sim", color="red", linestyle="--")

ax.legend()
plt.show()