# Get Trajectory graphs

In [20]:
%%capture

%load_ext autoreload
%autoreload 2

!pip install filterpy

import numpy as np
import pandas as pd
import os
import sys
from tqdm import tqdm
from filterpy.kalman import EnsembleKalmanFilter as EnKF
import matplotlib.pyplot as plt
import scipy.integrate as integrate
from scipy.interpolate import griddata
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from jax import jit, random, lax
from jax.scipy.linalg import sqrtm
from functools import partial
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from scipy import integrate
from jax.tree_util import Partial
from functools import partial

if os.path.isdir('./ACM270_Model_Error/') == False:
    ! git clone https://github.com/sreemanti-dey/ACM270_Model_Error.git
sys.path.append('./ACM270_Model_Error/')

! cd ACM270_Model_Error

"""# True Trajectory and Noisy Observations with 2-Scale Lorenz96"""

# parameters for two-scale L96
K = 36                      # number of large scale vars
J = 10                      # number of small scale vars per large var
h = 0.25                   # part of coupling
c = 10                     # part of coupling
b = 10                     # part of coupling
F = 10                     # forcing

dt = 0.1                  # time step
noise_const = 0.1
num_steps = 200
time_steps = num_steps

num_steps = int(num_steps * (0.05 / dt))
time_steps = num_steps

@jit
def L96_2(xy):
    x = xy[0:K]
    y = xy[K:].reshape(K, J)

    dx = jnp.zeros(K)
    dy = jnp.zeros((K, J))

    for k in range(K):
        dxdt = -1 * x[k - 1] * (x[k - 2] - x[(k + 1) % K]) - x[k] + F - (h * c / b) * jnp.sum(y[k])
        dx = dx.at[k].set(dxdt)

        for j in range(J):
            dydt = -1 * c * b * y[k, (j + 1) % J] * (y[k, (j + 2) % J] - y[k, j - 1]) - c * y[k, j] + (h * c / b) * x[k]
            dy = dy.at[k, j].set(dydt)

    return jnp.concatenate([dx, dy.flatten()])

@jit
def rk4_step_lorenz96_2(x):
    f = lambda y: L96_2(y)
    k1 = dt * f(x)
    k2 = dt * f(x + k1/2)
    k3 = dt * f(x + k2/2)
    k4 = dt * f(x + k3)
    return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

keys = iter(random.split(random.PRNGKey(0), 10))
key = random.PRNGKey(0)
# randomized starting point for large scale, 0 for small scale
x0 = np.linspace(0, 1, K)  # np.random.rand(K)
y0 = np.zeros((K, J))
xy0 = np.concatenate([x0, y0.flatten()])

rk4_step_lorenz96_2(xy0)

trajectory = jnp.zeros((time_steps, len(xy0)))
noisy_obs = jnp.zeros((time_steps, len(xy0)))
noise_covar = jnp.eye(len(xy0)) * noise_const

xy = xy0
for i in tqdm(range(time_steps)):
    xy = rk4_step_lorenz96_2(xy)
    trajectory = trajectory.at[i].set(xy)
    # Add noise to the observations
    noise = random.multivariate_normal(key, jnp.zeros(len(xy0)), noise_covar)
    noisy_obs = noisy_obs.at[i].set(xy + noise)
    key, _ = random.split(key)  # Update key for the next iteration

"""## True Trajectory

## Noisy Observations
"""

true_states = trajectory

"""# EnKF for 1-Scale Assimilation"""

@jit
def rk4_step_lorenz96_1(x):
    f = lambda y: (jnp.roll(y, 1) - jnp.roll(y, -2)) * jnp.roll(y, -1) - y + F
    k1 = dt * f(x)
    k2 = dt * f(x + k1/2)
    k3 = dt * f(x + k2/2)
    k4 = dt * f(x + k3)
    return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

lorenz1 = Partial(rk4_step_lorenz96_1)
lorenz2 = Partial(rk4_step_lorenz96_2)

@jit
def ledoit_wolf(P, shrinkage):
    return (1 - shrinkage) * P + shrinkage * jnp.trace(P) / P.shape[0] * jnp.eye(P.shape[0])

@jit
def ensrf_step(ensemble, y, H, Q, R, key, inflation = 1.2):
    n_ensemble = ensemble.shape[1]
    x_m = jnp.mean(ensemble, axis=1)
    A = ensemble - x_m.reshape((-1, 1))
    C_pred = (A @ A.T) / (n_ensemble - 1) + Q
    C_pred = ledoit_wolf(C_pred, noise_const)
    A = A * inflation
    P =  (A @ A.T) / (n_ensemble - 1) + Q
    K = P @ H.T @ jnp.linalg.inv(H @ P @ H.T + R)
    x_m += K @ (y - H @ x_m)
    M_sqrt = sqrt_m(jnp.eye(x_m.shape[0]) - K @ H)
    updated_A = M_sqrt @ A
    updated_ensemble = x_m.reshape((-1, 1)) + updated_A
    updated_P = (updated_A @ updated_A.T) / (n_ensemble - 1)
    updated_P = ledoit_wolf(updated_P, noise_const)

    ensemble = ensemble.astype(jnp.float32)
    C_pred = C_pred.astype(jnp.float32)
    updated_ensemble = updated_ensemble.astype(jnp.float32)
    updated_P = updated_P.astype(jnp.float32)
    return ensemble, C_pred, updated_ensemble, updated_P

@jit
def sqrt_m(M):
    eigenvalues, eigenvectors = jnp.linalg.eigh(M)
    inv_sqrt_eigenvalues = jnp.sqrt(eigenvalues)
    Lambda_inv_sqrt = jnp.diag(inv_sqrt_eigenvalues)
    M_sqrt = eigenvectors @ Lambda_inv_sqrt @ eigenvectors.T
    return M_sqrt.real

@partial(jit, static_argnums=(3))
def ensrf_steps(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key):
    model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
    key, *subkeys = random.split(key, num=num_steps + 1)
    subkeys = jnp.array(subkeys)

    def inner(carry, t):
        ensemble, covar = carry
        ensemble_predicted = model_vmap(ensemble)
        _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
        return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

    n = len(Q[0])
    covariance_init = jnp.zeros((n, n))
    _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

    return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

