In [2]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import os
from datetime import datetime

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"

In [4]:
import tsdm
import torch
import pandas
import numpy as np
from pathlib import Path
from pandas import Timedelta, Timestamp, Index, Series, DataFrame
from tqdm.auto import trange, tqdm

from torch import tensor, Tensor, jit
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam

import torchinfo

from tsdm.datasets import DATASETS
from tsdm.encoders.functional import time2float
from tsdm.losses import LOSSES
from tsdm.tasks import KIWI_RUNS_TASK
from tsdm.util import grad_norm, multi_norm
from tsdm.logutils import (
    log_optimizer_state,
    log_kernel_information,
    log_model_state,
    log_metrics,
)

from linodenet.models import LinODEnet, LinODECell, LinODE
from linodenet.projections.functional import symmetric, skew_symmetric

In [5]:
DEVICE = torch.device("cuda")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
BATCH_SIZE = 64
PRD_HORIZON = 24
OBS_HORIZON = 96
SEQLEN = PRD_HORIZON + OBS_HORIZON

#### Initialize Task

### Controls

- Cumulated_feed_volume_glucose
- Cumulated_feed_volume_medium
- InducerConcentration
- StirringSpeed
- Flow_Air
- Temperature
- Probe_Volume

In [20]:
TASK.timeseries["Flow_Air"].unique()

In [19]:
pandas.isna(TASK.timeseries).mean()

In [6]:
TASK = KIWI_RUNS_TASK(
    forecasting_horizon=24,
    observation_horizon=96,
    train_batch_size = 64,
    eval_batch_size = 1024,
)

### Encoding / decoding

In [16]:
TASK.metadata

In [11]:
from tsdm.encoders.modular import DateTimeEncoder

run = TASK.timeseries.loc[439, 15325]
time = run.reset_index()["measurement_time"]

In [12]:
enc = (time - time[0]) / Timedelta(1, "s")
dec = enc * Timedelta(1, "s") + time[0]
pandas.testing.assert_series_equal(time, dec)

In [None]:
time = 

In [15]:
encoder = DateTimeEncoder()
encoder.fit(time)
encoded = encoder.encode(time)
decoded = encoder.decode(encoded)
pandas.testing.assert_series_equal(time, decoded)

In [59]:
decoded = pandas.to_timedelta(Tensor(enc), unit="s") + time[0]


In [46]:
enc

In [44]:
encoded

In [41]:
pandas.testing.assert_series_equal(time, decoded)

In [30]:
time

In [31]:
decoded

In [29]:
pandas.testing.assert_series_equal(pandas.Series(time), pandas.Series(decoded))

In [None]:
TRAINLOADER = TASK.batchloader
EVALLOADERS = TASK.dataloaders

#### Initialize Loss

In [14]:
loss_weights = {
    "Base": 200,
    "DOT": 100,
    "Glucose": 10,
    "OD600": 20,
}

In [15]:
indices = {
    TASK.timeseries.columns.get_loc(key) : 1/w for key, w in loss_weights.items()
}


