In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from datetime import datetime

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"

In [None]:
import tsdm
import torch
import pandas
import numpy as np
from pathlib import Path

from tqdm.auto import trange, tqdm
from torch import tensor, Tensor, jit
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter

from tsdm.losses import LOSSES
from tsdm.util import grad_norm, multi_norm
from tsdm.datasets import Electricity
from tsdm.encoders import time2float

from linodenet.models import LinODEnet, LinODECell, LinODE
from linodenet.projections import symmetric, skew_symmetric

In [None]:
DEVICE = torch.device("cuda")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
BATCH_SIZE = 16
PRD_HORIZON = 24
OBS_HORIZON = 96
SEQLEN = PRD_HORIZON + OBS_HORIZON

In [None]:
from tsdm.tasks import ETDatasetInformer

TASK = ETDatasetInformer(
    dataset="ETTh1",
    forecasting_horizon=24,
    observation_horizon=96,
    test_metric="MSE",
    time_encoder="time2float",
)
DATASET = TASK.dataset

NUM_PTS, NUM_DIM = DATASET.dataset.shape

### Utility functions

In [None]:
@jit.script
def prep_batch(batch: tuple[Tensor, Tensor, Tensor], observation_horizon: int):
    T, X, Y = batch
    targets = Y[..., observation_horizon:].clone()
    Y[..., observation_horizon:] = float("nan")  # mask future
    X[..., observation_horizon:, :] = float("nan")  # mask future
    inputs = torch.cat([X, Y.unsqueeze(-1)], dim=-1)
    return T, inputs, targets


def get_all_preds(model, dataloader):
    Y, Ŷ = [], []
    for batch in tqdm(dataloader, leave=False):
        with torch.no_grad():
            model.zero_grad()
            times, inputs, targets = prep_batch(batch, OBS_HORIZON)
            outputs, _ = model(times, inputs)
            predics = outputs[:, OBS_HORIZON:, -1]
            loss = LOSS(predics, targets)
            Y.append(targets)
            Ŷ.append(predics)

    return torch.cat(Y, dim=0), torch.cat(Ŷ, dim=0)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

# logging utilities

def log_all(i, model, writer, optimizer):
    kernel = model.system.kernel.clone().detach().cpu()
    log_kernel_information(i, writer, kernel, histograms=True)
    log_optimizer_state(i, writer, optimizer, histograms=True)

### Plotting Kernel Utility Function

In [None]:
from dataclasses import dataclass, field


@dataclass
class Logger:
    writer: int = field(init=False)
    model: int
    task: int
    optimizer: int = 3

    def __post_init__(self):
        self.writer

In [None]:
Logger(1, 2, 4).writer

In [None]:
from torch.optim import Adam
from linodenet.models import LinODEnet
from tsdm.util.logging import (
    log_optimizer_state,
    log_kernel_information,
    log_model_state,
    log_metrics,
    compute_metrics,
)

MODEL = LinODEnet
model = MODEL(input_size=NUM_DIM, hidden_size=32, embedding_type="concat")
model.to(device=DEVICE, dtype=DTYPE)
LOSS = TASK.test_metric()

In [None]:
SEED = torch.seed() % 1000

In [None]:
l = ("a", "b", "c")

In [None]:
type(Literal[l])

In [None]:
GENERATOR = torch.Generator()
GENERATOR.manual_seed(SEED)

In [None]:
optimizer = Adam(model.parameters(), lr=0.0001)
# dataloader for training
TRAINLOADER = TASK.get_dataloader("train", batch_size=64, generator=GENERATOR)
# dataloaders for evaluation
eval_loaders = {
    split: TASK.get_dataloader(split, batch_size=1024, shuffle=False)
    for split in ("train", "valid", "test")
}

In [None]:
# warmup - set all gradients to none
y, yhat = model(torch.randn(NUM_DIM).cuda(), torch.randn(1, NUM_DIM).cuda())
torch.linalg.norm(y).backward()
model.zero_grad()

In [None]:
RUN_START = tsdm.util.now()
CHECKPOINTDIR = Path(f"checkpoints/{RUN_START}/")
CHECKPOINTDIR.mkdir(parents=True, exist_ok=True)

writer = SummaryWriter(f"runs/{MODEL.__name__}/{DATASET.__name__}{RUN_START}")
metrics = {key: LOSSES[key] for key in ("ND", "NRMSE", "MSE", "MAE")}
assert TASK.test_metric in metrics.values()

In [None]:
from dataclasses import dataclass

