In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2

In [None]:
import os
from datetime import datetime

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"

In [None]:
import tsdm
import torch
import pandas
import numpy as np

from tqdm import trange, tqdm
from torch import tensor, Tensor, jit
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter

from tsdm.metrics.functional import nrmse, nd
from tsdm.utils.dataloaders import SliceSampler
from tsdm.utils import grad_norm
from tsdm.datasets import Electricity
from tsdm.encoders import time2float

from linodenet.models import LinODEnet, LinODECell, LinODE
from linodenet.projections import symmetric, skew_symmetric

### Utility functions

In [None]:
def now():
    return datetime.now().isoformat(timespec="seconds")


def symmpart(kernel):
    return torch.mean(symmetric(kernel) ** 2) / torch.mean(kernel**2)


def skewpart(kenerl):
    return torch.mean(skew_symmetric(kernel) ** 2) / torch.mean(kernel**2)

### Plotting Kernel Utility Function

In [None]:
import matplotlib.pyplot as plt


def plot_spectrum(kernel):
    eigs = torch.linalg.eigvals(kernel).detach().cpu()
    fig, ax = plt.subplots(figsize=(12, 6), tight_layout=True)
    ax.set_xlim([-2.5, +2.5])
    ax.set_ylim([-2.5, +2.5])
    ax.set_aspect("equal")
    ax.set_xlabel("real part")
    ax.set_ylabel("imag part")
    ax.scatter(eigs.real, eigs.imag)
    return fig

### Setup Code

In [None]:
DEVICE = torch.device("cuda")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
DATASET = Electricity.dataset
NPTS, NDIM = DATASET.shape
SEQLEN = 48
PRD_HORIZON = 24
OBS_HORIZON = SEQLEN - PRD_HORIZON
BATCH_SIZE = 16

### Dataset preprocessing

In [None]:
# preprocessing cf. NBEATS-paper
ds = Electricity.dataset

# resample hourly
ds = ds.resample(pandas.Timedelta("1h"), label="right").sum()

# remove first year
ds = ds.loc[pandas.Timestamp("2012-01-01") :]

# train-test split
ds_train, ds_test = ds.iloc[:-PRD_HORIZON], ds.iloc[-PRD_HORIZON:]  # 168=7*24
t_train, t_test = time2float(ds_train.index), time2float(ds_test.index)
t_train, t_test = t_train / t_train.max(), t_test / t_test.max()

train = torch.cat(  # tuples (t, x)
    [
        tensor(t_train, device=DEVICE, dtype=DTYPE).unsqueeze(-1),
        tensor(ds_train.values, device=DEVICE, dtype=DTYPE),
    ],
    axis=-1,
)

### Data Loading Utility Functions

In [None]:
sampler = SliceSampler(train, slice_sampler=SEQLEN)
sampler = BatchSampler(sampler, batch_size=BATCH_SIZE, drop_last=True)


def collate_tensor(tensors: list[tensor]) -> tensor:
    r"""Combine list of tensors into batch"""
    return torch.stack(tensors, axis=0)


@jit.script
def prep(
    batch: Tensor, OBS_HORIZON: int = OBS_HORIZON, NAN: Tensor = NAN
) -> tuple[Tensor, Tensor, Tensor]:
    t, x = batch[:, :, 0], batch[:, :, 1:]
    t_obs, t_pred = t[:, :OBS_HORIZON], t[:, OBS_HORIZON:]
    x_obs = x.detach().clone()
    x_obs[:, OBS_HORIZON:, :] = NAN
    return t, x, x_obs

In [None]:
from torch.optim import Adam

model = LinODEnet(NDIM, 512, embedding_type="concat")
model.to(device=DEVICE, dtype=DTYPE)

optimizer = Adam(model.parameters(), lr=0.001)

batch = collate_tensor(next(iter(sampler)))
t, x = batch[:, :, 0], batch[:, :, 1:]
writer = SummaryWriter(f"runs/LinODEnet/{now()}")

In [None]:
for batch in (pbar := tqdm(sampler)):
    batch = collate_tensor(batch)
    t, x, x_obs = prep(batch)
    x_hat = model(t, x_obs)
    loss = torch.mean(nd(x_hat, x))
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        pbar.set_postfix(loss=f"{loss:.2e}")
        kernel = model.system.kernel.detach().cpu()
        writer.add_scalar(
            "optim/grad_norm", grad_norm(list(model.parameters())), pbar.n
        )
        #         writer.add_scalar("optim/momentum", grad_norm(list(model.parameters())), pbar.n)
        #         writer.add_scalar("optim/momentum", grad_norm(list(model.parameters())), pbar.n)
        writer.add_scalar("train/loss:nd", loss, pbar.n)
        writer.add_scalar("train/loss:nrmse", torch.mean(nrmse(x_hat, x)), pbar.n)

        # plot kernel data
        writer.add_histogram("kernel/histogram", model.system.kernel, pbar.n)
        writer.add_image(
            "kernel/values", torch.tanh(model.system.kernel), pbar.n, dataformats="HW"
        )
        writer.add_figure("kernel/spectrum", plot_spectrum(kernel), pbar.n)
        writer.add_scalar("kernel/skewpart", skewpart(kernel), pbar.n)
        writer.add_scalar("kernel/symmpart", symmpart(kernel), pbar.n)
        writer.add_scalar("kernel/det", torch.linalg.det(kernel), pbar.n)
        writer.add_scalar("kernel/rank", torch.linalg.matrix_rank(kernel), pbar.n)
        writer.add_scalar("kernel/trace", torch.trace(kernel), pbar.n)
        writer.add_scalar("kernel/cond", torch.linalg.cond(kernel), pbar.n)
        writer.add_scalar("kernel/logdet", torch.linalg.slogdet(kernel)[-1], pbar.n)
        writer.add_scalar(
            "kernel/norm:fro", torch.linalg.matrix_norm(kernel, ord="fro"), pbar.n
        )
        writer.add_scalar(
            "kernel/norm:nuc", torch.linalg.matrix_norm(kernel, ord="nuc"), pbar.n
        )
        writer.add_scalar(
            "kernel/norm:-∞", torch.linalg.matrix_norm(kernel, ord=-np.inf), pbar.n
        )
        writer.add_scalar(
            "kernel/norm:-2", torch.linalg.matrix_norm(kernel, ord=-2), pbar.n
        )
        writer.add_scalar(
            "kernel/norm:-1", torch.linalg.matrix_norm(kernel, ord=-1), pbar.n
        )
        writer.add_scalar(
            "kernel/norm:+1", torch.linalg.matrix_norm(kernel, ord=+1), pbar.n
        )
        writer.add_scalar(
            "kernel/norm:+2", torch.linalg.matrix_norm(kernel, ord=+2), pbar.n
        )
        writer.add_scalar(
            "kernel/norm:+∞", torch.linalg.matrix_norm(kernel, ord=+np.inf), pbar.n
        )