In [None]:
import torch
from torch import Tensor, jit, nn
from torch.optim import AdamW
from tqdm.autonotebook import tqdm

from linodenet.models import LatentStateSpaceModel as LSSM
from linodenet.models.embeddings import ConcatEmbedding, ConcatProjection
from linodenet.models.encoders.invertible_layers import (
    LinearContraction,
    NaiveLinearContraction,
    iResNetBlock,
    iSequential,
)
from linodenet.models.filters import LinearFilter, NonLinearFilter, SequentialFilter
from linodenet.models.system import LinODECell
from linodenet.utils import ReZeroCell

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp


# Define the Lotka-Volterra equations
def lotka_volterra(t, y, a, b, c, d):
    x, y = y
    dx_dt = a * x - b * x * y
    dy_dt = -c * y + d * x * y
    return [dx_dt, dy_dt]


# Set the parameters
a = 3.0  # prey growth rate
b = 1.5  # predation rate
c = 3.0  # predator death rate
d = 1.0  # conversion factor of prey to predator

# Set the initial conditions
x0 = 10.0  # initial prey population
y0 = 5.0  # initial predator population

# Set the time span
t_start = 0.0
t_end = 30.0
num_points = 1000

# Solve the equations using solve_ivp
sol = solve_ivp(
    lotka_volterra,
    [t_start, t_end],
    [x0, y0],
    args=(a, b, c, d),
    dense_output=True,
)

# Generate time points for evaluation
t_eval = np.linspace(t_start, t_end, num_points)

# Evaluate the solution at the time points
sol_eval = sol.sol(t_eval)

In [None]:
fig, ax = plt.subplots(figsize=(16, 10), constrained_layout=True)

# Plot the populations over time
ax.plot(t_eval, sol_eval[0], label="Prey")
ax.plot(t_eval, sol_eval[1], label="Predator")
ax.set_xlabel("Time")
ax.set_ylabel("Population")
ax.set_title("Lotka-Volterra Equations")
ax.legend()
ax.grid(True)

In [None]:
N = 1000

noise = np.random.gamma(shape=10, scale=1 / 10, size=(N, 1))

T = np.sort(np.random.uniform(t_start, t_end, N))
X = noise * sol.sol(T).T

m_train = T < (t_end - t_start) / 2
m_test = T > (t_end - t_start) / 2
T_train = T[m_train]
T_test = T[m_test]
X_train = X[m_train]
X_test = X[m_test]

In [None]:
plt.plot(T_train, X_train[:, 0], ".b", T_train, X_train[:, 1], ".r")

# Setup Model

In [None]:
latent_size = 16
input_size = 2

x = torch.randn(input_size)
z = torch.randn(latent_size)
dta = torch.rand(1)
dtb = torch.rand(1)
T = torch.tensor(T, dtype=torch.float32)
X = torch.tensor(X, dtype=torch.float32)
T_train = torch.tensor(T_train, dtype=torch.float32)
T_test = torch.tensor(T_test, dtype=torch.float32)
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)

## Initialize Encoder

In [None]:
Encoder = iSequential(
    ConcatEmbedding(input_size, latent_size),
    iResNetBlock(
        nn.Sequential(
            LinearContraction(latent_size, latent_size, L=0.95),
            ReZeroCell(),
        )
    ),
    iResNetBlock(
        nn.Sequential(
            LinearContraction(latent_size, latent_size, L=0.95),
            ReZeroCell(),
        )
    ),
)
assert torch.allclose(x, Encoder.decode(Encoder.encode(x)), atol=1e-3, rtol=1e-3)

## initialize Filter

In [None]:
Filter = SequentialFilter(
    LinearFilter(input_size, autoregressive=True),
    NonLinearFilter(input_size, autoregressive=True),
)
assert torch.allclose(x, Filter(x, x))

## initialize System

In [None]:
System = LinODECell(latent_size)
assert torch.allclose(System(dta + dtb, z), System(dta, System(dtb, z)))

## Initialize Model

In [None]:
model = LSSM(
    encoder=Encoder,
    system=System,
    decoder=Encoder.inverse,
    filter=Filter,
).to(device="cpu")

contractions = [
    m for m in model.encoder.modules() if m.__class__.__name__ == "LinearContraction"
]
for layer in contractions:
    layer.reset_cache()

assert torch.allclose(X[:100], model(T[:100], X[:100]))

## train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.script(model)
model = model.to(device)
optim = AdamW(model.parameters())

contractions = [
    m for m in model.encoder.modules() if m.original_name == "LinearContraction"
]

for layer in contractions:
    layer.reset_cache()

In [None]:
from torch.utils.data import DataLoader

from tsdm.random.samplers import SlidingWindowSampler

train_sampler = SlidingWindowSampler(T_train, horizons=(5,), stride=0.1, shuffle=True)
test_sampler = SlidingWindowSampler(T_test, horizons=(5,), stride=0.1)

In [None]:
from torch.nn.utils.rnn import pad_sequence

n_forecast = 10

train_samples = []
test_samples = []
for horizon in train_sampler:
    t = T_train[horizon]
    x = X_train[horizon]
    y = x.clone()
    x[-n_forecast:] = float("nan")
    train_samples.append((t, x, y))

for horizon in test_sampler:
    t = T_test[horizon]
    x = X_test[horizon]
    y = x.clone()
    x[-n_forecast:] = float("nan")
    test_samples.append((t, x, y))


def collate_fn(
    samples: list[tuple[Tensor, Tensor, Tensor]]
) -> tuple[Tensor, Tensor, Tensor]:
    nan = torch.tensor(float("nan"), device=device)
    t_list, x_list, y_list = list(zip(*samples))

    return (
        pad_sequence(t_list, batch_first=True, padding_value=nan),
        pad_sequence(x_list, batch_first=True, padding_value=nan),
        pad_sequence(y_list, batch_first=True, padding_value=nan),
    )

In [None]:
train_loader = DataLoader(train_samples, collate_fn=collate_fn, batch_size=8)
test_loader = DataLoader(train_samples, collate_fn=collate_fn, batch_size=8)


def test_score(model, dloader):
    with torch.no_grad():
        total = torch.tensor(0.0, device=device)

        for t, x, y in (pbar := tqdm(dloader, leave=False)):
            t = t.to(device)
            x = x.to(device)
            y = y.to(device)
            yhat = model(t, x)
            loss = torch.nanmean((y - yhat).pow(2))
            assert loss.isfinite()
            pbar.set_postfix(loss=f"{float(loss.item()):.4f}")
            total += loss

        total /= len(test_loader)
        return total

In [None]:
score = test_score(model, test_loader)
print(score)

In [None]:
for k in range(10):
    for t, x, y in (pbar := tqdm(train_loader)):
        with torch.no_grad():
            t = t.to(device)
            x = x.to(device)
            y = y.to(device)

        model.zero_grad(set_to_none=True)
        yhat = model(t, x)
        loss = torch.nanmean((y - yhat).pow(2))
        loss.backward()

        assert loss.isfinite()
        optim.step()
        pbar.set_postfix(loss=f"{float(loss.item()):.4f}")
    score = test_score(model, test_loader)
    print(score)