In [None]:
import torch
from torch import Tensor, jit, nn
from torch.optim import AdamW
from torchinfo import summary
from tqdm.autonotebook import tqdm, trange

from linodenet.models import LatentStateSpaceModel as LSSM
from linodenet.models.embeddings import ConcatEmbedding, ConcatProjection
from linodenet.models.encoders.invertible_layers import (
    LinearContraction,
    NaiveLinearContraction,
    iResNetBlock,
    iSequential,
)
from linodenet.models.filters import LinearFilter, NonLinearFilter, SequentialFilter
from linodenet.models.system import LinODECell
from linodenet.utils import ReZeroCell

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp


# Define the Lotka-Volterra equations
def lotka_volterra(t, y, a, b, c, d):
    x, y = y
    dx_dt = a * x - b * x * y
    dy_dt = -c * y + d * x * y
    return [dx_dt, dy_dt]


# Set the parameters
a = 3.0  # prey growth rate
b = 2.0  # predation rate
c = 3.0  # predator death rate
d = 1.0  # conversion factor of prey to predator

# Set the initial conditions
x0 = 2.0  # initial prey population
y0 = 1.0  # initial predator population

# Set the time span
T_MIN = 0.0
T_MAX = 30.0
num_points = 1000

# Solve the equations using solve_ivp
sol = solve_ivp(
    lotka_volterra,
    [T_MIN, T_MAX],
    [x0, y0],
    args=(a, b, c, d),
    dense_output=True,
)

# Generate time points for evaluation
t_eval = np.linspace(T_MIN, T_MAX, num_points)

# Evaluate the solution at the time points
sol_eval = sol.sol(t_eval)

In [None]:
N = 1000

noise = np.random.gamma(shape=20, scale=1 / 20, size=(N, 1)).clip(0.5, 1.5)

T = np.sort(np.random.uniform(T_MIN, T_MAX, N))
X = noise * sol.sol(T).T


fig, ax = plt.subplots(figsize=(16, 4), constrained_layout=True)

# Plot the populations over time
ax.plot(t_eval, sol_eval[0], label="Prey")
ax.plot(t_eval, sol_eval[1], label="Predator")
ax.plot(T, X, ".")
ax.set_xlabel("Time")
ax.set_ylabel("Population")
ax.set_title("Lotka-Volterra Equations")
ax.legend()
ax.grid(True)

In [None]:
# standardize
T = (T - T.min()) / (T.max() - T.min())
X = (X - X.mean(axis=0)) / X.std(axis=0)

m_train = T < 0.6
m_test = T > 0.6
T_train = T[m_train]
T_test = T[m_test]
X_train = X[m_train]
X_test = X[m_test]

plt.plot(T_train, X_train, ".");

# Setup Model

In [None]:
latent_size = 64
input_size = 2

x = torch.randn(input_size)
z = torch.randn(latent_size)
dta = torch.rand(1)
dtb = torch.rand(1)
T = torch.tensor(T, dtype=torch.float32)
X = torch.tensor(X, dtype=torch.float32)
T_train = torch.tensor(T_train, dtype=torch.float32)
T_test = torch.tensor(T_test, dtype=torch.float32)
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)

## Initialize Encoder

In [None]:
Encoder = iSequential(
    ConcatEmbedding(input_size, latent_size),
    iResNetBlock(
        nn.Sequential(
            LinearContraction(latent_size, latent_size, L=0.99),
            LinearContraction(latent_size, latent_size, L=0.99),
            ReZeroCell(),
        )
    ),
    iResNetBlock(
        nn.Sequential(
            LinearContraction(latent_size, latent_size, L=0.99),
            LinearContraction(latent_size, latent_size, L=0.99),
            ReZeroCell(),
        )
    ),
    iResNetBlock(
        nn.Sequential(
            LinearContraction(latent_size, latent_size, L=0.99),
            LinearContraction(latent_size, latent_size, L=0.99),
            ReZeroCell(),
        )
    ),
)

Decoder = iSequential(
    iResNetBlock(
        nn.Sequential(
            LinearContraction(latent_size, latent_size, L=0.99),
            LinearContraction(latent_size, latent_size, L=0.99),
            ReZeroCell(),
        )
    ),
    iResNetBlock(
        nn.Sequential(
            LinearContraction(latent_size, latent_size, L=0.99),
            LinearContraction(latent_size, latent_size, L=0.99),
            ReZeroCell(),
        )
    ),
    iResNetBlock(
        nn.Sequential(
            LinearContraction(latent_size, latent_size, L=0.99),
            LinearContraction(latent_size, latent_size, L=0.99),
            ReZeroCell(),
        )
    ),
    ConcatProjection(latent_size, input_size),
)

assert torch.allclose(x, Encoder.decode(Encoder.encode(x)), atol=1e-3, rtol=1e-3)

## initialize Filter

In [None]:
Filter = SequentialFilter(
    LinearFilter(input_size, autoregressive=True),
    NonLinearFilter(input_size, autoregressive=True),
    NonLinearFilter(input_size, autoregressive=True),
)
assert torch.allclose(x, Filter(x, x))

## initialize System

In [None]:
System = LinODECell(latent_size)
assert torch.allclose(System(dta + dtb, z), System(dta, System(dtb, z)))

## Initialize Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LSSM(
    encoder=Encoder,
    system=System,
    decoder=Decoder,
    filter=Filter,
).to(device="cpu")
# assert torch.allclose(X[:100], model(T[:100], X[:100]))
print(f"Number of named submodules: {len(list(model.named_modules()))}")
summary(model)

In [None]:
contractions = [
    m for m in model.encoder.modules() if m.__class__.__name__ == "LinearContraction"
] + [m for m in model.decoder.modules() if m.__class__.__name__ == "LinearContraction"]


