In [68]:
import jax
import jax.numpy as jnp
import pandas as pd
import diffrax
import matplotlib.pyplot as plt
import jax.random as jr
import equinox as eqx

from utils import plots, train_utils
from models import DataLoader, AutoregressiveCDE, RNN, SingleSolveCDE, CDEODE

In [69]:
dataloader = DataLoader.DataLoader("data/data_large.npy", key=jr.key(0))


In [70]:
def create_missing_mask(
    key, 
    length,
    missing_ratio
):
    """
    Create mask ensuring first observation is always present.
    
    Args:
        key: JAX random key
        length: Length of time series
        missing_ratio: Target fraction of points to mask (0.3, 0.5, 0.7)
    
    Returns:
        Boolean array where True = observed, False = missing
        First element is guaranteed to be True
    """
    rand_vals = jr.uniform(key, shape=(length - 1,))
    
    
    adjusted_ratio = (missing_ratio * length - 0) / (length - 1)
    adjusted_ratio = jnp.clip(adjusted_ratio, 0.0, 1.0)
    
    middle_mask = rand_vals > adjusted_ratio
    
    mask = jnp.concatenate([jnp.array([True]), middle_mask])
    
    return mask



def apply_mask(data, mask):
    """
    Apply mask to data, setting missing values to NaN.
    
    Args:
        data: Time series data, shape (T,) or (T, features)
        mask: Boolean mask, shape (T,)
    
    Returns:
        Data with NaN at missing positions
    """
    if data.ndim > 1:
        mask = mask[:, None]
    
    return jnp.where(mask, data, jnp.nan)

def mask_dataset(ys, length, missing_ratio, seed=0):
    new_data = []
    key = jr.key(seed)
    for i in range(len(ys)):
        key, subkey = jr.split(key)  
        mask = create_missing_mask(subkey, length, missing_ratio)
        masked_array = apply_mask(ys[i], mask)
        interp_ys = diffrax.linear_interpolation(
            ts=jnp.linspace(0, 1, 100)[:length], 
            ys=masked_array, 
            fill_forward_nans_at_end=True
        )
        new_data.append(interp_ys)
    return jnp.array(new_data)

In [71]:
def test_loss_with_missing_data(model, ys, missing_ratio):
    """ 
        Test loss
        
        Test ability of the model to forecast with an input length between 10 and 86 using RMSE.


        Params: 
            model:  The model which is being called
            ys: Used to calculate the loss (shape: (B,100,))
            prediction_mode: Decides whether the forecast will be evaluated 14 days ahead or until day 100
    """
    size = ys.shape[0]
    slice_size = size // 9

    losses = []

    for control_until in [10, 20, 30, 40, 50, 60, 70, 80, 86]:
        control_ys = ys[:, :control_until]

        #prevent data leakage by masking after slicing
        masked_control_ys = mask_dataset(control_ys, control_until, missing_ratio, seed=control_until)

        train_until = control_until + 14

        loss = train_utils.batch_loss(model, masked_control_ys, ys, control_until, train_until)
        losses.append(loss)

        print(f"Loss for input length of {control_until}: {loss}")
    total_mean_test_loss =  jnp.mean(jnp.array(losses))
    print(f"Mean Test Loss: {total_mean_test_loss}")
    return total_mean_test_loss

In [61]:
model = eqx.tree_deserialise_leaves("serialised_models/rnn_14_day_ahead.eqx", like=RNN.RNNForecaster(input_size=2, hidden_size=64, output_size=2, key=jr.key(0)))
test_loss_with_missing_data(model, dataloader.test_data, 0.3)
test_loss_with_missing_data(model, dataloader.test_data, 0.5)
test_loss_with_missing_data(model, dataloader.test_data, 0.7)

Loss for input length of 10: 1577.3223876953125
Loss for input length of 20: 3246.05419921875
Loss for input length of 30: 2604.49755859375
Loss for input length of 40: 1722.998779296875
Loss for input length of 50: 1363.8818359375
Loss for input length of 60: 1249.0345458984375
Loss for input length of 70: 1429.1473388671875
Loss for input length of 80: 1260.7310791015625
Loss for input length of 86: 1510.4169921875
Mean Test Loss: 1773.7872314453125
Loss for input length of 10: 1583.743896484375
Loss for input length of 20: 3566.453857421875
Loss for input length of 30: 2707.0654296875
Loss for input length of 40: 1844.0050048828125
Loss for input length of 50: 1566.8040771484375
Loss for input length of 60: 1436.0711669921875
Loss for input length of 70: 1608.1708984375
Loss for input length of 80: 1355.90380859375
Loss for input length of 86: 1543.21044921875
Mean Test Loss: 1912.381103515625
Loss for input length of 10: 1594.9560546875
Loss for input length of 20: 4303.89013671875