class IncrementCorrectionModel(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        x = nn.Dense(x.shape[-1])(x)  # Output layer should match input dimension
        return x

def create_model(key, input_shape, features):
    model = IncrementCorrectionModel(features)
    params = model.init(key, jnp.ones(input_shape))['params']
    return model, params

def loss_fn(params, apply_fn, x, y):
    predictions = apply_fn({'params': params}, x)
    loss = jnp.mean((predictions - y) ** 2)
    return loss

@jax.jit
def train_step(state, forecast_states, increments):
    def loss_fn_wrapper(params):
        return loss_fn(params, state.apply_fn, forecast_states, increments)

    grads = jax.grad(loss_fn_wrapper)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

def train_nn(state, forecast_states, analysis_states, num_epochs=10):
    increments = analysis_states - forecast_states

    for epoch in range(num_epochs):
        state = train_step(state, forecast_states, increments)
        current_loss = loss_fn(state.params, state.apply_fn, forecast_states, increments)
        print(f'Epoch {epoch+1}, Loss: {current_loss}')

    return state

Q = noise_const * jnp.eye(K)
H = jnp.eye(K)
R = jnp.eye(K) * noise_const
n_ensemble = 10
num_steps = time_steps
#the ensrf (ENKF) will run single scale Lorenz on the first K variables
observations = noisy_obs[:,:K]
initial_state = np.zeros((1, K))
ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(lorenz1, n_ensemble, ensemble_init, time_steps, observations, 1, H, Q, R, key)

time_steps = ensemble_forecast.shape[0]
time = jnp.arange(time_steps)

forecast_means = jnp.mean(ensemble_forecast, axis=2)
analysis_means = jnp.mean(ensemble_analysis, axis=2)

file = f"ACM270_Model_Error/long_training_data_noise_{noise_const}_dt_{dt}.npz"
data = np.load(file)
forecast_state = data["forecast"]
analysis_state = data["analysis"]

key = jax.random.PRNGKey(0)
input_shape = (K,)
features = K
model, params = create_model(key, input_shape, features)

learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

num_epochs = 15
_,_,n = forecast_state.shape
for i in range(n):
  forecast_means = forecast_state[:,:,i]
  analysis_means = analysis_state[:,:,i]
  state = train_nn(state, forecast_means, analysis_means, num_epochs=num_epochs)

#@partial(jax.jit, static_argnums=(3))
def ensrf_steps_nn(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key, nn_params, model):
    def corrected_state_transition(v):
        transition = state_transition_function(v)
        correction = model.apply({'params': nn_params}, transition)
        return transition + correction

    model_vmap = jax.vmap(corrected_state_transition, in_axes=1, out_axes=1)
    key, *subkeys = jax.random.split(key, num=num_steps + 1)
    subkeys = jnp.array(subkeys)

    def inner(carry, t):
        ensemble, covar = carry
        ensemble_predicted = model_vmap(ensemble)
        _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
        return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

    n = len(Q[0])
    covariance_init = jnp.zeros((n, n))
    _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

    return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

nn_params = state.params
ensemble_forecast_NN, C_forecast_NN, ensemble_analysis_NN, C_analysis_NN = ensrf_steps_nn(
    lorenz1,
    n_ensemble,
    ensemble_init,
    time_steps,
    observations,
    1,
    H,
    Q,
    R,
    key,
    nn_params,
    model
)

ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(
    lorenz1,
    n_ensemble,
    ensemble_init,
    time_steps,
    observations,
    1,
    H,
    Q,
    R,
    key
)

time_steps = ensemble_forecast.shape[0]
time = jnp.arange(time_steps)

true_states = trajectory
forecast_means = jnp.mean(ensemble_forecast, axis=2)
analysis_means = jnp.mean(ensemble_analysis, axis=2)

forecast_means_NN = jnp.mean(ensemble_forecast_NN, axis=2)
analysis_means_NN = jnp.mean(ensemble_analysis_NN, axis=2)

model_error = observations - forecast_means
NRMSE_fm = np.dot(model_error.flatten(), model_error.flatten()) / np.dot(observations.flatten(), observations.flatten())
model_error = observations - analysis_means
NRMSE_as = np.dot(model_error.flatten(), model_error.flatten()) / np.dot(observations.flatten(), observations.flatten())
model_error = observations - forecast_means_NN
NRMSE_fm_NN = np.dot(model_error.flatten(), model_error.flatten()) / np.dot(observations.flatten(), observations.flatten())
model_error = observations - analysis_means_NN
NRMSE_as_NN = np.dot(model_error.flatten(), model_error.flatten()) / np.dot(observations.flatten(), observations.flatten())

fontsize = 25
res = 500

time = np.arange(0, time_steps)*dt

plt.figure(figsize=(14, 8))

i = 0  # Index of the state variable to plot
plt.plot(time, true_states[:, i], label=f"True state {i+1}", linestyle="-", marker="o",color='k', linewidth=2.5,markersize=7)
plt.plot(time, forecast_means[:, i], label=f"Forecast mean - {i+1}", linestyle="--", marker="x",color='b', linewidth=2.5,markersize=7)
plt.plot(time, analysis_means[:, i], label=f"Analysis mean - {i+1}", linestyle=":", marker="d",color='green', linewidth=2.5,markersize=7)
plt.scatter(time, observations[:, i], label=f"Observations - {i+1}", marker="s",color='gray')

plt.xlabel("Time",fontsize=fontsize)
plt.ylabel("State values",fontsize=fontsize)
plt.title("EnKF prediction with True State and Observations",fontsize=fontsize)
plt.legend(fontsize=fontsize/1.5)
plt.grid(True)
plt.xticks(fontsize=fontsize / 1.25)
plt.yticks(fontsize=fontsize / 1.25)
plt.savefig(f'uncorrected_noise_{noise_const}_dt_{dt}.png', bbox_inches="tight", dpi=res)


plt.figure(figsize=(14, 8))

i = 0  # Index of the state variable to plot
plt.plot(time, true_states[:, i], label=f"True state {i+1}", linestyle="-", marker="o",color='k', linewidth=2.5,markersize=7)
plt.plot(time, forecast_means_NN[:, i], label=f"Forecast mean - {i+1}", linestyle="--", marker="x",color='b', linewidth=2.5,markersize=7)
plt.plot(time, analysis_means_NN[:, i], label=f"Analysis mean - {i+1}", linestyle=":", marker="d",color='green', linewidth=2.5,markersize=7)
plt.scatter(time, observations[:, i], label=f"Observations {i+1}", marker="s",color='gray')

plt.xlabel("Time",fontsize=fontsize)
plt.ylabel("State values",fontsize=fontsize)
plt.title("Corrected EnKF prediction with True State and Observations",fontsize=fontsize)
plt.legend(fontsize=fontsize/1.5)
plt.grid(True)
plt.xticks(fontsize=fontsize / 1.25)
plt.yticks(fontsize=fontsize / 1.25)
plt.savefig(f'corrected_noise_{noise_const}_dt_{dt}.png', bbox_inches="tight", dpi=res)

plt.figure(figsize=(14, 8))

i = 0  # Index of the state variable to plot
plt.plot(time, analysis_means[:, i], label=f"Forecast mean - {i+1}", linestyle="--", marker="x",color='b', linewidth=2.5,markersize=7)
plt.plot(time, forecast_means_NN[:, i], label=f"Analysis mean - {i+1}", linestyle=":", marker="d",color='green', linewidth=2.5,markersize=7)

plt.xlabel("Time",fontsize=fontsize)
plt.ylabel("State values",fontsize=fontsize)
plt.title("EnKF analysis and NN prediction",fontsize=fontsize)
plt.legend(fontsize=fontsize/1.5)
plt.grid(True)
plt.xticks(fontsize=fontsize / 1.25)
plt.yticks(fontsize=fontsize / 1.25)
plt.savefig(f'NN_noise_{noise_const}_dt_{dt}.png', bbox_inches="tight", dpi=res)


# MSE vs. dt/noise

In [None]:
%%capture

%load_ext autoreload
%autoreload 2

!pip install filterpy

import numpy as np
import pandas as pd
import os
import sys
from tqdm import tqdm
from filterpy.kalman import EnsembleKalmanFilter as EnKF
import matplotlib.pyplot as plt
import scipy.integrate as integrate
from scipy.interpolate import griddata
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from jax import jit, random, lax
from jax.scipy.linalg import sqrtm
from functools import partial
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from scipy import integrate
from jax.tree_util import Partial
from functools import partial

! rm -rf ACM270_Model_Error
# if os.path.isdir('./ACM270_Model_Error/') == False:
! git clone https://github.com/sreemanti-dey/ACM270_Model_Error.git
sys.path.append('./ACM270_Model_Error/')

! cd ACM270_Model_Error

"""# True Trajectory and Noisy Observations with 2-Scale Lorenz96"""

noise_consts = [0.1, 0.5, 1, 5]
dts = [0.01, 0.05, 0.1]

full_NRMSE_as = np.zeros((len(noise_consts),len(dts)))
full_NRMSE_as_NN = np.zeros((len(noise_consts),len(dts)))
full_NRMSE_fm = np.zeros((len(noise_consts),len(dts)))
full_NRMSE_fm_NN = np.zeros((len(noise_consts),len(dts)))

for idx_n, noise_const in enumerate(noise_consts):
  for idx_dt, dt in enumerate(dts):
    # parameters for two-scale L96
    K = 36                      # number of large scale vars
    J = 10                      # number of small scale vars per large var
    h = 0.25                   # part of coupling
    c = 10                     # part of coupling
    b = 10                     # part of coupling
    F = 10                     # forcing
    # dt = 0.1                  # time step
    # noise_const = 0.1
    num_steps = 200
    num_steps = int(num_steps * (0.05 / dt))
    time_steps = num_steps

    @jit
    def L96_2(xy):
        x = xy[0:K]
        y = xy[K:].reshape(K, J)

        dx = jnp.zeros(K)
        dy = jnp.zeros((K, J))

        for k in range(K):
            dxdt = -1 * x[k - 1] * (x[k - 2] - x[(k + 1) % K]) - x[k] + F - (h * c / b) * jnp.sum(y[k])
            dx = dx.at[k].set(dxdt)

            for j in range(J):
                dydt = -1 * c * b * y[k, (j + 1) % J] * (y[k, (j + 2) % J] - y[k, j - 1]) - c * y[k, j] + (h * c / b) * x[k]
                dy = dy.at[k, j].set(dydt)

        return jnp.concatenate([dx, dy.flatten()])

    @jit
    def rk4_step_lorenz96_2(x):
        f = lambda y: L96_2(y)
        k1 = dt * f(x)
        k2 = dt * f(x + k1/2)
        k3 = dt * f(x + k2/2)
        k4 = dt * f(x + k3)
        return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

    keys = iter(random.split(random.PRNGKey(0), 10))
    key = random.PRNGKey(0)
    # randomized starting point for large scale, 0 for small scale
    x0 = np.linspace(0, 1, K)  # np.random.rand(K)
    y0 = np.zeros((K, J))
    xy0 = np.concatenate([x0, y0.flatten()])

    rk4_step_lorenz96_2(xy0)

    trajectory = jnp.zeros((time_steps, len(xy0)))
    noisy_obs = jnp.zeros((time_steps, len(xy0)))
    noise_covar = jnp.eye(len(xy0)) * noise_const

    xy = xy0
    for i in tqdm(range(time_steps)):
        xy = rk4_step_lorenz96_2(xy)
        trajectory = trajectory.at[i].set(xy)
        # Add noise to the observations
        noise = random.multivariate_normal(key, jnp.zeros(len(xy0)), noise_covar)
        noisy_obs = noisy_obs.at[i].set(xy + noise)
        key, _ = random.split(key)  # Update key for the next iteration

    """## True Trajectory

    ## Noisy Observations
    """

    true_states = trajectory

    """# EnKF for 1-Scale Assimilation"""

    @jit
    def rk4_step_lorenz96_1(x):
        f = lambda y: (jnp.roll(y, 1) - jnp.roll(y, -2)) * jnp.roll(y, -1) - y + F
        k1 = dt * f(x)
        k2 = dt * f(x + k1/2)
        k3 = dt * f(x + k2/2)
        k4 = dt * f(x + k3)
        return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

    lorenz1 = Partial(rk4_step_lorenz96_1)
    lorenz2 = Partial(rk4_step_lorenz96_2)

    @jit
    def ledoit_wolf(P, shrinkage):
        return (1 - shrinkage) * P + shrinkage * jnp.trace(P) / P.shape[0] * jnp.eye(P.shape[0])

    @jit
    def ensrf_step(ensemble, y, H, Q, R, key, inflation = 1.2):
        n_ensemble = ensemble.shape[1]
        x_m = jnp.mean(ensemble, axis=1)
        A = ensemble - x_m.reshape((-1, 1))
        C_pred = (A @ A.T) / (n_ensemble - 1) + Q
        C_pred = ledoit_wolf(C_pred, noise_const)
        A = A * inflation
        P =  (A @ A.T) / (n_ensemble - 1) + Q
        K = P @ H.T @ jnp.linalg.inv(H @ P @ H.T + R)
        x_m += K @ (y - H @ x_m)
        M_sqrt = sqrt_m(jnp.eye(x_m.shape[0]) - K @ H)
        updated_A = M_sqrt @ A
        updated_ensemble = x_m.reshape((-1, 1)) + updated_A
        updated_P = (updated_A @ updated_A.T) / (n_ensemble - 1)
        updated_P = ledoit_wolf(updated_P, noise_const)

        ensemble = ensemble.astype(jnp.float32)
        C_pred = C_pred.astype(jnp.float32)
        updated_ensemble = updated_ensemble.astype(jnp.float32)
        updated_P = updated_P.astype(jnp.float32)
        return ensemble, C_pred, updated_ensemble, updated_P

    @jit
    def sqrt_m(M):
        eigenvalues, eigenvectors = jnp.linalg.eigh(M)
        inv_sqrt_eigenvalues = jnp.sqrt(eigenvalues)
        Lambda_inv_sqrt = jnp.diag(inv_sqrt_eigenvalues)
        M_sqrt = eigenvectors @ Lambda_inv_sqrt @ eigenvectors.T
        return M_sqrt.real

    @partial(jit, static_argnums=(3))
    def ensrf_steps(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key):
        model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
        key, *subkeys = random.split(key, num=num_steps + 1)
        subkeys = jnp.array(subkeys)

        def inner(carry, t):
            ensemble, covar = carry
            ensemble_predicted = model_vmap(ensemble)
            _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
            return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

        n = len(Q[0])
        covariance_init = jnp.zeros((n, n))
        _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

        return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

    class IncrementCorrectionModel(nn.Module):
        features: int

        @nn.compact
        def __call__(self, x):
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(x.shape[-1])(x)  # Output layer should match input dimension
            return x

    def create_model(key, input_shape, features):
        model = IncrementCorrectionModel(features)
        params = model.init(key, jnp.ones(input_shape))['params']
        return model, params

    def loss_fn(params, apply_fn, x, y):
        predictions = apply_fn({'params': params}, x)
        loss = jnp.mean((predictions - y) ** 2)
        return loss

    @jax.jit
    def train_step(state, forecast_states, increments):
        def loss_fn_wrapper(params):
            return loss_fn(params, state.apply_fn, forecast_states, increments)

        grads = jax.grad(loss_fn_wrapper)(state.params)
        state = state.apply_gradients(grads=grads)
        return state

    def train_nn(state, forecast_states, analysis_states, num_epochs=10):
        increments = analysis_states - forecast_states

        for epoch in range(num_epochs):
            state = train_step(state, forecast_states, increments)
            current_loss = loss_fn(state.params, state.apply_fn, forecast_states, increments)
            print(f'Epoch {epoch+1}, Loss: {current_loss}')

        return state

    Q = noise_const * jnp.eye(K)
    H = jnp.eye(K)
    R = jnp.eye(K) * noise_const
    n_ensemble = 10
    num_steps = time_steps
    #the ensrf (ENKF) will run single scale Lorenz on the first K variables
    observations = noisy_obs[:,:K]
    initial_state = np.zeros((1, K))
    ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
    ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(lorenz1, n_ensemble, ensemble_init, time_steps, observations, 1, H, Q, R, key)

    time_steps = ensemble_forecast.shape[0]
    time = jnp.arange(time_steps)

    forecast_means = jnp.mean(ensemble_forecast, axis=2)
    analysis_means = jnp.mean(ensemble_analysis, axis=2)

    file = f"ACM270_Model_Error/long_training_data_noise_{noise_const}_dt_{dt}.npz"
    data = np.load(file)
    forecast_state = data["forecast"]
    analysis_state = data["analysis"]

    key = jax.random.PRNGKey(0)
    input_shape = (K,)
    features = K
    model, params = create_model(key, input_shape, features)

    learning_rate = 1e-3
    optimizer = optax.adam(learning_rate)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    num_epochs = 15
    _,_,n = forecast_state.shape
    for i in range(n):
      forecast_means = forecast_state[:,:,i]
      analysis_means = analysis_state[:,:,i]
      state = train_nn(state, forecast_means, analysis_means, num_epochs=num_epochs)

    #@partial(jax.jit, static_argnums=(3))
    def ensrf_steps_nn(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key, nn_params, model):
        def corrected_state_transition(v):
            transition = state_transition_function(v)
            correction = model.apply({'params': nn_params}, transition)
            return transition + correction

        model_vmap = jax.vmap(corrected_state_transition, in_axes=1, out_axes=1)
        key, *subkeys = jax.random.split(key, num=num_steps + 1)
        subkeys = jnp.array(subkeys)

        def inner(carry, t):
            ensemble, covar = carry
            ensemble_predicted = model_vmap(ensemble)
            _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
            return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

        n = len(Q[0])
        covariance_init = jnp.zeros((n, n))
        _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

        return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

    if dt == 0.01:
      file = f"testing_data_noise_{noise_const}_dt_{dt}.npz"
    else:
      file = f"ACM270_Model_Error/testing_data_noise_{noise_const}_dt_{dt}.npz"
    data = np.load(file)
    noisy_trajectory = data["noisy_trajectory"]
    try:
      trajectory=data['true_trajectory']
    except:
      trajectory=data['trajectory']

    noisy_trajectory =  jnp.asarray(noisy_trajectory)
    trajectory = jnp.asarray(trajectory)

    NRMSE_fm = []
    NRMSE_as = []
    NRMSE_fm_NN = []
    NRMSE_as_NN = []

    _,_,n = noisy_trajectory.shape

    for i in range(n):
      observations = noisy_trajectory[:,:,i]

      nn_params = state.params
      ensemble_forecast_NN, C_forecast_NN, ensemble_analysis_NN, C_analysis_NN = ensrf_steps_nn(
          lorenz1,
          n_ensemble,
          ensemble_init,
          time_steps,
          observations,
          1,
          H,
          Q,
          R,
          key,
          nn_params,
          model
      )

      ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(
          lorenz1,
          n_ensemble,
          ensemble_init,
          time_steps,
          observations,
          1,
          H,
          Q,
          R,
          key
      )

      time_steps = ensemble_forecast.shape[0]
      time = jnp.arange(time_steps)

      true_states = trajectory[:,:,i]
      forecast_means = jnp.mean(ensemble_forecast, axis=2)
      analysis_means = jnp.mean(ensemble_analysis, axis=2)

      forecast_means_NN = jnp.mean(ensemble_forecast_NN, axis=2)
      analysis_means_NN = jnp.mean(ensemble_analysis_NN, axis=2)

      model_error = true_states - forecast_means
      NRMSE_fm.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - analysis_means
      NRMSE_as.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - forecast_means_NN
      NRMSE_fm_NN.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - analysis_means_NN
      NRMSE_as_NN.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))

    # full_NRMSE_fm.append(np.mean(NRMSE_fm))
    # full_NRMSE_as.append(np.mean(NRMSE_as))
    # full_NRMSE_fm_NN.append(np.mean(NRMSE_fm_NN))
    # full_NRMSE_as_NN.append(np.mean(NRMSE_as_NN))
    full_NRMSE_as[idx_n,idx_dt] = np.mean(NRMSE_as)
    full_NRMSE_as_NN[idx_n,idx_dt] = np.mean(NRMSE_as_NN)
    full_NRMSE_fm[idx_n,idx_dt] = np.mean(NRMSE_fm)
    full_NRMSE_fm_NN[idx_n,idx_dt] = np.mean(NRMSE_fm_NN)