class WRMSE()
    """Weighted RMSE loss





In [None]:
torch.randn(3,4,5,6) * torch.randn(3,4)

In [22]:
TASK.preprocessor.inverse_transform(X.cpu().numpy())

In [23]:
X.shape

In [None]:
LOSS = TASK.test_metric

#### Initialize Model

In [None]:
MODEL = LinODEnet
model = MODEL(input_size=NUM_DIM, hidden_size=32, embedding_type="concat")
model.to(device=DEVICE, dtype=DTYPE)
torchinfo.summary(model)

#### Initalize Optimizer

In [None]:
optimizer = Adam(model.parameters(), lr=0.001)

#### Initialize Logger

In [None]:
# dataloader for training
TRAINLOADER = TASK.get_dataloader("train", batch_size=64)
# dataloaders for evaluation
eval_loaders = {
    split: TASK.get_dataloader(split, batch_size=1024, shuffle=False)
    for split in ("train", "valid", "test")
}

### Utility functions

In [None]:
@jit.script
def prep_batch(batch: tuple[Tensor, Tensor, Tensor], observation_horizon: int):
    T, X, Y = batch
    targets = Y[..., observation_horizon:].clone()
    Y[..., observation_horizon:] = float("nan")  # mask future
    X[..., observation_horizon:, :] = float("nan")  # mask future
    inputs = torch.cat([X, Y.unsqueeze(-1)], dim=-1)
    return T, inputs, targets


def get_all_preds(model, dataloader):
    Y, Ŷ = [], []
    for batch in tqdm(dataloader, leave=False):
        with torch.no_grad():
            model.zero_grad()
            times, inputs, targets = prep_batch(batch, OBS_HORIZON)
            outputs, _ = model(times, inputs)
            predics = outputs[:, OBS_HORIZON:, -1]
            loss = LOSS(predics, targets)
            Y.append(targets)
            Ŷ.append(predics)

    return torch.cat(Y, dim=0), torch.cat(Ŷ, dim=0)

# logging utilities

In [None]:
def log_all(i, model, writer, optimizer):
    kernel = model.system.kernel.clone().detach().cpu()
    log_kernel_information(i, writer, kernel, histograms=True)
    log_optimizer_state(i, writer, optimizer, histograms=True)

In [None]:
# warmup - set all gradients to none
y, yhat = model(torch.randn(NUM_DIM).cuda(), torch.randn(1, NUM_DIM).cuda())
torch.linalg.norm(y).backward()
model.zero_grad()

In [None]:
RUN_START = tsdm.util.now()
CHECKPOINTDIR = Path(f"checkpoints/{RUN_START}/")
CHECKPOINTDIR.mkdir(parents=True, exist_ok=True)

writer = SummaryWriter(f"runs/{MODEL.__name__}/{DATASET.__name__}{RUN_START}")
metrics = {key: LOSSES[key] for key in ("ND", "NRMSE", "MSE", "MAE")}
assert any(isinstance(TASK.test_metric, metric) for metric in metrics.values())
metrics = {key: LOSSES[key]() for key in ("ND", "NRMSE", "MSE", "MAE")}

### Training Start

In [None]:
i = -1

for epoch in (epochs := trange(100)):
    # log
    with torch.no_grad():
        # log optimizer state first !!!
        log_optimizer_state(epoch, writer, optimizer, histograms=True)
        log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

        for name, dataloader in eval_loaders.items():
            y, ŷ = get_all_preds(model, dataloader)
            log_metrics(epoch, writer, y, ŷ, metrics, prefix=name)

    for batch in (batches := tqdm(TRAINLOADER)):
        i += 1
        # Optimization step
        model.zero_grad()
        times, inputs, targets = prep_batch(batch, OBS_HORIZON)
        outputs, _ = model(times, inputs)
        predics = outputs[:, OBS_HORIZON:, -1]
        loss = LOSS(predics, targets)
        loss.backward()
        optimizer.step()

        # batch logging
        with torch.no_grad():
            i += 1
            log_metrics(i, writer, targets, predics, metrics, prefix="batch")
            log_optimizer_state(i, writer, optimizer, prefix="batch")

            lval = loss.clone().detach().cpu().numpy()
            gval = grad_norm(list(model.parameters())).clone().detach().cpu().numpy()
            batches.set_postfix(loss=lval, gnorm=gval)

            if torch.any(torch.isnan(loss)):
                raise RuntimeError("NaN-value encountered!!")

    with torch.no_grad():
        # log optimizer state first !!!
        log_optimizer_state(epoch, writer, optimizer, histograms=True)
        log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

        for name, dataloader in eval_loaders.items():
            y, ŷ = get_all_preds(model, dataloader)
            log_metrics(epoch, writer, y, ŷ, metrics, prefix=name)

        # Model Checkpoint
        torch.jit.save(model, CHECKPOINTDIR.joinpath(f"{MODEL.__name__}-{epochs.n}"))
        torch.save(
            {
                "optimizer": optimizer,
                "epoch": epoch,
                "batch": i,
            },
            CHECKPOINTDIR.joinpath(f"{optimizer.__name__}-{epochs.n}"),
        )

In [None]:
for name, dataloader in eval_loaders.items():
    print(name)
    y, ŷ = get_all_preds(model, dataloader)
    log_metrics(epoch, writer, y, ŷ, metrics, prefix=name)

In [None]:
import tensorboard as tb
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

In [None]:
ea = EventAccumulator(
    "/home/rscholz/Projects/KIWI/tsdm/dev/experiments/runs/LinODEnet/ETTh12021-09-29T02:57:42/events.out.tfevents.1632877062.workstation.373922.0",
    size_guidance={  # see below regarding this argument
        event_accumulator.COMPRESSED_HISTOGRAMS: 500,
        event_accumulator.IMAGES: 4,
        event_accumulator.AUDIO: 4,
        event_accumulator.SCALARS: 0,
        event_accumulator.HISTOGRAMS: 1,
    },
)

In [None]:
ea.Reload()

In [None]:
pandas.DataFrame(ea.Scalars("train:metrics/MSE"))

In [None]:
from typing import TypeVar, Union, Sequence, Generic, Callable

T = TypeVar('T')
S = TypeVar('S')

SingleOrSequence = Union[T, Sequence[T], list[S]]

In [None]:
SingleOrSequence[dict[str, int], float]

In [None]:
SingleOrSequence

In [None]:
class A:
    @classmethod
    def __getitem__(cls, item):
        return Union[item, Sequence[item]]

In [None]:
A[int]

In [None]:
Callable[..., int]

In [None]:
ObjectType = TypeVar("ObjectType")
r"""Generic type hint for instances."""

ClassType = TypeVar("ClassType")
r"""Generic type hint for classes."""

FunctionType = TypeVar("FunctionType")
r"""Generic type hint for function."""

ReturnType = TypeVar("ReturnType")
r"""Generic type hint for return type."""


#
LookupTable = dict[str, ObjectType]
ClassLookupTable = LookupTable[type[ObjectType]]
FunctionLookupTable = LookupTable[FunctionType]
CallableLookupTable = LookupTable[Callable[..., ReturnType]]
CombinedLookupTable = Union[
    dict[str, FunctionType],
    dict[str, type[ObjectType]],
    dict[str, Union[FunctionType, type[ObjectType]]],
]

In [None]:
CallableLookupTable[int]