@dataclass
class Logger:
    writer: SummaryWriter
    model: Model
    optimizer: Optimizer
    dataloaders: dict[str, DataLoader]
    metrics: dict[key, type[Loss]]
    epoch: Optional[int] = None
    batch: Optional[int] = None
    history: Optional[dict[str, DataFrame]] = None

    def __post_init__(self):
        self.KEYS = list(dataloaders)
        self.batch = 0 if batch is None else batch
        self.epoch = 0 if batch is None else batch

        if self.history is None:
            self.history["batch"] = DataFrame(columns=metrics)
            for key in self.KEYS:       
                self.history[key] = DataFrame(columns=metrics)

    @torch.no_grad()
    def log_at_batch_end(self, *, targets: Tensor, predics: Tensor):
        self.batch += 1
        hist = compute_metrics(targets=targets, predics=predics, metrics=self.metrics)
        log_metrics(self.batch, self.writer, hist, prefix="batch")
        log_optimizer_state(self.batch, self.writer, self.optimizer, prefix="batch")
        self.history["batch"].append(self._to_cpu(hist))

    @torch.no_grad()
    def log_at_epoch_end(self, *, targets: Tensor, predics: Tensor):
        self.epoch += 1 
        
        log_optimizer_state(epoch, writer, optimizer, histograms=True)
        log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

        for key, dataloader in eval_loaders.items():
            y, ŷ = get_all_preds(model, dataloader)
            hist = compute_metrics(targets=y, predics=ŷ, metrics=self.metrics)
            log_metrics(self.epoch, self.writer, hist, prefix=key)
            self.history[key].append(self._to_cpu(hist))
    
    @staticmethod
    def _to_cpu(scalar_dict: dict[str, Tensor])- > dict[str, float]:
        return {key:scalar.item() for key, scalar in scalar_dict.items()}

In [None]:
type(batch_hist["ND"].item())

In [None]:
type(writer)

In [None]:
df = pandas.DataFrame(columns=metrics)

In [None]:
df = df.append({k: v.item() for k, v in batch_hist.items()}, ignore_index=True)

In [None]:
pandas.DataFrame.from_dict(
    {k: v.cpu() for k, v in batch_hist.items()}, orient="columns"
)

In [None]:
i = 0


# with torch.no_grad():
#     # log optimizer state first !!!
#     log_optimizer_state(epoch, writer, optimizer, histograms=True)
#     log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

#     for name, dataloader in eval_loaders.items():
#         y, ŷ = get_all_preds(model, dataloader)
#         hist = compute_metrics(targets=y, predics=ŷ, metrics=metrics)
#         log_metrics(i, writer, hist, prefix=name)

#     # Model Checkpoint
#     torch.jit.save(model, CHECKPOINTDIR.joinpath(f"{MODEL.__name__}-{epochs.n}"))
#     torch.save(
#         {
#             "trainloader" : TRAINLOADER,
#             "optimizer": optimizer,
#             "epoch": epoch,
#             "batch": batch,
#         }
#     )


for epoch in (epochs := trange(100)):
    for batch in (batches := tqdm(TRAINLOADER)):
        # Optimization step
        model.zero_grad()
        times, inputs, targets = prep_batch(batch, OBS_HORIZON)
        outputs, _ = model(times, inputs)
        predics = outputs[:, OBS_HORIZON:, -1]
        loss = LOSS(predics, targets)
        loss.backward()
        optimizer.step()

        # batch logging
        with torch.no_grad():
            i += 1
            batch_hist = compute_metrics(
                targets=targets, predics=predics, metrics=metrics
            )
            log_metrics(epoch, writer, batch_hist, prefix="batch")
            log_optimizer_state(i, writer, optimizer, prefix="batch")

            lval = loss.clone().detach().cpu().numpy()
            gval = grad_norm(list(model.parameters())).clone().detach().cpu().numpy()
            batches.set_postfix(loss=lval, gnorm=gval)

            if torch.any(torch.isnan(loss)):
                raise RuntimeError("NaN-value encountered!!")

    with torch.no_grad():
        # log optimizer state first !!!
        log_optimizer_state(epoch, writer, optimizer, histograms=True)
        log_kernel_information(epoch, writer, model.system.kernel, histograms=True)

        for name, dataloader in eval_loaders.items():
            y, ŷ = get_all_preds(model, dataloader)
            hist = compute_metrics(targets=y, predics=ŷ, metrics=metrics)
            log_metrics(epoch, writer, hist, prefix=name)

        # Model Checkpoint
        torch.jit.save(model, CHECKPOINTDIR.joinpath(f"{MODEL.__name__}-{epochs.n}"))
        torch.save(
            {
                "optimizer": optimizer,
                "epoch": epoch,
                "batch": batch,
                "generator": GENERATOR.get_state(),
                "trainloader": TRAINLOADER,
            }
        )

In [None]:
import tensorboard as tb
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

In [None]:
ea = EventAccumulator(
    "/home/rscholz/Projects/KIWI/tsdm/dev/experiments/runs/LinODEnet/ETTh12021-09-29T02:57:42/events.out.tfevents.1632877062.workstation.373922.0",
    size_guidance={  # see below regarding this argument
        event_accumulator.COMPRESSED_HISTOGRAMS: 500,
        event_accumulator.IMAGES: 4,
        event_accumulator.AUDIO: 4,
        event_accumulator.SCALARS: 0,
        event_accumulator.HISTOGRAMS: 1,
    },
)

In [None]:
ea.Reload()

In [None]:
g = torch.Generator()
torch.save(g, "generator.torch")

In [None]:
torch.get_rng_state()

In [None]:
pandas.DataFrame(ea.Scalars("train:metrics/MSE"))

In [None]:
model