fontsize = 20
res = 500
cols = ['k','b','orange','gray']
plt.figure(figsize=(8, 8))
for i in range(len(dts)):
  plt.plot(noise_consts, full_NRMSE_as[:, i], label=f"EnKF: dt={dts[i]}", linestyle="-", marker="o",color=cols[i], linewidth=2.5,markersize=7)
  plt.plot(noise_consts, full_NRMSE_as_NN[:, i], label=f"NN: dt={dts[i]}", linestyle="--", marker="x",color=cols[i], linewidth=2.5,markersize=7)

plt.xlabel("Noise constant",fontsize=fontsize)
plt.ylabel("Total analysis NRMSE",fontsize=fontsize)
plt.title("Total analysis NRMSE vs noise levels",fontsize=fontsize)
# red_patch = mpatches.Patch(color='red', label='The red data')
# plt.legend(handles=[red_patch])
plt.legend(fontsize=fontsize/1.25)
plt.grid(True)
plt.xticks(fontsize=fontsize / 1.25)
plt.yticks(fontsize=fontsize / 1.25)
plt.savefig("analysis_MSE_vs_dt.png", bbox_inches="tight", dpi=res)

# np.savez('MSE_results.npz', full_NRMSE_as=full_NRMSE_as,full_NRMSE_as_NN=full_NRMSE_as_NN, full_NRMSE_fm=full_NRMSE_fm, full_NRMSE_fm_NN=full_NRMSE_fm_NN)



# MSE vs. time

In [19]:
%%capture

%load_ext autoreload
%autoreload 2

!pip install filterpy

import numpy as np
import pandas as pd
import os
import sys
from tqdm import tqdm
from filterpy.kalman import EnsembleKalmanFilter as EnKF
import scipy.integrate as integrate
from scipy.interpolate import griddata
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from jax import jit, random, lax
from jax.scipy.linalg import sqrtm
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from scipy import integrate
from jax.tree_util import Partial
from functools import partial

! rm -rf ACM270_Model_Error
# if os.path.isdir('./ACM270_Model_Error/') == False:
! git clone https://github.com/sreemanti-dey/ACM270_Model_Error.git
sys.path.append('./ACM270_Model_Error/')

! cd ACM270_Model_Error
noise_consts = [0.1, 0.5]
dts = [0.05, 0.1]

