In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
RUN_NAME = "skew-init"  # | input("enter name for run")

In [3]:
import os
from datetime import datetime

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"

In [4]:
from pathlib import Path
from time import time
from typing import Any

import numpy as np
import pandas
import torch
import torchinfo
from linodenet.models import LinODE, LinODECell, LinODEnet
from linodenet.projections.functional import skew_symmetric, symmetric
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
from torch import Tensor, jit, tensor
from torch.optim import SGD, Adam, AdamW
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm, trange

import tsdm
from tsdm.datasets import DATASETS
from tsdm.encoders.functional import time2float
from tsdm.logutils import (
    log_kernel_information,
    log_metrics,
    log_model_state,
    log_optimizer_state,
)
from tsdm.losses import LOSSES
from tsdm.tasks import KIWI_RUNS_TASK
from tsdm.util import grad_norm, multi_norm

# Initialize Task

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
BATCH_SIZE = 128
PRD_HORIZON = 30
OBS_HORIZON = 90
HORIZON = SEQLEN = OBS_HORIZON + PRD_HORIZON

In [6]:
TASK = KIWI_RUNS_TASK(
    forecasting_horizon=PRD_HORIZON,
    observation_horizon=OBS_HORIZON,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=2048,
)

DATASET = TASK.dataset
ts = TASK.timeseries
md = TASK.metadata
NUM_PTS, NUM_DIM = ts.shape

## Initialize Loss

In [7]:
LOSS = TASK.test_metric.to(device=DEVICE)

TASK.loss_weights

## Initialize DataLoaders

In [8]:
TRAINLOADER = TASK.batchloader
EVALLOADERS = TASK.dataloaders

## Hyperparamters

In [9]:
def join_dicts(d: dict[str, Any]) -> dict[str, Any]:
    """Recursively join dict by composing keys with '/'."""
    result = {}
    for key, val in d.items():
        if isinstance(val, dict):
            result |= join_dicts(
                {f"{key}/{subkey}": item for subkey, item in val.items()}
            )
        else:
            result[key] = val
    return result


def add_prefix(d: dict[str, Any], /, prefix: str) -> dict[str, Any]:
    return {f"{prefix}/{key}": item for key, item in d.items()}


# OPTIMIZER_CONIFG = {
#     "__name__": "SGD",
#     "lr": 0.001,
#     "momentum": 0,
#     "dampening": 0,
#     "weight_decay": 0,
#     "nesterov": False,
# }

OPTIMIZER_CONIFG = {
    "__name__": "Adam",
    "lr": 0.01,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0,
    "amsgrad": False,
}

MODEL_CONFIG = {
    "__name__": "LinODEnet",
    "input_size": NUM_DIM,
    "hidden_size": 128,
    "embedding_type": "concat",
    "Encoder_cfg": {"nblocks": 10},
    "Decoder_cfg": {"nblocks": 10},
    "System_cfg": {
        "kernel_initialization": "skew-symmetric",
        "kernel_parametrization": "skew_symmetric",
        "scale": 0.1,
    },
}

HPARAMS = join_dicts(
    {
        "Optimizer": OPTIMIZER_CONIFG,
        "Model": MODEL_CONFIG,
    }
)

## Initialize Model

In [10]:
MODEL = LinODEnet
model = MODEL(**MODEL_CONFIG)
model.to(device=DEVICE, dtype=DTYPE)
torchinfo.summary(model)

## Initalize Optimizer

In [11]:
from tsdm.optimizers import OPTIMIZERS
from tsdm.util import initialize_from

In [12]:
OPTIMIZER_CONIFG |= {"params": model.parameters()}
optimizer = initialize_from(OPTIMIZERS, **OPTIMIZER_CONIFG)

## Utility functions

In [13]:
batch = next(iter(TRAINLOADER[0]))
T, X = batch
targets = X[..., OBS_HORIZON:, TASK.targets.index].clone()
# assert targets.shape == (BATCH_SIZE, PRD_HORIZON, len(TASK.targets))

inputs = X.clone()
inputs[:, OBS_HORIZON:, TASK.targets.index] = NAN
inputs[:, OBS_HORIZON:, TASK.observables.index] = NAN
# assert inputs.shape == (BATCH_SIZE, HORIZON, NUM_DIM)

In [14]:
targets = X[..., OBS_HORIZON:, TASK.targets.index].clone()
targets.shape

In [15]:
def prep_batch(batch: tuple[Tensor, Tensor]):
    """Get batch and create model inputs and targets"""
    T, X = batch
    targets = X[..., OBS_HORIZON:, TASK.targets.index].clone()
    # assert targets.shape == (BATCH_SIZE, PRD_HORIZON, len(TASK.targets))

    inputs = X.clone()
    inputs[:, OBS_HORIZON:, TASK.targets.index] = NAN
    inputs[:, OBS_HORIZON:, TASK.observables.index] = NAN
    # assert inputs.shape == (BATCH_SIZE, HORIZON, NUM_DIM)
    return T, inputs, targets


def get_all_preds(model, dataloader):
    Y, Ŷ = [], []
    for batch in (pbar := tqdm(dataloader, leave=False)):
        with torch.no_grad():
            model.zero_grad()
            times, inputs, targets = prep_batch(batch)
            outputs = model(times, inputs)
            predics = outputs[:, OBS_HORIZON:, TASK.targets.index]
            loss = LOSS(targets, predics)
            Y.append(targets)
            Ŷ.append(predics)
        if pbar.n == 5:
            break

    targets, predics = torch.cat(Y, dim=0), torch.cat(Ŷ, dim=0)
    mask = torch.isnan(targets)
    targets[mask] = torch.tensor(0.0)
    predics[mask] = torch.tensor(0.0)
    # scale = 1/torch.mean(mask.to(dtype=torch.float32))
    # targets *= scale
    # predics *= scale
    return targets, predics

