In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
from datetime import datetime

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"

In [3]:
import tsdm
import torch
import pandas
import numpy as np
from pathlib import Path

from tqdm.auto import trange, tqdm
from torch import tensor, Tensor, jit
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter

from tsdm.losses import LOSSES
from tsdm.util import grad_norm, multi_norm
from tsdm.datasets import Electricity
from tsdm.encoders import time2float

from linodenet.models import LinODEnet, LinODECell, LinODE
from linodenet.projections import symmetric, skew_symmetric

In [4]:
DEVICE = torch.device("cuda")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
BATCH_SIZE = 16
PRD_HORIZON = 24
OBS_HORIZON = 96
SEQLEN = PRD_HORIZON + OBS_HORIZON

In [5]:
from tsdm.tasks import ETDatasetInformer

TASK = ETDatasetInformer(
    dataset="ETTm1",
    forecasting_horizon=24,
    observation_horizon=96,
    test_metric="MSE",
    time_encoder="time2float",
)
DATASET = TASK.dataset

NUM_PTS, NUM_DIM = DATASET.dataset.shape

### Utility functions

In [6]:
@jit.script
def prep_batch(batch: tuple[Tensor, Tensor, Tensor], observation_horizon: int):
    T, X, Y = batch
    targets = Y[..., observation_horizon:].clone()
    Y[..., observation_horizon:] = float("nan")  # mask future
    X[..., observation_horizon:, :] = float("nan")  # mask future
    inputs = torch.cat([X, Y.unsqueeze(-1)], dim=-1)
    return T, inputs, targets


def get_all_preds(model, dataloader):
    Y, Ŷ = [], []
    for batch in tqdm(dataloader, leave=False):
        with torch.no_grad():
            model.zero_grad()
            times, inputs, targets = prep_batch(batch, OBS_HORIZON)
            outputs, _ = model(times, inputs)
            predics = outputs[:, OBS_HORIZON:, -1]
            loss = LOSS(predics, targets)
            Y.append(targets)
            Ŷ.append(predics)

    return torch.cat(Y, dim=0), torch.cat(Ŷ, dim=0)

# logging utilities

In [7]:
def log_all(i, model, writer, optimizer):
    kernel = model.system.kernel.clone().detach().cpu()
    log_kernel_information(i, writer, kernel, histograms=True)
    log_optimizer_state(i, writer, optimizer, histograms=True)

### Plotting Kernel Utility Function

In [8]:
from torch.optim import Adam
from linodenet.models import LinODEnet
from tsdm.util.logging import (
    log_optimizer_state,
    log_kernel_information,
    log_model_state,
    log_metrics,
)

MODEL = LinODEnet
model = MODEL(input_size=NUM_DIM, hidden_size=32, embedding_type="concat")
model.to(device=DEVICE, dtype=DTYPE)
LOSS = TASK.test_metric()

In [9]:
optimizer = Adam(model.parameters(), lr=0.0001)
# dataloader for training
TRAINLOADER = TASK.get_dataloader("train", batch_size=64)
# dataloaders for evaluation
eval_loaders = {
    split: TASK.get_dataloader(split, batch_size=1024, shuffle=False)
    for split in ("train", "valid", "test")
}

In [10]:
# warmup - set all gradients to none
y, yhat = model(torch.randn(NUM_DIM).cuda(), torch.randn(1, NUM_DIM).cuda())
torch.linalg.norm(y).backward()
model.zero_grad()

In [11]:
RUN_START = tsdm.util.now()
CHECKPOINTDIR = Path(f"checkpoints/{RUN_START}/")
CHECKPOINTDIR.mkdir(parents=True, exist_ok=True)

writer = SummaryWriter(f"runs/{MODEL.__name__}/{DATASET.__name__}{RUN_START}")
metrics = {key: LOSSES[key] for key in ("ND", "NRMSE", "MSE", "MAE")}
assert TASK.test_metric in metrics.values()

In [12]:
i = -1

for epoch in (epochs := trange(100)):
    # log
    with torch.no_grad():
        # log optimizer state first !!!
        log_optimizer_state(epoch, writer, optimizer, histograms=True)
        log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

        for name, dataloader in eval_loaders.items():
            y, ŷ = get_all_preds(model, dataloader)
            log_metrics(epoch, writer, y, ŷ, metrics, prefix=name)

    for batch in (batches := tqdm(TRAINLOADER)):
        i += 1
        # Optimization step
        model.zero_grad()
        times, inputs, targets = prep_batch(batch, OBS_HORIZON)
        outputs, _ = model(times, inputs)
        predics = outputs[:, OBS_HORIZON:, -1]
        loss = LOSS(predics, targets)
        loss.backward()
        optimizer.step()

        # batch logging
        with torch.no_grad():
            i += 1
            log_metrics(i, writer, targets, predics, metrics, prefix="batch")
            log_optimizer_state(i, writer, optimizer, prefix="batch")

            lval = loss.clone().detach().cpu().numpy()
            gval = grad_norm(list(model.parameters())).clone().detach().cpu().numpy()
            batches.set_postfix(loss=lval, gnorm=gval)

            if torch.any(torch.isnan(loss)):
                raise RuntimeError("NaN-value encountered!!")

    with torch.no_grad():
        # log optimizer state first !!!
        log_optimizer_state(epoch, writer, optimizer, histograms=True)
        log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

        for name, dataloader in eval_loaders.items():
            y, ŷ = get_all_preds(model, dataloader)
            log_metrics(epoch, writer, y, ŷ, metrics, prefix=name)

        # Model Checkpoint
        torch.jit.save(model, CHECKPOINTDIR.joinpath(f"{MODEL.__name__}-{epochs.n}"))
        torch.save(
            {
                "optimizer": optimizer,
                "epoch": epoch,
                "batch": i,
            },
            CHECKPOINTDIR.joinpath(f"{optimizer.__name__}-{epochs.n}"),
        )

In [None]:
import tensorboard as tb
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

In [None]:
ea = EventAccumulator(
    "/home/rscholz/Projects/KIWI/tsdm/dev/experiments/runs/LinODEnet/ETTh12021-09-29T02:57:42/events.out.tfevents.1632877062.workstation.373922.0",
    size_guidance={  # see below regarding this argument
        event_accumulator.COMPRESSED_HISTOGRAMS: 500,
        event_accumulator.IMAGES: 4,
        event_accumulator.AUDIO: 4,
        event_accumulator.SCALARS: 0,
        event_accumulator.HISTOGRAMS: 1,
    },
)

In [None]:
ea.Reload()

In [None]:
pandas.DataFrame(ea.Scalars("train:metrics/MSE"))