for idx_n, noise_const in enumerate(noise_consts):
  for idx_dt, dt in enumerate(dts):

    """# True Trajectory and Noisy Observations with 2-Scale Lorenz96"""
    # parameters for two-scale L96
    K = 36                      # number of large scale vars
    J = 10                      # number of small scale vars per large var
    h = 0.25                   # part of coupling
    c = 10                     # part of coupling
    b = 10                     # part of coupling
    F = 10                     # forcing
    # dt = 0.5                 # time step
    # noise_const = 5
    num_steps = 200
    num_steps = int(num_steps * (0.05 / dt))
    time_steps = num_steps

    @jit
    def L96_2(xy):
        x = xy[0:K]
        y = xy[K:].reshape(K, J)

        dx = jnp.zeros(K)
        dy = jnp.zeros((K, J))

        for k in range(K):
            dxdt = -1 * x[k - 1] * (x[k - 2] - x[(k + 1) % K]) - x[k] + F - (h * c / b) * jnp.sum(y[k])
            dx = dx.at[k].set(dxdt)

            for j in range(J):
                dydt = -1 * c * b * y[k, (j + 1) % J] * (y[k, (j + 2) % J] - y[k, j - 1]) - c * y[k, j] + (h * c / b) * x[k]
                dy = dy.at[k, j].set(dydt)

        return jnp.concatenate([dx, dy.flatten()])

    @jit
    def rk4_step_lorenz96_2(x):
        f = lambda y: L96_2(y)
        k1 = dt * f(x)
        k2 = dt * f(x + k1/2)
        k3 = dt * f(x + k2/2)
        k4 = dt * f(x + k3)
        return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

    keys = iter(random.split(random.PRNGKey(0), 10))
    key = random.PRNGKey(0)
    # randomized starting point for large scale, 0 for small scale
    x0 = np.linspace(0, 1, K)  # np.random.rand(K)
    y0 = np.zeros((K, J))
    xy0 = np.concatenate([x0, y0.flatten()])

    rk4_step_lorenz96_2(xy0)

    trajectory = jnp.zeros((time_steps, len(xy0)))
    noisy_obs = jnp.zeros((time_steps, len(xy0)))
    noise_covar = jnp.eye(len(xy0)) * noise_const

    xy = xy0
    for i in tqdm(range(time_steps)):
        xy = rk4_step_lorenz96_2(xy)
        trajectory = trajectory.at[i].set(xy)
        # Add noise to the observations
        noise = random.multivariate_normal(key, jnp.zeros(len(xy0)), noise_covar)
        noisy_obs = noisy_obs.at[i].set(xy + noise)
        key, _ = random.split(key)  # Update key for the next iteration

    true_states = trajectory

    @jit
    def rk4_step_lorenz96_1(x):
        f = lambda y: (jnp.roll(y, 1) - jnp.roll(y, -2)) * jnp.roll(y, -1) - y + F
        k1 = dt * f(x)
        k2 = dt * f(x + k1/2)
        k3 = dt * f(x + k2/2)
        k4 = dt * f(x + k3)
        return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

    lorenz1 = Partial(rk4_step_lorenz96_1)
    lorenz2 = Partial(rk4_step_lorenz96_2)

    @jit
    def ledoit_wolf(P, shrinkage):
        return (1 - shrinkage) * P + shrinkage * jnp.trace(P) / P.shape[0] * jnp.eye(P.shape[0])

    @jit
    def ensrf_step(ensemble, y, H, Q, R, key, inflation = 1.2):
        n_ensemble = ensemble.shape[1]
        x_m = jnp.mean(ensemble, axis=1)
        A = ensemble - x_m.reshape((-1, 1))
        C_pred = (A @ A.T) / (n_ensemble - 1) + Q
        C_pred = ledoit_wolf(C_pred, noise_const)
        A = A * inflation
        P =  (A @ A.T) / (n_ensemble - 1) + Q
        K = P @ H.T @ jnp.linalg.inv(H @ P @ H.T + R)
        x_m += K @ (y - H @ x_m)
        M_sqrt = sqrt_m(jnp.eye(x_m.shape[0]) - K @ H)
        updated_A = M_sqrt @ A
        updated_ensemble = x_m.reshape((-1, 1)) + updated_A
        updated_P = (updated_A @ updated_A.T) / (n_ensemble - 1)
        updated_P = ledoit_wolf(updated_P, noise_const)

        ensemble = ensemble.astype(jnp.float32)
        C_pred = C_pred.astype(jnp.float32)
        updated_ensemble = updated_ensemble.astype(jnp.float32)
        updated_P = updated_P.astype(jnp.float32)
        return ensemble, C_pred, updated_ensemble, updated_P

    @jit
    def sqrt_m(M):
        eigenvalues, eigenvectors = jnp.linalg.eigh(M)
        inv_sqrt_eigenvalues = jnp.sqrt(eigenvalues)
        Lambda_inv_sqrt = jnp.diag(inv_sqrt_eigenvalues)
        M_sqrt = eigenvectors @ Lambda_inv_sqrt @ eigenvectors.T
        return M_sqrt.real

    @partial(jit, static_argnums=(3))
    def ensrf_steps(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key):
        model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
        key, *subkeys = random.split(key, num=num_steps + 1)
        subkeys = jnp.array(subkeys)

        def inner(carry, t):
            ensemble, covar = carry
            ensemble_predicted = model_vmap(ensemble)
            _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
            return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

        n = len(Q[0])
        covariance_init = jnp.zeros((n, n))
        _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

        return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

    class IncrementCorrectionModel(nn.Module):
        features: int

        @nn.compact
        def __call__(self, x):
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(x.shape[-1])(x)  # Output layer should match input dimension
            return x

    def create_model(key, input_shape, features):
        model = IncrementCorrectionModel(features)
        params = model.init(key, jnp.ones(input_shape))['params']
        return model, params

    def loss_fn(params, apply_fn, x, y):
        predictions = apply_fn({'params': params}, x)
        loss = jnp.mean((predictions - y) ** 2)
        return loss

    @jax.jit
    def train_step(state, forecast_states, increments):
        def loss_fn_wrapper(params):
            return loss_fn(params, state.apply_fn, forecast_states, increments)

        grads = jax.grad(loss_fn_wrapper)(state.params)
        state = state.apply_gradients(grads=grads)
        return state

    def train_nn(state, forecast_states, analysis_states, num_epochs=10):
        increments = analysis_states - forecast_states

        for epoch in range(num_epochs):
            state = train_step(state, forecast_states, increments)
            current_loss = loss_fn(state.params, state.apply_fn, forecast_states, increments)
            print(f'Epoch {epoch+1}, Loss: {current_loss}')

        return state

    Q = noise_const * jnp.eye(K)
    H = jnp.eye(K)
    R = jnp.eye(K) * noise_const
    n_ensemble = 10
    num_steps = time_steps
    #the ensrf (ENKF) will run single scale Lorenz on the first K variables
    observations = noisy_obs[:,:K]
    initial_state = np.zeros((1, K))
    ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
    ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(lorenz1, n_ensemble, ensemble_init, time_steps, observations, 1, H, Q, R, key)

    time_steps = ensemble_forecast.shape[0]
    time = jnp.arange(time_steps)

    forecast_means = jnp.mean(ensemble_forecast, axis=2)
    analysis_means = jnp.mean(ensemble_analysis, axis=2)

    if dt == 0.01:
      file = f"long_training_data_noise_{noise_const}_dt_{dt}.npz"
    else:
      file = f"ACM270_Model_Error/long_training_data_noise_{noise_const}_dt_{dt}.npz"
    data = np.load(file)
    forecast_state = data["forecast"]
    analysis_state = data["analysis"]

    key = jax.random.PRNGKey(0)
    input_shape = (K,)
    features = K
    model, params = create_model(key, input_shape, features)

    learning_rate = 1e-3
    optimizer = optax.adam(learning_rate)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    num_epochs = 15
    _,_,n = forecast_state.shape
    for i in range(n):
      forecast_means = forecast_state[:,:,i]
      analysis_means = analysis_state[:,:,i]
      state = train_nn(state, forecast_means, analysis_means, num_epochs=num_epochs)

    #@partial(jax.jit, static_argnums=(3))
    def ensrf_steps_nn(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key, nn_params, model):
        def corrected_state_transition(v):
            transition = state_transition_function(v)
            correction = model.apply({'params': nn_params}, transition)
            return transition + correction

        model_vmap = jax.vmap(corrected_state_transition, in_axes=1, out_axes=1)
        key, *subkeys = jax.random.split(key, num=num_steps + 1)
        subkeys = jnp.array(subkeys)

        def inner(carry, t):
            ensemble, covar = carry
            ensemble_predicted = model_vmap(ensemble)
            _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
            return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

        n = len(Q[0])
        covariance_init = jnp.zeros((n, n))
        _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

        return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

    fontsize = 20
    res = 500
    file = f"ACM270_Model_Error/testing_data_noise_{noise_const}_dt_{dt}.npz"
    data = np.load(file)
    noisy_trajectory = data["noisy_trajectory"]
    try:
      trajectory=data['true_trajectory']
    except:
      trajectory=data['trajectory']

    noisy_trajectory =  jnp.asarray(noisy_trajectory)
    trajectory = jnp.asarray(trajectory)

    m,_,n = noisy_trajectory.shape

    NRMSE_fm = np.zeros((m,))
    NRMSE_as = np.zeros((m,))
    NRMSE_fm_NN = np.zeros((m,))
    NRMSE_as_NN = np.zeros((m,))

    for i in range(n):
      observations = noisy_trajectory[:,:,i]

      nn_params = state.params
      ensemble_forecast_NN, C_forecast_NN, ensemble_analysis_NN, C_analysis_NN = ensrf_steps_nn(
          lorenz1,
          n_ensemble,
          ensemble_init,
          time_steps,
          observations,
          1,
          H,
          Q,
          R,
          key,
          nn_params,
          model
      )

      ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(
          lorenz1,
          n_ensemble,
          ensemble_init,
          time_steps,
          observations,
          1,
          H,
          Q,
          R,
          key
      )

      true_states = trajectory[:,:,i]
      time_steps = ensemble_forecast.shape[0]
      time = jnp.arange(time_steps)

      forecast_means = jnp.mean(ensemble_forecast, axis=2)
      analysis_means = jnp.mean(ensemble_analysis, axis=2)

      forecast_means_NN = jnp.mean(ensemble_forecast_NN, axis=2)
      analysis_means_NN = jnp.mean(ensemble_analysis_NN, axis=2)

      model_error = true_states - forecast_means
      NRMSE_fm += (np.diag(model_error @ model_error.T) / np.diag(true_states @ true_states.T))
      model_error = true_states - analysis_means
      NRMSE_as += (np.diag(model_error @ model_error.T) / np.diag(true_states @ true_states.T))
      model_error = true_states - forecast_means_NN
      NRMSE_fm_NN += (np.diag(model_error @ model_error.T) / np.diag(true_states @ true_states.T))
      model_error = true_states - analysis_means_NN
      NRMSE_as_NN += (np.diag(model_error @ model_error.T) / np.diag(true_states @ true_states.T))

    NRMSE_fm /= n
    NRMSE_as /= n
    NRMSE_fm_NN /= n
    NRMSE_as_NN /= n

    time = np.arange(0, time_steps)*dt
    fontsize = 25

    plt.figure(figsize=(14, 8))

    i = 0  # Index of the state variable to plot
    plt.plot(time, NRMSE_fm, label=f"Forecast (uncorrected)", linestyle="-", marker="o", color = 'k', linewidth=2.5,markersize=7)
    plt.plot(time, NRMSE_as, label=f"Analysis (uncorrected)", linestyle="-", marker="x", color = 'k', linewidth=2.5,markersize=7)
    plt.plot(time, NRMSE_fm_NN, label=f"Forecast (corrected)", linestyle="--", marker="o", color = 'b', linewidth=2.5,markersize=7)
    plt.plot(time, NRMSE_as_NN, label=f"Analysis (corrected)", linestyle="--", marker="x", color = 'b', linewidth=2.5,markersize=7)

    plt.xlabel("Time",fontsize=fontsize)
    plt.ylabel("Average NRMSE",fontsize=fontsize)
    plt.title(f"Average NRMSE vs time",fontsize=fontsize)
    plt.legend(fontsize=fontsize/1.5)
    plt.grid(True)
    plt.xticks(fontsize=fontsize / 1.25)
    plt.yticks(fontsize=fontsize / 1.25)
    # plt.show()
    plt.savefig(f'long_MSE_vs_time_noise_{noise_const}_dt_{dt}.png', bbox_inches="tight", dpi=res)