def reset_all_caches():
    for layer in contractions:
        layer.reset_cache()


print(contractions)
reset_all_caches()

## Warmup

In [None]:
model = model.to(device)
optim = AdamW(model.parameters(), lr=0.0001)
model.zero_grad(set_to_none=True)
reset_all_caches()

for k in range(3):
    t = torch.sort(torch.rand(100, device=device))[0]
    x = torch.randn(100, input_size, device=device)
    r = model(t, x)
    r.norm().backward()
    optim.step()
    reset_all_caches()
    model.zero_grad(set_to_none=True)

## JIT warmup

In [None]:
model = torch.jit.script(model)
print(f"Number of named submodules: {len(list(model.named_modules()))}")
summary(model)

In [None]:
def iter_modules(module):
    """Helper function needed because named_modules returns wrong results."""
    yield module
    for name, submodule in module.named_children():
        if name != "inverse":
            yield from iter_modules(submodule)


contractions = [
    m for m in iter_modules(model.encoder) if m.original_name == "LinearContraction"
] + [m for m in iter_modules(model.decoder) if m.original_name == "LinearContraction"]


def reset_all_caches():
    for layer in contractions:
        layer.reset_cache()


print(contractions)
reset_all_caches()

## Warmup

In [None]:
model = model.to(device)
optim = AdamW(model.parameters(), lr=0.0001)
model.zero_grad(set_to_none=True)
reset_all_caches()

for k in range(2):
    t = torch.sort(torch.rand(100, device=device))[0]
    x = torch.randn(100, input_size, device=device)
    r = model(t, x)
    r.norm().backward()
    optim.step()
    reset_all_caches()
    model.zero_grad(set_to_none=True)

## train Model

In [None]:
from torch.utils.data import DataLoader

from tsdm.random.samplers import SlidingWindowSampler

horizon = 1 / 16
stride = 1 / 128

train_sampler = SlidingWindowSampler(
    T_train, horizons=(horizon,), stride=stride, shuffle=True
)
test_sampler = SlidingWindowSampler(T_test, horizons=(horizon,), stride=stride)

## initialize Loss

In [None]:
from tsdm.metrics import TimeSeriesMSE

loss = TimeSeriesMSE()

In [None]:
from torch.nn.utils.rnn import pad_sequence

n_forecast = 30

train_samples = []
for horizon in train_sampler:
    t = T_train[horizon]
    x = X_train[horizon]
    y = x.clone()
    x[-n_forecast:] = float("nan")
    train_samples.append((t, x, y))

test_samples = []
for horizon in test_sampler:
    t = T_test[horizon]
    x = X_test[horizon]
    y = x.clone()
    x[-n_forecast:] = float("nan")
    test_samples.append((t, x, y))


def collate_fn(
    samples: list[tuple[Tensor, Tensor, Tensor]]
) -> tuple[Tensor, Tensor, Tensor]:
    nan = torch.tensor(float("nan"), device=device)
    t_list, x_list, y_list = list(zip(*samples))

    return (
        pad_sequence(t_list, batch_first=True),
        pad_sequence(x_list, batch_first=True, padding_value=nan),
        pad_sequence(y_list, batch_first=True, padding_value=nan),
    )

In [None]:
train_loader = DataLoader(
    train_samples, collate_fn=collate_fn, batch_size=128, shuffle=True
)
test_loader = DataLoader(test_samples, collate_fn=collate_fn, batch_size=128)
infer_loader = DataLoader(test_samples, collate_fn=collate_fn, shuffle=True)


@torch.no_grad()
def test_score(model, dloader):
    total = torch.tensor(0.0, device=device)

    for t, x, y in dloader:
        t = t.to(device)
        x = x.to(device)
        y = y.to(device)
        yhat = model(t, x)
        r = loss(y, yhat)
        assert r.isfinite()
        total += r

    # total /= len(dloader)
    return total.item()

In [None]:
test_score(model, test_loader)

In [None]:
@torch.no_grad()
def grad_norm(model):
    total = torch.tensor(0.0, device=device)
    for p in model.parameters():
        if p.grad is not None:
            total += p.grad.norm() / p.numel()
    total = total / len(list(model.parameters()))
    return total.item()

In [None]:
optim = AdamW(model.parameters())
reset_all_caches()

In [None]:
for k in (outer := trange(10000)):
    for t, x, y in train_loader:
        with torch.no_grad():
            t = t.to(device)
            x = x.to(device)
            y = y.to(device)
        yhat = model(t, x)
        loss_post = loss(y, yhat)
        loss_pre = loss(y, model.xhat_pre)
        total = loss_pre + loss_post
        assert total.isfinite()
        total.backward()
        grad = grad_norm(model)
        # pbar.set_postfix(loss=f"{r.item():.4f}", grad=f"{grad:.4f}")
        optim.step()
        model.zero_grad(set_to_none=True)
        reset_all_caches()
    score = test_score(model, test_loader)
    outer.set_postfix(
        loss_post=f"{loss_post.item():.4f}",
        loss_pre=f"{loss_pre.item():.4f}",
        score=f"{score:.4f}",
        grad=f"{grad:.4f}",
    )

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

for ax in axes:
    t, x, y = next(iter(infer_loader))
    t, x, y = t[0], x[0], y[0]
    i_forecast = torch.argmax(x[:, 0])
    print(len(t), torch.isnan(x).sum() // 2)
    t_long = torch.linspace(t.min(), t.max(), 1000)

    yhat = model(t.to(device), x.to(device)).cpu().detach()
    ax.axvspan(t.min(), t[i_forecast], alpha=0.2)
    ax.plot(t, y, ".", t, yhat);

In [None]:
model.system.weight