Array(2514.406, dtype=float32)

In [72]:
model = eqx.tree_deserialise_leaves("serialised_models/autoregressive_cde_14_day_ahead.eqx", like=AutoregressiveCDE.AutoregressiveCDE(data_size=3, hidden_size=3, width_size=64, depth=3, key=jr.key(2)))
test_loss_with_missing_data(model, dataloader.test_data, 0.3)
test_loss_with_missing_data(model, dataloader.test_data, 0.5)
test_loss_with_missing_data(model, dataloader.test_data, 0.7)

Loss for input length of 10: 1554.40478515625
Loss for input length of 20: 3305.204345703125
Loss for input length of 30: 2697.40087890625
Loss for input length of 40: 1736.1083984375
Loss for input length of 50: 1516.241943359375
Loss for input length of 60: 1701.8223876953125
Loss for input length of 70: 1842.950927734375
Loss for input length of 80: 1752.9615478515625
Loss for input length of 86: 1852.677734375
Mean Test Loss: 1995.5303955078125
Loss for input length of 10: 1572.1829833984375
Loss for input length of 20: 4141.50830078125
Loss for input length of 30: 3678.3935546875
Loss for input length of 40: 2223.578857421875
Loss for input length of 50: 2076.552734375
Loss for input length of 60: 2101.6396484375
Loss for input length of 70: 2394.879150390625
Loss for input length of 80: 2229.93701171875
Loss for input length of 86: 2315.21337890625
Mean Test Loss: 2525.987548828125
Loss for input length of 10: 1591.5745849609375
Loss for input length of 20: 5088.45458984375
Loss 

Array(3346.5552, dtype=float32)

In [67]:
model = eqx.tree_deserialise_leaves("serialised_models/single_solve_cde_14_day_ahead.eqx", like=SingleSolveCDE.SingleSolveCDE(data_size=3, hidden_size=3, width_size=64, depth=3, key=jr.key(2)))
test_loss_with_missing_data(model, dataloader.test_data, 0.3)
test_loss_with_missing_data(model, dataloader.test_data, 0.5)
test_loss_with_missing_data(model, dataloader.test_data, 0.7)

Loss for input length of 10: 1540.343017578125
Loss for input length of 20: 3496.313232421875
Loss for input length of 30: 4003.48583984375
Loss for input length of 40: 1729.0379638671875
Loss for input length of 50: 1102.8875732421875
Loss for input length of 60: 1194.04931640625
Loss for input length of 70: 1189.257080078125
Loss for input length of 80: 913.565673828125
Loss for input length of 86: 1138.0421142578125
Mean Test Loss: 1811.8868408203125
Loss for input length of 10: 1563.6043701171875
Loss for input length of 20: 4364.42529296875
Loss for input length of 30: 4797.14208984375
Loss for input length of 40: 2285.07763671875
Loss for input length of 50: 1805.1517333984375
Loss for input length of 60: 1678.18994140625
Loss for input length of 70: 1890.8519287109375
Loss for input length of 80: 1419.5989990234375
Loss for input length of 86: 1569.4749755859375
Mean Test Loss: 2374.835205078125
Loss for input length of 10: 1588.0032958984375
Loss for input length of 20: 5286.46

Array(3235.6438, dtype=float32)

In [66]:
model = eqx.tree_deserialise_leaves("serialised_models/cde_ode_14_day_ahead.eqx", like=CDEODE.CDEODE(data_size=3, hidden_size=3, width_size=64, depth=3, key=jr.key(2)))
test_loss_with_missing_data(model, dataloader.test_data, 0.3)
test_loss_with_missing_data(model, dataloader.test_data, 0.5)
test_loss_with_missing_data(model, dataloader.test_data, 0.7)

Loss for input length of 10: 1595.3050537109375
Loss for input length of 20: 5232.30029296875
Loss for input length of 30: 2787.12451171875
Loss for input length of 40: 2524.214599609375
Loss for input length of 50: 1641.1046142578125
Loss for input length of 60: 1675.4755859375
Loss for input length of 70: 2532.91748046875
Loss for input length of 80: 3380.387939453125
Loss for input length of 86: 3866.750732421875
Mean Test Loss: 2803.95361328125
Loss for input length of 10: 1599.722412109375
Loss for input length of 20: 5430.73486328125
Loss for input length of 30: 3249.871826171875
Loss for input length of 40: 2551.370849609375
Loss for input length of 50: 1848.8719482421875
Loss for input length of 60: 1865.743896484375
Loss for input length of 70: 2724.832763671875
Loss for input length of 80: 3548.67236328125
Loss for input length of 86: 3995.2841796875
Mean Test Loss: 2979.456298828125
Loss for input length of 10: 1605.9593505859375
Loss for input length of 20: 5764.5205078125


Array(3399.7415, dtype=float32)