# MSE vs. training data

In [None]:
%%capture

%load_ext autoreload
%autoreload 2

!pip install filterpy

import numpy as np
import pandas as pd
import os
import sys
from tqdm import tqdm
from filterpy.kalman import EnsembleKalmanFilter as EnKF
import matplotlib.pyplot as plt
import scipy.integrate as integrate
from scipy.interpolate import griddata
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from jax import jit, random, lax
from jax.scipy.linalg import sqrtm
from functools import partial
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from scipy import integrate
from jax.tree_util import Partial
from functools import partial

! rm -rf ACM270_Model_Error
# if os.path.isdir('./ACM270_Model_Error/') == False:
! git clone https://github.com/sreemanti-dey/ACM270_Model_Error.git
sys.path.append('./ACM270_Model_Error/')

! cd ACM270_Model_Error

"""# True Trajectory and Noisy Observations with 2-Scale Lorenz96"""

full_NRMSE_fm = []
full_NRMSE_as = []
full_NRMSE_fm_NN = []
full_NRMSE_as_NN = []

training_data = np.linspace(1,50,15, dtype=int)

for idx_n, train_data in enumerate(training_data):
    # parameters for two-scale L96
    K = 36                      # number of large scale vars
    J = 10                      # number of small scale vars per large var
    h = 0.25                   # part of coupling
    c = 10                     # part of coupling
    b = 10                     # part of coupling
    F = 10                     # forcing
    dt = 0.1                  # time step
    noise_const = 0.1
    num_steps = 200
    num_steps = int(num_steps * (0.05 / dt))
    time_steps = num_steps

    @jit
    def L96_2(xy):
        x = xy[0:K]
        y = xy[K:].reshape(K, J)

        dx = jnp.zeros(K)
        dy = jnp.zeros((K, J))

        for k in range(K):
            dxdt = -1 * x[k - 1] * (x[k - 2] - x[(k + 1) % K]) - x[k] + F - (h * c / b) * jnp.sum(y[k])
            dx = dx.at[k].set(dxdt)

            for j in range(J):
                dydt = -1 * c * b * y[k, (j + 1) % J] * (y[k, (j + 2) % J] - y[k, j - 1]) - c * y[k, j] + (h * c / b) * x[k]
                dy = dy.at[k, j].set(dydt)

        return jnp.concatenate([dx, dy.flatten()])

    @jit
    def rk4_step_lorenz96_2(x):
        f = lambda y: L96_2(y)
        k1 = dt * f(x)
        k2 = dt * f(x + k1/2)
        k3 = dt * f(x + k2/2)
        k4 = dt * f(x + k3)
        return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

    keys = iter(random.split(random.PRNGKey(0), 10))
    key = random.PRNGKey(0)
    # randomized starting point for large scale, 0 for small scale
    x0 = np.linspace(0, 1, K)  # np.random.rand(K)
    y0 = np.zeros((K, J))
    xy0 = np.concatenate([x0, y0.flatten()])

    rk4_step_lorenz96_2(xy0)

    trajectory = jnp.zeros((time_steps, len(xy0)))
    noisy_obs = jnp.zeros((time_steps, len(xy0)))
    noise_covar = jnp.eye(len(xy0)) * noise_const

    xy = xy0
    for i in tqdm(range(time_steps)):
        xy = rk4_step_lorenz96_2(xy)
        trajectory = trajectory.at[i].set(xy)
        # Add noise to the observations
        noise = random.multivariate_normal(key, jnp.zeros(len(xy0)), noise_covar)
        noisy_obs = noisy_obs.at[i].set(xy + noise)
        key, _ = random.split(key)  # Update key for the next iteration

    """## True Trajectory

    ## Noisy Observations
    """

    true_states = trajectory

    """# EnKF for 1-Scale Assimilation"""

    @jit
    def rk4_step_lorenz96_1(x):
        f = lambda y: (jnp.roll(y, 1) - jnp.roll(y, -2)) * jnp.roll(y, -1) - y + F
        k1 = dt * f(x)
        k2 = dt * f(x + k1/2)
        k3 = dt * f(x + k2/2)
        k4 = dt * f(x + k3)
        return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

    lorenz1 = Partial(rk4_step_lorenz96_1)
    lorenz2 = Partial(rk4_step_lorenz96_2)

    @jit
    def ledoit_wolf(P, shrinkage):
        return (1 - shrinkage) * P + shrinkage * jnp.trace(P) / P.shape[0] * jnp.eye(P.shape[0])

    @jit
    def ensrf_step(ensemble, y, H, Q, R, key, inflation = 1.2):
        n_ensemble = ensemble.shape[1]
        x_m = jnp.mean(ensemble, axis=1)
        A = ensemble - x_m.reshape((-1, 1))
        C_pred = (A @ A.T) / (n_ensemble - 1) + Q
        C_pred = ledoit_wolf(C_pred, noise_const)
        A = A * inflation
        P =  (A @ A.T) / (n_ensemble - 1) + Q
        K = P @ H.T @ jnp.linalg.inv(H @ P @ H.T + R)
        x_m += K @ (y - H @ x_m)
        M_sqrt = sqrt_m(jnp.eye(x_m.shape[0]) - K @ H)
        updated_A = M_sqrt @ A
        updated_ensemble = x_m.reshape((-1, 1)) + updated_A
        updated_P = (updated_A @ updated_A.T) / (n_ensemble - 1)
        updated_P = ledoit_wolf(updated_P, noise_const)

        ensemble = ensemble.astype(jnp.float32)
        C_pred = C_pred.astype(jnp.float32)
        updated_ensemble = updated_ensemble.astype(jnp.float32)
        updated_P = updated_P.astype(jnp.float32)
        return ensemble, C_pred, updated_ensemble, updated_P

    @jit
    def sqrt_m(M):
        eigenvalues, eigenvectors = jnp.linalg.eigh(M)
        inv_sqrt_eigenvalues = jnp.sqrt(eigenvalues)
        Lambda_inv_sqrt = jnp.diag(inv_sqrt_eigenvalues)
        M_sqrt = eigenvectors @ Lambda_inv_sqrt @ eigenvectors.T
        return M_sqrt.real

    @partial(jit, static_argnums=(3))
    def ensrf_steps(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key):
        model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
        key, *subkeys = random.split(key, num=num_steps + 1)
        subkeys = jnp.array(subkeys)

        def inner(carry, t):
            ensemble, covar = carry
            ensemble_predicted = model_vmap(ensemble)
            _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
            return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

        n = len(Q[0])
        covariance_init = jnp.zeros((n, n))
        _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

        return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

    class IncrementCorrectionModel(nn.Module):
        features: int

        @nn.compact
        def __call__(self, x):
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(x.shape[-1])(x)  # Output layer should match input dimension
            return x

    def create_model(key, input_shape, features):
        model = IncrementCorrectionModel(features)
        params = model.init(key, jnp.ones(input_shape))['params']
        return model, params

    def loss_fn(params, apply_fn, x, y):
        predictions = apply_fn({'params': params}, x)
        loss = jnp.mean((predictions - y) ** 2)
        return loss

    @jax.jit
    def train_step(state, forecast_states, increments):
        def loss_fn_wrapper(params):
            return loss_fn(params, state.apply_fn, forecast_states, increments)

        grads = jax.grad(loss_fn_wrapper)(state.params)
        state = state.apply_gradients(grads=grads)
        return state

    def train_nn(state, forecast_states, analysis_states, num_epochs=10):
        increments = analysis_states - forecast_states

        for epoch in range(num_epochs):
            state = train_step(state, forecast_states, increments)
            current_loss = loss_fn(state.params, state.apply_fn, forecast_states, increments)
            print(f'Epoch {epoch+1}, Loss: {current_loss}')

        return state

    Q = noise_const * jnp.eye(K)
    H = jnp.eye(K)
    R = jnp.eye(K) * noise_const
    n_ensemble = 10
    num_steps = time_steps
    #the ensrf (ENKF) will run single scale Lorenz on the first K variables
    observations = noisy_obs[:,:K]
    initial_state = np.zeros((1, K))
    ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
    ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(lorenz1, n_ensemble, ensemble_init, time_steps, observations, 1, H, Q, R, key)

    time_steps = ensemble_forecast.shape[0]
    time = jnp.arange(time_steps)

    forecast_means = jnp.mean(ensemble_forecast, axis=2)
    analysis_means = jnp.mean(ensemble_analysis, axis=2)

    file = f"ACM270_Model_Error/long_training_data_noise_{noise_const}_dt_{dt}.npz"
    data = np.load(file)
    forecast_state = data["forecast"]
    analysis_state = data["analysis"]

    key = jax.random.PRNGKey(0)
    input_shape = (K,)
    features = K
    model, params = create_model(key, input_shape, features)

    learning_rate = 1e-3
    optimizer = optax.adam(learning_rate)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    num_epochs = 15
    _,_,n = forecast_state.shape
    for i in range(train_data):
      forecast_means = forecast_state[:,:,i]
      analysis_means = analysis_state[:,:,i]
      state = train_nn(state, forecast_means, analysis_means, num_epochs=num_epochs)

    #@partial(jax.jit, static_argnums=(3))
    def ensrf_steps_nn(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key, nn_params, model):
        def corrected_state_transition(v):
            transition = state_transition_function(v)
            correction = model.apply({'params': nn_params}, transition)
            return transition + correction

        model_vmap = jax.vmap(corrected_state_transition, in_axes=1, out_axes=1)
        key, *subkeys = jax.random.split(key, num=num_steps + 1)
        subkeys = jnp.array(subkeys)

        def inner(carry, t):
            ensemble, covar = carry
            ensemble_predicted = model_vmap(ensemble)
            _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
            return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

        n = len(Q[0])
        covariance_init = jnp.zeros((n, n))
        _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

        return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

    file = f"ACM270_Model_Error/testing_data_noise_{noise_const}_dt_{dt}.npz"
    data = np.load(file)
    noisy_trajectory = data["noisy_trajectory"]
    try:
      trajectory=data['true_trajectory']
    except:
      trajectory=data['trajectory']

    noisy_trajectory =  jnp.asarray(noisy_trajectory)
    trajectory = jnp.asarray(trajectory)

    NRMSE_fm = []
    NRMSE_as = []
    NRMSE_fm_NN = []
    NRMSE_as_NN = []

    _,_,n = noisy_trajectory.shape

    for i in range(n):
      observations = noisy_trajectory[:,:,i]

      nn_params = state.params
      ensemble_forecast_NN, C_forecast_NN, ensemble_analysis_NN, C_analysis_NN = ensrf_steps_nn(
          lorenz1,
          n_ensemble,
          ensemble_init,
          time_steps,
          observations,
          1,
          H,
          Q,
          R,
          key,
          nn_params,
          model
      )

      ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(
          lorenz1,
          n_ensemble,
          ensemble_init,
          time_steps,
          observations,
          1,
          H,
          Q,
          R,
          key
      )

      time_steps = ensemble_forecast.shape[0]
      time = jnp.arange(time_steps)

      true_states = trajectory[:,:,i]
      forecast_means = jnp.mean(ensemble_forecast, axis=2)
      analysis_means = jnp.mean(ensemble_analysis, axis=2)

      forecast_means_NN = jnp.mean(ensemble_forecast_NN, axis=2)
      analysis_means_NN = jnp.mean(ensemble_analysis_NN, axis=2)

      model_error = true_states - forecast_means
      NRMSE_fm.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - analysis_means
      NRMSE_as.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - forecast_means_NN
      NRMSE_fm_NN.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - analysis_means_NN
      NRMSE_as_NN.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))

    full_NRMSE_fm.append(np.mean(NRMSE_fm))
    full_NRMSE_as.append(np.mean(NRMSE_as))
    full_NRMSE_fm_NN.append(np.mean(NRMSE_fm_NN))
    full_NRMSE_as_NN.append(np.mean(NRMSE_as_NN))