## Logging Utilities

In [16]:
from tsdm.logutils import compute_metrics


def log_all(i, model, writer, optimizer):
    kernel = model.system.kernel.clone().detach().cpu()
    log_kernel_information(i, writer, kernel, histograms=True)
    log_optimizer_state(i, writer, optimizer, histograms=True)


def log_hparams(i, writer, *, metric_dict, hparam_dict):
    hparam_dict |= {"epoch": i}
    metric_dict = add_prefix(metric_dict, "hparam")
    writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict)


metrics = {key: LOSSES[key] for key in ("ND", "NRMSE", "MSE", "MAE")}
# assert any(isinstance(TASK.test_metric, metric) for metric in metrics.values())
metrics = {key: LOSSES[key]() for key in ("ND", "NRMSE", "MSE", "MAE")} | {
    "WRMSE": LOSS
}

print("WARMUP")
t = torch.randn(NUM_DIM).to(DEVICE)
x = torch.randn(1, NUM_DIM).to(device=DEVICE)
y = model(t, x)
torch.linalg.norm(y).backward()
model.zero_grad()

In [None]:
RUN_START = tsdm.util.now()
CHECKPOINTDIR = Path(
    f"checkpoints/{MODEL.__name__}/{DATASET.__name__}/{RUN_NAME}/{RUN_START}"
)
CHECKPOINTDIR.mkdir(parents=True, exist_ok=True)
LOGGING_DIR = f"runs/{MODEL.__name__}/{DATASET.__name__}/{RUN_NAME}/{RUN_START}"
writer = SummaryWriter(LOGGING_DIR)

### Training Start

In [None]:
# i = 0
epoch = 1

with torch.no_grad():
    # log optimizer state first !!!
    # log_optimizer_state(epoch, writer, optimizer, histograms=True)
    log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

    for key in ((0, "train"), (0, "test")):
        dataloader = EVALLOADERS[key]
        y, ŷ = get_all_preds(model, dataloader)
        assert torch.isfinite( y).all()
        log_metrics(
            epoch, writer, metrics=metrics, targets=y, predics=ŷ, prefix=key[1]
        )


for _ in (epochs := trange(100)):
    epoch += 1
    for batch in (batches := tqdm(TRAINLOADER[0])):
        i += 1
        # Optimization step
        model.zero_grad()
        times, inputs, targets = prep_batch(batch)

        forward_time = time()
        outputs = model(times, inputs)
        forward_time = time() - forward_time

        predics = outputs[:, OBS_HORIZON:, TASK.targets.index]

        # get rid of nan-values in teh targets.
        mask = torch.isnan(targets)
        targets[mask] = torch.tensor(0.0)
        predics[mask] = torch.tensor(0.0)

        # # compensate NaN-Value with upscaling
        # scale = 1/torch.mean(mask.to(dtype=torch.float32))
        # targets *= scale
        # predics *= scale

        loss = LOSS(targets, predics)

        backward_time = time()
        loss.backward()
        backward_time = time() - backward_time

        optimizer.step()

        # batch logging
        logging_time = time()
        with torch.no_grad():
            log_metrics(
                i,
                writer,
                metrics=metrics,
                targets=targets,
                predics=predics,
                prefix="batch",
            )
            log_optimizer_state(i, writer, optimizer, prefix="batch")

            lval = loss.clone().detach().cpu().numpy()
            gval = grad_norm(list(model.parameters())).clone().detach().cpu().numpy()
            if torch.any(torch.isnan(loss)):
                raise RuntimeError("NaN-value encountered!!")
        logging_time = time() - logging_time

        batches.set_postfix(
            loss=f"{lval:.2e}",
            gnorm=f"{gval:.2e}",
            Δt_forward=f"{forward_time:.1f}",
            Δt_backward=f"{backward_time:.1f}",
            Δt_logging=f"{logging_time:.1f}",
        )

    with torch.no_grad():
        # log optimizer state first !!!
        log_optimizer_state(epoch, writer, optimizer, histograms=True)
        log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

        for key in ((0, "train"), (0, "test")):
            dataloader = EVALLOADERS[key]
            y, ŷ = get_all_preds(model, dataloader)
            metric_values = compute_metrics(metrics, targets=y, predics=ŷ)
            log_metrics(
                epoch, writer, metrics=metrics, values=metric_values, prefix=key[1]
            )
            # log_hparams(epoch, writer, metric_dict=metric_values, hparam_dict=HPARAMS)

        # Model Checkpoint
        torch.jit.save(model, CHECKPOINTDIR.joinpath(f"{MODEL.__name__}-{epochs.n}"))
        torch.save(
            {
                "optimizer": optimizer,
                "epoch": epoch,
                "batch": i,
            },
            CHECKPOINTDIR.joinpath(f"{optimizer.__class__.__name__}-{epochs.n}"),
        )

In [None]:
buffers = dict(model.named_buffers())

In [None]:
zhat_pre = buffers['zhat_pre']
zhat_post = buffers['zhat_post']
xhat_pre = buffers['xhat_pre']
xhat_post = buffers['xhat_post']

In [None]:
zhat_post.shape, zhat_pre.shape, xhat

In [None]:
zhat_pre.shape

In [None]:
(xhat_post - xhat_pre).abs().amax(dim=(0,2)).max()

In [None]:
(zhat_pre - zhat_post).abs().amax(dim=(0,2)).max()

In [None]:
[buffers['xhat_pre'][:, i, :].max() for i in range(120)]

In [None]:
buffers['xhat_post'][:, -1, :].max()