In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2

In [2]:
import os
from datetime import datetime

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"

In [3]:
import tsdm
import torch
import pandas
import numpy as np

from tqdm.auto import trange, tqdm
from torch import tensor, Tensor, jit
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter

from tsdm.losses.functional import nrmse, nd
from tsdm.util import grad_norm, multi_norm
from tsdm.datasets import Electricity
from tsdm.encoders import time2float

from linodenet.models import LinODEnet, LinODECell, LinODE
from linodenet.projections import symmetric, skew_symmetric

In [4]:
DEVICE = torch.device("cuda")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
BATCH_SIZE = 16
PRD_HORIZON = 24
OBS_HORIZON = 96
SEQLEN = PRD_HORIZON + OBS_HORIZON

In [5]:
from tsdm.tasks import ETDatasetInformer

TASK = ETDatasetInformer(
    dataset="ETTh1",
    forecasting_horizon=24,
    observation_horizon=96,
    test_metric="MSE",
    time_encoder="time2float",
)
DATASET = TASK.dataset

NUM_PTS, NUM_DIM = DATASET.dataset.shape

### Utility functions

In [6]:
@jit.script
def prep_batch(batch: tuple[Tensor, Tensor, Tensor], observation_horizon: int):
    T, X, Y = batch
    targets = Y[..., observation_horizon:].clone()
    Y[..., observation_horizon:] = float("nan")
    inputs = torch.cat([X, Y.unsqueeze(-1)], dim=-1)
    return T, inputs, targets


def evaluate_model(model, task):
    YS = []
    YHATS = []
    for batch in tqdm(task.get_dataloader("test", batch_size=2760)):
        with torch.no_grad():
            model.zero_grad()
            times, inputs, targets = prep_batch(batch, OBS_HORIZON)
            outputs = model(times, inputs)
            predics = outputs[:, OBS_HORIZON:, -1]
            loss = LOSS(predics, targets)
            YS.append(targets)
            YHATS.append(predics)

    YS = torch.cat(YS, dim=-1)
    YHATS = torch.cat(YHATS, dim=-1)
    return LOSS(YS, YHATS)

# logging utilities

In [7]:
def log_all(i, model, writer, optimizer):
    kernel = model.system.kernel.clone().detach().cpu()
    log_kernel_information(i, writer, kernel, histograms=True)
    log_optimizer_state(i, writer, optimizer, histograms=True)

### Plotting Kernel Utility Function

In [15]:
from torch.optim import Adam
from linodenet.models import LinODEnet
from tsdm.util.logging import log_optimizer_state, log_kernel_information

MODEL = LinODEnet
model = MODEL(input_size=NUM_DIM, hidden_size=32, embedding_type="concat")
model.to(device=DEVICE, dtype=DTYPE)
writer = SummaryWriter(f"runs/{MODEL.__name__}/{DATASET.__name__}{tsdm.util.now()}")
LOSS = TASK.test_metric()

In [16]:
optimizer = Adam(model.parameters(), lr=0.01)
TRAINLOADER = TASK.get_dataloader("joint", batch_size=32)

In [17]:
for epoch in (epochs := trange(10)):

    #     log_all(epochs.n, model, writer, optimizer)

    for batch in (batches := tqdm(TRAINLOADER)):
        model.zero_grad()
        times, inputs, targets = prep_batch(batch, OBS_HORIZON)
        outputs = model(times, inputs)
        predics = outputs[:, OBS_HORIZON:, -1]
        loss = LOSS(predics, targets)
        loss.backward()

        lval = loss.clone().detach().cpu().numpy()
        gval = grad_norm(list(model.parameters())).clone().detach().cpu().numpy()
        batches.set_postfix(loss=lval, gnorm=gval)

        if torch.any(torch.isnan(loss)):
            raise RuntimeError("NaN-value encountered!!")
        optimizer.step()

    TEST_LOSS = evaluate_model(model, TASK)
    print(epoch, TEST_LOSS)

In [33]:
from pathlib import Path

path = Path.cwd().joinpath("models")
path.mkdir(exist_ok=True)
model.save("models/full_model.pt")

In [40]:
torch.jit.save(model, path.joinpath("full_model.pt"))

In [35]:
model2 = torch.jit.load("models/full_model")

In [None]:
for batch in (pbar := tqdm(sampler)):
    batch = collate_tensor(batch)
    t, x, x_obs = prep(batch)
    x_hat = model(t, x_obs)
    loss = torch.mean(nd(x_hat, x))
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        pbar.set_postfix(loss=f"{loss:.2e}")
        kernel = model.system.kernel.detach().cpu()
        writer.add_scalar(
            "optim/grad_norm", grad_norm(list(model.parameters())), pbar.n
        )
        #         writer.add_scalar("optim/momentum", grad_norm(list(model.parameters())), pbar.n)
        #         writer.add_scalar("optim/momentum", grad_norm(list(model.parameters())), pbar.n)
        writer.add_scalar("train/loss:nd", loss, pbar.n)
        writer.add_scalar("train/loss:nrmse", torch.mean(nrmse(x_hat, x)), pbar.n)

        # plot kernel data
        writer.add_histogram("kernel/histogram", model.system.kernel, pbar.n)
        writer.add_image(
            "kernel/values", torch.tanh(model.system.kernel), pbar.n, dataformats="HW"
        )
        writer.add_figure("kernel/spectrum", plot_spectrum(kernel), pbar.n)