fontsize = 20
res = 500

plt.figure(figsize=(8, 8))
plt.plot(training_data, full_NRMSE_as, label="Analysis (uncorrected)", linestyle="-",color='k', linewidth=2.5,markersize=7)
plt.plot(training_data, full_NRMSE_as_NN, label="Analysis (corrected)", linestyle="-", marker="o",color='b', linewidth=2.5,markersize=7)
plt.plot(training_data, full_NRMSE_fm, label="Forecast (uncorrected)", linestyle="--",color='k', linewidth=2.5,markersize=7)
plt.plot(training_data, full_NRMSE_fm_NN, label="Forecast (corrected)", linestyle="--", marker="x",color='b', linewidth=2.5,markersize=7)

plt.xlabel("Trajectories used in training",fontsize=fontsize)
plt.ylabel("Total NRMSE",fontsize=fontsize)
plt.title("Total NRMSE vs training trajectories",fontsize=fontsize)
plt.legend(fontsize=fontsize/1.25)
plt.grid(True)
plt.xticks(fontsize=fontsize / 1.25)
plt.yticks(fontsize=fontsize / 1.25)
plt.savefig(f"analysis_MSE_vs_traindata_noise_{noise_const}_dt_{dt}.png", bbox_inches="tight", dpi=res)



# MSE vs. Epoch

In [None]:
%%capture

%load_ext autoreload
%autoreload 2

!pip install filterpy

import numpy as np
import pandas as pd
import os
import sys
from tqdm import tqdm
from filterpy.kalman import EnsembleKalmanFilter as EnKF
import matplotlib.pyplot as plt
import scipy.integrate as integrate
from scipy.interpolate import griddata
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from jax import jit, random, lax
from jax.scipy.linalg import sqrtm
from functools import partial
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from scipy import integrate
from jax.tree_util import Partial
from functools import partial

! rm -rf ACM270_Model_Error
# if os.path.isdir('./ACM270_Model_Error/') == False:
! git clone https://github.com/sreemanti-dey/ACM270_Model_Error.git
sys.path.append('./ACM270_Model_Error/')

! cd ACM270_Model_Error

"""# True Trajectory and Noisy Observations with 2-Scale Lorenz96"""

full_NRMSE_fm = []
full_NRMSE_as = []
full_NRMSE_fm_NN = []
full_NRMSE_as_NN = []

epochs = np.linspace(1,50,15, dtype=int)

for idx_n, train_epoch in enumerate(epochs):
    # parameters for two-scale L96
    K = 36                      # number of large scale vars
    J = 10                      # number of small scale vars per large var
    h = 0.25                   # part of coupling
    c = 10                     # part of coupling
    b = 10                     # part of coupling
    F = 10                     # forcing
    dt = 0.05                  # time step
    noise_const = 0.1
    num_steps = 200
    num_steps = int(num_steps * (0.05 / dt))
    time_steps = num_steps

    @jit
    def L96_2(xy):
        x = xy[0:K]
        y = xy[K:].reshape(K, J)

        dx = jnp.zeros(K)
        dy = jnp.zeros((K, J))

        for k in range(K):
            dxdt = -1 * x[k - 1] * (x[k - 2] - x[(k + 1) % K]) - x[k] + F - (h * c / b) * jnp.sum(y[k])
            dx = dx.at[k].set(dxdt)

            for j in range(J):
                dydt = -1 * c * b * y[k, (j + 1) % J] * (y[k, (j + 2) % J] - y[k, j - 1]) - c * y[k, j] + (h * c / b) * x[k]
                dy = dy.at[k, j].set(dydt)

        return jnp.concatenate([dx, dy.flatten()])

    @jit
    def rk4_step_lorenz96_2(x):
        f = lambda y: L96_2(y)
        k1 = dt * f(x)
        k2 = dt * f(x + k1/2)
        k3 = dt * f(x + k2/2)
        k4 = dt * f(x + k3)
        return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

    keys = iter(random.split(random.PRNGKey(0), 10))
    key = random.PRNGKey(0)
    # randomized starting point for large scale, 0 for small scale
    x0 = np.linspace(0, 1, K)  # np.random.rand(K)
    y0 = np.zeros((K, J))
    xy0 = np.concatenate([x0, y0.flatten()])

    rk4_step_lorenz96_2(xy0)

    trajectory = jnp.zeros((time_steps, len(xy0)))
    noisy_obs = jnp.zeros((time_steps, len(xy0)))
    noise_covar = jnp.eye(len(xy0)) * noise_const

    xy = xy0
    for i in tqdm(range(time_steps)):
        xy = rk4_step_lorenz96_2(xy)
        trajectory = trajectory.at[i].set(xy)
        # Add noise to the observations
        noise = random.multivariate_normal(key, jnp.zeros(len(xy0)), noise_covar)
        noisy_obs = noisy_obs.at[i].set(xy + noise)
        key, _ = random.split(key)  # Update key for the next iteration

    """## True Trajectory

    ## Noisy Observations
    """

    true_states = trajectory

    """# EnKF for 1-Scale Assimilation"""

    @jit
    def rk4_step_lorenz96_1(x):
        f = lambda y: (jnp.roll(y, 1) - jnp.roll(y, -2)) * jnp.roll(y, -1) - y + F
        k1 = dt * f(x)
        k2 = dt * f(x + k1/2)
        k3 = dt * f(x + k2/2)
        k4 = dt * f(x + k3)
        return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

    lorenz1 = Partial(rk4_step_lorenz96_1)
    lorenz2 = Partial(rk4_step_lorenz96_2)

    @jit
    def ledoit_wolf(P, shrinkage):
        return (1 - shrinkage) * P + shrinkage * jnp.trace(P) / P.shape[0] * jnp.eye(P.shape[0])

    @jit
    def ensrf_step(ensemble, y, H, Q, R, key, inflation = 1.2):
        n_ensemble = ensemble.shape[1]
        x_m = jnp.mean(ensemble, axis=1)
        A = ensemble - x_m.reshape((-1, 1))
        C_pred = (A @ A.T) / (n_ensemble - 1) + Q
        C_pred = ledoit_wolf(C_pred, noise_const)
        A = A * inflation
        P =  (A @ A.T) / (n_ensemble - 1) + Q
        K = P @ H.T @ jnp.linalg.inv(H @ P @ H.T + R)
        x_m += K @ (y - H @ x_m)
        M_sqrt = sqrt_m(jnp.eye(x_m.shape[0]) - K @ H)
        updated_A = M_sqrt @ A
        updated_ensemble = x_m.reshape((-1, 1)) + updated_A
        updated_P = (updated_A @ updated_A.T) / (n_ensemble - 1)
        updated_P = ledoit_wolf(updated_P, noise_const)

        ensemble = ensemble.astype(jnp.float32)
        C_pred = C_pred.astype(jnp.float32)
        updated_ensemble = updated_ensemble.astype(jnp.float32)
        updated_P = updated_P.astype(jnp.float32)
        return ensemble, C_pred, updated_ensemble, updated_P

    @jit
    def sqrt_m(M):
        eigenvalues, eigenvectors = jnp.linalg.eigh(M)
        inv_sqrt_eigenvalues = jnp.sqrt(eigenvalues)
        Lambda_inv_sqrt = jnp.diag(inv_sqrt_eigenvalues)
        M_sqrt = eigenvectors @ Lambda_inv_sqrt @ eigenvectors.T
        return M_sqrt.real

    @partial(jit, static_argnums=(3))
    def ensrf_steps(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key):
        model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
        key, *subkeys = random.split(key, num=num_steps + 1)
        subkeys = jnp.array(subkeys)

        def inner(carry, t):
            ensemble, covar = carry
            ensemble_predicted = model_vmap(ensemble)
            _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
            return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

        n = len(Q[0])
        covariance_init = jnp.zeros((n, n))
        _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

        return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

    class IncrementCorrectionModel(nn.Module):
        features: int

        @nn.compact
        def __call__(self, x):
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(self.features)(x)
            x = nn.relu(x)
            x = nn.Dense(x.shape[-1])(x)  # Output layer should match input dimension
            return x

    def create_model(key, input_shape, features):
        model = IncrementCorrectionModel(features)
        params = model.init(key, jnp.ones(input_shape))['params']
        return model, params

    def loss_fn(params, apply_fn, x, y):
        predictions = apply_fn({'params': params}, x)
        loss = jnp.mean((predictions - y) ** 2)
        return loss

    @jax.jit
    def train_step(state, forecast_states, increments):
        def loss_fn_wrapper(params):
            return loss_fn(params, state.apply_fn, forecast_states, increments)

        grads = jax.grad(loss_fn_wrapper)(state.params)
        state = state.apply_gradients(grads=grads)
        return state

    def train_nn(state, forecast_states, analysis_states, num_epochs=10):
        increments = analysis_states - forecast_states

        for epoch in range(num_epochs):
            state = train_step(state, forecast_states, increments)
            current_loss = loss_fn(state.params, state.apply_fn, forecast_states, increments)
            print(f'Epoch {epoch+1}, Loss: {current_loss}')

        return state

    Q = noise_const * jnp.eye(K)
    H = jnp.eye(K)
    R = jnp.eye(K) * noise_const
    n_ensemble = 10
    num_steps = time_steps
    #the ensrf (ENKF) will run single scale Lorenz on the first K variables
    observations = noisy_obs[:,:K]
    initial_state = np.zeros((1, K))
    ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
    ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(lorenz1, n_ensemble, ensemble_init, time_steps, observations, 1, H, Q, R, key)

    time_steps = ensemble_forecast.shape[0]
    time = jnp.arange(time_steps)

    forecast_means = jnp.mean(ensemble_forecast, axis=2)
    analysis_means = jnp.mean(ensemble_analysis, axis=2)

    file = f"ACM270_Model_Error/long_training_data_noise_{noise_const}_dt_{dt}.npz"
    data = np.load(file)
    forecast_state = data["forecast"]
    analysis_state = data["analysis"]

    key = jax.random.PRNGKey(0)
    input_shape = (K,)
    features = K
    model, params = create_model(key, input_shape, features)

    learning_rate = 1e-3
    optimizer = optax.adam(learning_rate)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    num_epochs = train_epoch
    _,_,n = forecast_state.shape
    for i in range(50):
      forecast_means = forecast_state[:,:,i]
      analysis_means = analysis_state[:,:,i]
      state = train_nn(state, forecast_means, analysis_means, num_epochs=num_epochs)

    #@partial(jax.jit, static_argnums=(3))
    def ensrf_steps_nn(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key, nn_params, model):
        def corrected_state_transition(v):
            transition = state_transition_function(v)
            correction = model.apply({'params': nn_params}, transition)
            return transition + correction

        model_vmap = jax.vmap(corrected_state_transition, in_axes=1, out_axes=1)
        key, *subkeys = jax.random.split(key, num=num_steps + 1)
        subkeys = jnp.array(subkeys)

        def inner(carry, t):
            ensemble, covar = carry
            ensemble_predicted = model_vmap(ensemble)
            _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
            return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

        n = len(Q[0])
        covariance_init = jnp.zeros((n, n))
        _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

        return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

    file = f"ACM270_Model_Error/testing_data_noise_{noise_const}_dt_{dt}.npz"
    data = np.load(file)
    noisy_trajectory = data["noisy_trajectory"]
    try:
      trajectory=data['true_trajectory']
    except:
      trajectory=data['trajectory']

    noisy_trajectory =  jnp.asarray(noisy_trajectory)
    trajectory = jnp.asarray(trajectory)

    NRMSE_fm = []
    NRMSE_as = []
    NRMSE_fm_NN = []
    NRMSE_as_NN = []

    _,_,n = noisy_trajectory.shape

    for i in range(n):
      observations = noisy_trajectory[:,:,i]

      nn_params = state.params
      ensemble_forecast_NN, C_forecast_NN, ensemble_analysis_NN, C_analysis_NN = ensrf_steps_nn(
          lorenz1,
          n_ensemble,
          ensemble_init,
          time_steps,
          observations,
          1,
          H,
          Q,
          R,
          key,
          nn_params,
          model
      )

      ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(
          lorenz1,
          n_ensemble,
          ensemble_init,
          time_steps,
          observations,
          1,
          H,
          Q,
          R,
          key
      )

      time_steps = ensemble_forecast.shape[0]
      time = jnp.arange(time_steps)

      true_states = trajectory[:,:,i]
      forecast_means = jnp.mean(ensemble_forecast, axis=2)
      analysis_means = jnp.mean(ensemble_analysis, axis=2)

      forecast_means_NN = jnp.mean(ensemble_forecast_NN, axis=2)
      analysis_means_NN = jnp.mean(ensemble_analysis_NN, axis=2)

      model_error = true_states - forecast_means
      NRMSE_fm.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - analysis_means
      NRMSE_as.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - forecast_means_NN
      NRMSE_fm_NN.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
      model_error = true_states - analysis_means_NN
      NRMSE_as_NN.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))

    full_NRMSE_fm.append(np.mean(NRMSE_fm))
    full_NRMSE_as.append(np.mean(NRMSE_as))
    full_NRMSE_fm_NN.append(np.mean(NRMSE_fm_NN))
    full_NRMSE_as_NN.append(np.mean(NRMSE_as_NN))

fontsize = 20
res = 500

plt.figure(figsize=(8, 8))
plt.plot(epochs, full_NRMSE_as, label="Analysis (uncorrected)", linestyle="-",color='k', linewidth=2.5,markersize=7)
plt.plot(epochs, full_NRMSE_as_NN, label="Analysis (corrected)", linestyle="-", marker="o",color='b', linewidth=2.5,markersize=7)
plt.plot(epochs, full_NRMSE_fm, label="Forecast (uncorrected)", linestyle="--",color='k', linewidth=7)
plt.plot(epochs, full_NRMSE_fm_NN, label="Forecast (corrected)", linestyle="--", marker="x",color='b', linewidth=2.5,markersize=7)

plt.xlabel("Number of Epochs",fontsize=fontsize)
plt.ylabel("Total NRMSE",fontsize=fontsize)
plt.title("Total NRMSE vs Epochs",fontsize=fontsize)
plt.legend(fontsize=fontsize/1.25)
plt.grid(True)
plt.xticks(fontsize=fontsize / 1.25)
plt.yticks(fontsize=fontsize / 1.25)
plt.savefig(f"MSE_vs_epoch_noise_{noise_const}_dt_{dt}.png", bbox_inches="tight", dpi=res)


# MSE vs. learning rate

In [21]:
%%capture

%load_ext autoreload
%autoreload 2

!pip install filterpy

import numpy as np
import pandas as pd
import os
import sys
from tqdm import tqdm
from filterpy.kalman import EnsembleKalmanFilter as EnKF
import matplotlib.pyplot as plt
import scipy.integrate as integrate
from scipy.interpolate import griddata
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from jax import jit, random, lax
from jax.scipy.linalg import sqrtm
from functools import partial
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from scipy import integrate
from jax.tree_util import Partial
from functools import partial

! rm -rf ACM270_Model_Error
# if os.path.isdir('./ACM270_Model_Error/') == False:
! git clone https://github.com/sreemanti-dey/ACM270_Model_Error.git
sys.path.append('./ACM270_Model_Error/')

! cd ACM270_Model_Error

"""# True Trajectory and Noisy Observations with 2-Scale Lorenz96"""
for iii in range(2):
  if iii == 0:
      # parameters for two-scale L96
      K = 36                      # number of large scale vars
      J = 10                      # number of small scale vars per large var
      h = 0.25                   # part of coupling
      c = 10                     # part of coupling
      b = 10                     # part of coupling
      F = 10                     # forcing
      dt = 0.1                  # time step
      noise_const = 0.1
      num_steps = 200
      num_steps = int(num_steps * (0.05 / dt))
      time_steps = num_steps
  else:
      # parameters for two-scale L96
      K = 36                      # number of large scale vars
      J = 10                      # number of small scale vars per large var
      h = 0.25                   # part of coupling
      c = 10                     # part of coupling
      b = 10                     # part of coupling
      F = 10                     # forcing
      dt = 0.05                  # time step
      noise_const = 0.1
      num_steps = 200
      num_steps = int(num_steps * (0.05 / dt))
      time_steps = num_steps
  full_NRMSE_fm = []
  full_NRMSE_as = []
  full_NRMSE_fm_NN = []
  full_NRMSE_as_NN = []

  LRs = np.logspace(-5,-1,15)

  for idx_n, LR in enumerate(LRs):
      @jit
      def L96_2(xy):
          x = xy[0:K]
          y = xy[K:].reshape(K, J)

          dx = jnp.zeros(K)
          dy = jnp.zeros((K, J))

          for k in range(K):
              dxdt = -1 * x[k - 1] * (x[k - 2] - x[(k + 1) % K]) - x[k] + F - (h * c / b) * jnp.sum(y[k])
              dx = dx.at[k].set(dxdt)

              for j in range(J):
                  dydt = -1 * c * b * y[k, (j + 1) % J] * (y[k, (j + 2) % J] - y[k, j - 1]) - c * y[k, j] + (h * c / b) * x[k]
                  dy = dy.at[k, j].set(dydt)

          return jnp.concatenate([dx, dy.flatten()])

      @jit
      def rk4_step_lorenz96_2(x):
          f = lambda y: L96_2(y)
          k1 = dt * f(x)
          k2 = dt * f(x + k1/2)
          k3 = dt * f(x + k2/2)
          k4 = dt * f(x + k3)
          return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

      keys = iter(random.split(random.PRNGKey(0), 10))
      key = random.PRNGKey(0)
      # randomized starting point for large scale, 0 for small scale
      x0 = np.linspace(0, 1, K)  # np.random.rand(K)
      y0 = np.zeros((K, J))
      xy0 = np.concatenate([x0, y0.flatten()])

      rk4_step_lorenz96_2(xy0)

      trajectory = jnp.zeros((time_steps, len(xy0)))
      noisy_obs = jnp.zeros((time_steps, len(xy0)))
      noise_covar = jnp.eye(len(xy0)) * noise_const

      xy = xy0
      for i in tqdm(range(time_steps)):
          xy = rk4_step_lorenz96_2(xy)
          trajectory = trajectory.at[i].set(xy)
          # Add noise to the observations
          noise = random.multivariate_normal(key, jnp.zeros(len(xy0)), noise_covar)
          noisy_obs = noisy_obs.at[i].set(xy + noise)
          key, _ = random.split(key)  # Update key for the next iteration

      """## True Trajectory

      ## Noisy Observations
      """

      true_states = trajectory

      """# EnKF for 1-Scale Assimilation"""

      @jit
      def rk4_step_lorenz96_1(x):
          f = lambda y: (jnp.roll(y, 1) - jnp.roll(y, -2)) * jnp.roll(y, -1) - y + F
          k1 = dt * f(x)
          k2 = dt * f(x + k1/2)
          k3 = dt * f(x + k2/2)
          k4 = dt * f(x + k3)
          return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

      lorenz1 = Partial(rk4_step_lorenz96_1)
      lorenz2 = Partial(rk4_step_lorenz96_2)

      @jit
      def ledoit_wolf(P, shrinkage):
          return (1 - shrinkage) * P + shrinkage * jnp.trace(P) / P.shape[0] * jnp.eye(P.shape[0])

      @jit
      def ensrf_step(ensemble, y, H, Q, R, key, inflation = 1.2):
          n_ensemble = ensemble.shape[1]
          x_m = jnp.mean(ensemble, axis=1)
          A = ensemble - x_m.reshape((-1, 1))
          C_pred = (A @ A.T) / (n_ensemble - 1) + Q
          C_pred = ledoit_wolf(C_pred, noise_const)
          A = A * inflation
          P =  (A @ A.T) / (n_ensemble - 1) + Q
          K = P @ H.T @ jnp.linalg.inv(H @ P @ H.T + R)
          x_m += K @ (y - H @ x_m)
          M_sqrt = sqrt_m(jnp.eye(x_m.shape[0]) - K @ H)
          updated_A = M_sqrt @ A
          updated_ensemble = x_m.reshape((-1, 1)) + updated_A
          updated_P = (updated_A @ updated_A.T) / (n_ensemble - 1)
          updated_P = ledoit_wolf(updated_P, noise_const)

          ensemble = ensemble.astype(jnp.float32)
          C_pred = C_pred.astype(jnp.float32)
          updated_ensemble = updated_ensemble.astype(jnp.float32)
          updated_P = updated_P.astype(jnp.float32)
          return ensemble, C_pred, updated_ensemble, updated_P

      @jit
      def sqrt_m(M):
          eigenvalues, eigenvectors = jnp.linalg.eigh(M)
          inv_sqrt_eigenvalues = jnp.sqrt(eigenvalues)
          Lambda_inv_sqrt = jnp.diag(inv_sqrt_eigenvalues)
          M_sqrt = eigenvectors @ Lambda_inv_sqrt @ eigenvectors.T
          return M_sqrt.real

      @partial(jit, static_argnums=(3))
      def ensrf_steps(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key):
          model_vmap = jax.vmap(lambda v: state_transition_function(v), in_axes=1, out_axes=1)
          key, *subkeys = random.split(key, num=num_steps + 1)
          subkeys = jnp.array(subkeys)

          def inner(carry, t):
              ensemble, covar = carry
              ensemble_predicted = model_vmap(ensemble)
              _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
              return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

          n = len(Q[0])
          covariance_init = jnp.zeros((n, n))
          _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

          return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

      class IncrementCorrectionModel(nn.Module):
          features: int

          @nn.compact
          def __call__(self, x):
              x = nn.Dense(self.features)(x)
              x = nn.relu(x)
              x = nn.Dense(self.features)(x)
              x = nn.relu(x)
              x = nn.Dense(self.features)(x)
              x = nn.relu(x)
              x = nn.Dense(x.shape[-1])(x)  # Output layer should match input dimension
              return x

      def create_model(key, input_shape, features):
          model = IncrementCorrectionModel(features)
          params = model.init(key, jnp.ones(input_shape))['params']
          return model, params

      def loss_fn(params, apply_fn, x, y):
          predictions = apply_fn({'params': params}, x)
          loss = jnp.mean((predictions - y) ** 2)
          return loss

      @jax.jit
      def train_step(state, forecast_states, increments):
          def loss_fn_wrapper(params):
              return loss_fn(params, state.apply_fn, forecast_states, increments)

          grads = jax.grad(loss_fn_wrapper)(state.params)
          state = state.apply_gradients(grads=grads)
          return state

      def train_nn(state, forecast_states, analysis_states, num_epochs=10):
          increments = analysis_states - forecast_states

          for epoch in range(num_epochs):
              state = train_step(state, forecast_states, increments)
              current_loss = loss_fn(state.params, state.apply_fn, forecast_states, increments)
              print(f'Epoch {epoch+1}, Loss: {current_loss}')

          return state

      Q = noise_const * jnp.eye(K)
      H = jnp.eye(K)
      R = jnp.eye(K) * noise_const
      n_ensemble = 10
      num_steps = time_steps
      #the ensrf (ENKF) will run single scale Lorenz on the first K variables
      observations = noisy_obs[:,:K]
      initial_state = np.zeros((1, K))
      ensemble_init = random.multivariate_normal(key, initial_state, Q, (n_ensemble,)).T
      ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(lorenz1, n_ensemble, ensemble_init, time_steps, observations, 1, H, Q, R, key)

      time_steps = ensemble_forecast.shape[0]
      time = jnp.arange(time_steps)

      forecast_means = jnp.mean(ensemble_forecast, axis=2)
      analysis_means = jnp.mean(ensemble_analysis, axis=2)

      file = f"ACM270_Model_Error/long_training_data_noise_{noise_const}_dt_{dt}.npz"
      data = np.load(file)
      forecast_state = data["forecast"]
      analysis_state = data["analysis"]

      key = jax.random.PRNGKey(0)
      input_shape = (K,)
      features = K
      model, params = create_model(key, input_shape, features)

      learning_rate = LR
      optimizer = optax.adam(learning_rate)
      state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

      num_epochs = 15
      _,_,n = forecast_state.shape
      for i in range(50):
        forecast_means = forecast_state[:,:,i]
        analysis_means = analysis_state[:,:,i]
        state = train_nn(state, forecast_means, analysis_means, num_epochs=num_epochs)

      #@partial(jax.jit, static_argnums=(3))
      def ensrf_steps_nn(state_transition_function, n_ensemble, ensemble_init, num_steps, observations, observation_interval, H, Q, R, key, nn_params, model):
          def corrected_state_transition(v):
              transition = state_transition_function(v)
              correction = model.apply({'params': nn_params}, transition)
              return transition + correction

          model_vmap = jax.vmap(corrected_state_transition, in_axes=1, out_axes=1)
          key, *subkeys = jax.random.split(key, num=num_steps + 1)
          subkeys = jnp.array(subkeys)

          def inner(carry, t):
              ensemble, covar = carry
              ensemble_predicted = model_vmap(ensemble)
              _, C_forecast, ensemble_analysis, C_analysis = ensrf_step(ensemble_predicted, observations[t, :], H, Q, R, subkeys[t])
              return (ensemble_analysis, C_analysis), (ensemble_predicted, C_forecast, ensemble_analysis, C_analysis)

          n = len(Q[0])
          covariance_init = jnp.zeros((n, n))
          _, (ensemble_forecast, C_forecast, ensemble_analysis, C_analysis) = jax.lax.scan(inner, (ensemble_init, covariance_init), jnp.arange(num_steps))

          return ensemble_forecast, C_forecast, ensemble_analysis, C_analysis

      file = f"ACM270_Model_Error/testing_data_noise_{noise_const}_dt_{dt}.npz"
      data = np.load(file)
      noisy_trajectory = data["noisy_trajectory"]
      try:
        trajectory=data['true_trajectory']
      except:
        trajectory=data['trajectory']

      noisy_trajectory =  jnp.asarray(noisy_trajectory)
      trajectory = jnp.asarray(trajectory)

      NRMSE_fm = []
      NRMSE_as = []
      NRMSE_fm_NN = []
      NRMSE_as_NN = []

      _,_,n = noisy_trajectory.shape

      for i in range(n):
        observations = noisy_trajectory[:,:,i]

        nn_params = state.params
        ensemble_forecast_NN, C_forecast_NN, ensemble_analysis_NN, C_analysis_NN = ensrf_steps_nn(
            lorenz1,
            n_ensemble,
            ensemble_init,
            time_steps,
            observations,
            1,
            H,
            Q,
            R,
            key,
            nn_params,
            model
        )

        ensemble_forecast, C_forecast, ensemble_analysis, C_analysis = ensrf_steps(
            lorenz1,
            n_ensemble,
            ensemble_init,
            time_steps,
            observations,
            1,
            H,
            Q,
            R,
            key
        )

        time_steps = ensemble_forecast.shape[0]
        time = jnp.arange(time_steps)

        true_states = trajectory[:,:,i]
        forecast_means = jnp.mean(ensemble_forecast, axis=2)
        analysis_means = jnp.mean(ensemble_analysis, axis=2)

        forecast_means_NN = jnp.mean(ensemble_forecast_NN, axis=2)
        analysis_means_NN = jnp.mean(ensemble_analysis_NN, axis=2)

        model_error = true_states - forecast_means
        NRMSE_fm.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
        model_error = true_states - analysis_means
        NRMSE_as.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
        model_error = true_states - forecast_means_NN
        NRMSE_fm_NN.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))
        model_error = true_states - analysis_means_NN
        NRMSE_as_NN.append(np.dot(model_error.flatten(), model_error.flatten()) / np.dot(true_states.flatten(), true_states.flatten()))

      full_NRMSE_fm.append(np.mean(NRMSE_fm))
      full_NRMSE_as.append(np.mean(NRMSE_as))
      full_NRMSE_fm_NN.append(np.mean(NRMSE_fm_NN))
      full_NRMSE_as_NN.append(np.mean(NRMSE_as_NN))

  fontsize = 20
  res = 500

  plt.figure(figsize=(8, 8))
  plt.semilogx(LRs, full_NRMSE_as, label="Analysis (uncorrected)", linestyle="-",color='k', linewidth=2.5,markersize=2)
  plt.semilogx(LRs, full_NRMSE_as_NN, label="Analysis (corrected)", linestyle="-", marker="o",color='b', linewidth=2.5,markersize=7)
  plt.semilogx(LRs, full_NRMSE_fm, label="Forecast (uncorrected)", linestyle="--",color='k', linewidth=2.5,markersize=7)
  plt.semilogx(LRs, full_NRMSE_fm_NN, label="Forecast (corrected)", linestyle="--", marker="x",color='b', linewidth=2.5,markersize=7)

  plt.xlabel("Learning rate",fontsize=fontsize)
  plt.ylabel("Total NRMSE",fontsize=fontsize)
  plt.title("Total NRMSE vs learning rate",fontsize=fontsize)
  plt.legend(fontsize=fontsize/1.25)
  plt.grid(True)
  plt.xticks(fontsize=fontsize / 1.25)
  plt.yticks(fontsize=fontsize / 1.25)
  plt.savefig(f"MSE_vs_LR_noise_{noise_const}_dt_{dt}.png", bbox_inches="tight", dpi=res)
