In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

# logging.basicConfig(level=logging.INFO)

In [2]:
import os

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
RUN_NAME = "SurrLoss+Sequential_Filter"  # | input("enter name for run")

In [3]:
from pathlib import Path
from time import perf_counter, time
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas
import torch
import torchinfo
from linodenet.models import LinODE, LinODECell, LinODEnet
from linodenet.projections.functional import skew_symmetric, symmetric
from linodenet.models.filters import SequentialFilter
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
from torch import Tensor, jit, nn, tensor
from torch.optim import SGD, Adam, AdamW
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

import tsdm
from tsdm.datasets import DATASETS
from tsdm.encoders.functional import time2float
from tsdm.logutils import (
    log_kernel_information,
    log_metrics,
    log_model_state,
    log_optimizer_state,
)
from tsdm.losses import LOSSES
from tsdm.tasks import KIWI_RUNS_TASK
from tsdm.util import grad_norm, multi_norm

In [4]:
torch.backends.cudnn.benchmark = True

# Initialize Task

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
BATCH_SIZE = 256
FORECAST_ALL = True

# on average ca 30s between timestamps, i.e. 2 obs = 1min
# let's increase horizon. OBS: 240 = 2h, PRD = 120 = 1h
PRD_HORIZON = 120
OBS_HORIZON = 240
HORIZON = SEQLEN = OBS_HORIZON + PRD_HORIZON

In [6]:
task = KIWI_RUNS_TASK(
    forecasting_horizon=PRD_HORIZON,
    observation_horizon=OBS_HORIZON,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=2048,
)

DATASET = task.dataset
ts = task.timeseries
md = task.metadata
NUM_PTS, NUM_DIM = ts.shape

## Initialize Loss & Metrics

In [7]:
TASK_LOSS = task.test_metric  # .to(device=DEVICE)
metrics = {key: LOSSES[key] for key in ("ND", "NRMSE", "MSE", "MAE")}
# assert any(isinstance(TASK.test_metric, metric) for metric in metrics.values())
metrics = {key: LOSSES[key]() for key in ("ND", "NRMSE", "MSE", "MAE")} | {
    "WRMSE": TASK_LOSS
}

# LOSS = TASK_LOSS
# Let's try something else
LOSS = metrics["NRMSE"]

In [8]:
task.loss_weights

## Initialize DataLoaders

In [9]:
TRAINLOADERS = task.batchloaders
TRAINLOADER = TRAINLOADERS[(0, "train")]
EVALLOADERS = task.dataloaders

In [12]:
task.timeseries.columns

## Hyperparamters

In [10]:
def join_dicts(d: dict[str, Any]) -> dict[str, Any]:
    """Recursively join dict by composing keys with '/'."""
    result = {}
    for key, val in d.items():
        if isinstance(val, dict):
            result |= join_dicts(
                {f"{key}/{subkey}": item for subkey, item in val.items()}
            )
        else:
            result[key] = val
    return result


def add_prefix(d: dict[str, Any], /, prefix: str) -> dict[str, Any]:
    return {f"{prefix}/{key}": item for key, item in d.items()}


# OPTIMIZER_CONIFG = {
#     "__name__": "SGD",
#     "lr": 0.001,
#     "momentum": 0,
#     "dampening": 0,
#     "weight_decay": 0,
#     "nesterov": False,
# }

# OPTIMIZER_CONFIG = {
#     "__name__": "Adam",
#     "lr": 0.01,
#     "betas": (0.9, 0.999),
#     "eps": 1e-08,
#     "weight_decay": 0,
#     "amsgrad": False,
# }


OPTIMIZER_CONFIG = {
    "__name__": "AdamW",
    "lr": 0.001,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0.001,
    "amsgrad": False,
}


SYSTEM = {
    "__name__": "LinODECell",
    "input_size": int,
    "kernel_initialization": "skew-symmetric",
}

EMBEDDING = {
    "__name__": "ConcatEmbedding",
    "input_size": int,
    "hidden_size": int,
}
FILTER = {
    "__name__": "SequentialFilter",
    "input_size": int,
    "hidden_size": int,
    "autoregressive": True,
}

# FILTER = {
#     "__name__": "RecurrentCellFilter",
#     "concat": True,
#     "input_size": int,
#     "hidden_size": int,
#     "autoregressive": True,
#     "Cell": {
#         "__name__": "GRUCell",
#         "input_size": int,
#         "hidden_size": int,
#         "bias": True,
#         "device": None,
#         "dtype": None,
#     },
# }
from linodenet.models.encoders import ResNet, iResNet

# ENCODER = {"__name__": "ResNet", "__module__": "linodenet.models.encoders","input_size": int, "nblocks": 5, "rezero": True}
# DECODER = {"__name__": "ResNet", "__module__": "linodenet.models.encoders","input_size": int, "nblocks": 5, "rezero": True}


LR_SCHEDULER_CONFIG = {
    "__name__": "ReduceLROnPlateau",
    "mode": "min",
    # (str) – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
    "factor": 0.1,
    # (float) – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
    "patience": 10,
    # (int) – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
    "threshold": 0.0001,
    # (float) – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
    "threshold_mode": "rel",
    # (str) – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
    "cooldown": 0,
    # (int) – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
    "min_lr": 1e-08,
    # (float or list) – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
    "eps": 1e-08,
    # (float) – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
    "verbose": True
    # (bool) – If True, prints a message to stdout for each update. Default: False.
}

MODEL_CONFIG = {
    "__name__": "LinODEnet",
    "input_size": NUM_DIM,
    "hidden_size": 128,
    "embedding_type": "concat",
    "Filter": SequentialFilter.HP,
    "System": SYSTEM,
    "Encoder": ResNet.HP,
    "Decoder": ResNet.HP,
    "Embedding": EMBEDDING,
}


HPARAMS = join_dicts(
    {
        "Optimizer": OPTIMIZER_CONFIG,
        "LR_Scheduler": LR_SCHEDULER_CONFIG,
        "Model": MODEL_CONFIG,
    }
)

In [11]:
model = ResNet(input_size=12, rezero=True)
torchinfo.summary(model, depth=4)

## Initialize Model

In [13]:
MODEL = LinODEnet
model = MODEL(**MODEL_CONFIG)
model.to(device=DEVICE, dtype=DTYPE)
torchinfo.summary(model)

### Initialized Kernel statistics

In [14]:
expA = torch.matrix_exp(model.kernel)
for o in (-np.infty, -2, -1, 1, 2, np.infty, "fro", "nuc"):
    val = torch.linalg.matrix_norm(model.kernel, ord=o).item()
    val2 = torch.linalg.matrix_norm(expA, ord=o).item()
    o = str(o)
    print(f"{o=:6s}\t {val=:10.6f} \t {val2=:10.6f}")

## Initalize Optimizer

In [15]:
from tsdm.optimizers import OPTIMIZERS, LR_SCHEDULERS
from tsdm.util import initialize_from

In [16]:
OPTIMIZER_CONFIG |= {"params": model.parameters()}
optimizer = initialize_from(OPTIMIZERS, **OPTIMIZER_CONFIG)

In [17]:
optimizer

In [18]:
OPTIMIZER_CONFIG

In [19]:
# lr_scheduler = initialize_from(
#     LR_SCHEDULERS, LR_SCHEDULER_CONFIG | {"optimizer": OPTIMIZER_CONFIG}
# )

## Utility functions

In [20]:
batch = next(iter(TRAINLOADER))
T, X = batch
targets = X[..., OBS_HORIZON:, task.targets.index].clone()
# assert targets.shape == (BATCH_SIZE, PRD_HORIZON, len(TASK.targets))

inputs = X.clone()
inputs[:, OBS_HORIZON:, task.targets.index] = NAN
inputs[:, OBS_HORIZON:, task.observables.index] = NAN
# assert inputs.shape == (BATCH_SIZE, HORIZON, NUM_DIM)

In [21]:
targets = X[..., OBS_HORIZON:, task.targets.index].clone()
targets.shape

In [22]:
def prep_batch(batch: tuple[Tensor, Tensor]):
    """Get batch and create model inputs and targets"""
    T, X = batch
    T = T.cuda(non_blocking=True)
    X = X.cuda(non_blocking=True)
    originals = X
    targets = X[..., OBS_HORIZON:, task.targets.index].clone()
    # assert targets.shape == (BATCH_SIZE, PRD_HORIZON, len(TASK.targets))

    inputs = X.clone()
    inputs[:, OBS_HORIZON:, task.targets.index] = NAN
    inputs[:, OBS_HORIZON:, task.observables.index] = NAN
    # assert inputs.shape == (BATCH_SIZE, HORIZON, NUM_DIM)
    return T, inputs, targets, originals


def get_all_preds(model, dataloader):
    y, yhat = [], []
    for batch in (pbar := tqdm(dataloader, leave=False)):
        with torch.no_grad():
            model.zero_grad()
            times, inputs, targets, originals = prep_batch(batch)
            outputs = model(times, inputs)
            predics = outputs[:, OBS_HORIZON:, task.targets.index]
            # display(outputs)
            # display(targets)
            loss = LOSS(targets, predics)
            y.append(targets.clone().detach().cpu())
            yhat.append(predics.clone().detach().cpu())
        if pbar.n == 5:
            break

    targets, predics = torch.cat(y, dim=0), torch.cat(yhat, dim=0)
    mask = torch.isnan(targets)
    targets[mask] = torch.tensor(0.0)
    predics[mask] = torch.tensor(0.0)
    # scale = 1/torch.mean(mask.to(dtype=torch.float32))
    # targets *= scale
    # predics *= scale
    return targets, predics

## Logging Utilities

In [23]:
from tsdm.logutils import compute_metrics


def log_all(i, model, writer, optimizer):
    kernel = model.system.kernel.clone().detach().cpu()
    log_kernel_information(i, writer, kernel, histograms=True)
    log_optimizer_state(i, writer, optimizer, histograms=True)


def log_hparams(i, writer, *, metric_dict, hparam_dict):
    hparam_dict |= {"epoch": i}
    metric_dict = add_prefix(metric_dict, "hparam")
    writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict)


print("WARMUP")

t = torch.randn(3, NUM_DIM).to(DEVICE)
x = torch.randn(3, 1, NUM_DIM).to(device=DEVICE)
y = model(t, x)
torch.linalg.norm(y).backward()
model.zero_grad()

In [24]:
RUN_START = tsdm.util.now()
CHECKPOINTDIR = Path(
    f"checkpoints/{MODEL.__name__}/{DATASET.name}/{RUN_NAME}/{RUN_START}"
)
CHECKPOINTDIR.mkdir(parents=True, exist_ok=True)
LOGGING_DIR = f"runs/{MODEL.__name__}/{DATASET.name}/{RUN_NAME}/{RUN_START}"
writer = SummaryWriter(LOGGING_DIR)

# Training

In [25]:
i = -1
epoch = 1

with torch.no_grad():
    # log optimizer state first !!!
    # log_optimizer_state(epoch, writer, optimizer, histograms=True)
    log_kernel_information(
        epoch, writer=writer, kernel=model.system.kernel, histograms=True
    )

    for key in ((0, "train"), (0, "test")):
        dataloader = EVALLOADERS[key]
        y, ŷ = get_all_preds(model, dataloader)
        assert torch.isfinite(y).all()
        log_metrics(
            epoch, writer=writer, metrics=metrics, targets=y, predics=ŷ, prefix=key[1]
        )

In [26]:
for g in optimizer.param_groups:
    g["lr"] = 0.0001

In [48]:
encoder = task.preprocessor

target_idx = ts.columns.get_loc("OD600")

In [50]:
target_encoder = encoder[-1].column_encoders[0][target_idx]

In [51]:
import pickle

with open("encoder.pickle", "wb") as file:
    pickle.dump(task.preprocessor, file)

with open("target_encoder.pickle", "wb") as file:
    pickle.dump(target_encoder, file)

In [27]:
for epoch in (epochs := trange(epoch, 100)):
    batching_time = perf_counter()
    for batch in (batches := tqdm(TRAINLOADER, leave=False)):
        batching_time = perf_counter() - batching_time
        i += 1
        # Optimization step
        model.zero_grad(set_to_none=True)
        times, inputs, targets, originals = prep_batch(batch)

        forward_time = perf_counter()
        outputs = model(times, inputs)
        forward_time = perf_counter() - forward_time
        predics = outputs[:, OBS_HORIZON:, task.targets.index]
        mask = torch.isnan(targets)
        targets[mask] = torch.tensor(0.0)
        predics[mask] = torch.tensor(0.0)
        mask = torch.isnan(originals)
        originals[mask] = torch.tensor(0.0)
        outputs[mask] = torch.tensor(0.0)

        if not FORECAST_ALL:
            # get rid of nan-values in the targets.
            loss = LOSS(targets, predics)
        else:
            loss = LOSS(originals, outputs)

        # # compensate NaN-Value with upscaling
        # scale = 1/torch.mean(mask.to(dtype=torch.float32))
        # targets *= scale
        # predics *= scale

        backward_time = perf_counter()
        loss.backward()
        backward_time = perf_counter() - backward_time

        optimizer.step()

        # batch logging
        with torch.no_grad():
            logging_time = time()
            if torch.any(torch.isnan(loss)):
                raise RuntimeError("NaN-value encountered!!")

            log_metrics(
                i,
                writer=writer,
                metrics=metrics,
                targets=targets.clone(),
                predics=predics.clone(),
                prefix="batch",
            )
            log_optimizer_state(i, writer=writer, optimizer=optimizer, prefix="batch")

            # lval = loss.clone().detach().cpu().numpy()
            # gval = grad_norm(list(model.parameters())).clone().detach().cpu().numpy()
            logging_time = time() - logging_time

        batches.set_postfix(
            # loss=f"{lval:.2e}",
            # gnorm=f"{gval:.2e}",
            Δt_forward=f"{forward_time:.1f}",
            Δt_backward=f"{backward_time:.1f}",
            Δt_logging=f"{logging_time:.1f}",
            Δt_batching=f"{batching_time:.1f}",
        )
        batching_time = perf_counter()

    with torch.no_grad():
        # log optimizer state first !!!
        log_optimizer_state(epoch, writer=writer, optimizer=optimzier, histograms=True)
        log_kernel_information(
            epoch, writer=writer, kernel=model.system.kernel, histograms=True
        )

        for key in ((0, "train"), (0, "test")):
            dataloader = EVALLOADERS[key]
            y, ŷ = get_all_preds(model, dataloader)
            metric_values = compute_metrics(metrics, targets=y, predics=ŷ)
            log_metrics(
                epoch,
                writer=writer,
                metrics=metrics,
                values=metric_values,
                prefix=key[1],
            )
            # log_hparams(epoch, writer, metric_dict=metric_values, hparam_dict=HPARAMS)

        # Model Checkpoint
        torch.jit.save(model, CHECKPOINTDIR.joinpath(f"{MODEL.__name__}-{epoch}"))
        torch.save(
            {
                "optimizer": optimizer,
                "epoch": epoch,
                "batch": i,
            },
            CHECKPOINTDIR.joinpath(f"{optimizer.__class__.__name__}-{epoch}"),
        )

In [None]:
raise StopIteration

# Post Training Analysis

In [None]:
buffers = dict(model.named_buffers())
set(buffers.keys())

In [None]:
timedeltas = model.timedeltas.detach().cpu()
xhat_pre = model.xhat_pre.detach().cpu()
xhat_post = model.xhat_post.detach().cpu()
zhat_pre = model.zhat_pre.detach().cpu()
zhat_post = model.zhat_post.detach().cpu()
xhat_pre.shape, xhat_post.shape, zhat_pre.shape, zhat_post.shape

## Relative size change xhat_pre ⟶ xhat_post

In [None]:
%matplotlib inline
plt.style.use("bmh")

BATCH_DIM, LEN, DIM = tuple(xhat_pre.shape)
n, m = model.input_size, model.hidden_size


def gmean(x, dim=(), p=2):
    """Geometric mean"""
    return torch.exp(torch.mean(torch.log(torch.abs(x) ** p), dim=dim) ** (1 / p))


predata = xhat_pre
postdata = xhat_post

xpretotalmag = torch.mean(
    torch.linalg.norm(xhat_pre, dim=-1) / torch.linalg.norm(xhat_pre[:, [0]], dim=-1),
    dim=0,
).squeeze()

xpsttotalmag = torch.mean(
    torch.linalg.norm(xhat_post, dim=-1) / torch.linalg.norm(xhat_post[:, [0]], dim=-1),
    dim=0,
).squeeze()

zpretotalmag = torch.mean(
    torch.linalg.norm(zhat_pre, dim=-1) / torch.linalg.norm(zhat_pre[:, [0]], dim=-1),
    dim=0,
).squeeze()

zpsttotalmag = torch.mean(
    torch.linalg.norm(zhat_post, dim=-1) / torch.linalg.norm(zhat_post[:, [0]], dim=-1),
    dim=0,
).squeeze()

xpremag = torch.mean(
    torch.linalg.norm(xhat_pre[..., 1:, :], dim=-1)
    / torch.linalg.norm(xhat_pre[..., :-1, :], dim=-1),
    dim=0,
)
xpstmag = torch.mean(
    torch.linalg.norm(xhat_post[..., 1:, :], dim=-1)
    / torch.linalg.norm(xhat_post[..., :-1, :], dim=-1),
    dim=0,
)
zpremag = torch.mean(
    torch.linalg.norm(zhat_pre[..., 1:, :], dim=-1)
    / torch.linalg.norm(zhat_pre[..., :-1, :], dim=-1),
    dim=0,
)
zpstmag = torch.mean(
    torch.linalg.norm(zhat_post[..., 1:, :], dim=-1)
    / torch.linalg.norm(zhat_post[..., :-1, :], dim=-1),
    dim=0,
)

system_mag = torch.linalg.norm(zhat_pre[:, 1:], dim=-1) / torch.linalg.norm(
    zhat_post[:, :-1], dim=-1
)
system_mag = torch.cat([torch.ones(BATCH_DIM, 1), system_mag], dim=-1)
combine_mag = torch.linalg.norm(zhat_post, dim=-1) / torch.linalg.norm(zhat_pre, dim=-1)
# system_mag = torch.cat([torch.ones(BATCH_DIM, 1), system_mag], dim=-1)
decoder_mag = gmean(xhat_pre, dim=-1) / gmean(zhat_pre, dim=-1)
filter_mag = gmean(xhat_post, dim=-1) / gmean(xhat_pre, dim=-1)
encoder_mag = gmean(zhat_post, dim=-1) / gmean(xhat_post, dim=-1)

filter_mag = torch.mean(filter_mag, dim=0)
system_mag = torch.mean(system_mag, dim=0)
combine_mag = torch.mean(combine_mag, dim=0)
decoder_mag = torch.mean(decoder_mag, dim=0)
encoder_mag = torch.mean(encoder_mag, dim=0)

fig, ax = plt.subplots(
    ncols=4, nrows=3, figsize=(12, 8), sharey="row", constrained_layout=True
)

ax[0, 0].semilogy(xpretotalmag)
ax[0, 0].set_title(r"Rel. Magnitude change $\hat{x}_0  \rightarrow \hat{x}_{t+1}  $")
ax[0, 1].semilogy(xpsttotalmag)
ax[0, 1].set_title(r"Rel. Magnitude change $\hat{x}_0' \rightarrow \hat{x}_{t+1}' $")
ax[0, 2].semilogy(zpretotalmag)
ax[0, 2].set_title(r"Rel. Magnitude change $\hat{z}_0  \rightarrow \hat{z}_{t+1}  $")
ax[0, 3].semilogy(zpsttotalmag)
ax[0, 3].set_title(r"Rel. Magnitude change $\hat{z}_0' \rightarrow \hat{z}_{t+1}' $")

ax[1, 0].semilogy(xpremag)
ax[1, 0].set_title(r"Rel. Magnitude change $\hat{x}_t  \rightarrow \hat{x}_{t+1}  $")
ax[1, 1].semilogy(xpstmag)
ax[1, 1].set_title(r"Rel. Magnitude change $\hat{x}_t' \rightarrow \hat{x}_{t+1}' $")
ax[1, 2].semilogy(zpremag)
ax[1, 2].set_title(r"Rel. Magnitude change $\hat{z}_t  \rightarrow \hat{z}_{t+1}  $")
ax[1, 3].semilogy(zpstmag)
ax[1, 3].set_title(r"Rel. Magnitude change $\hat{z}_t' \rightarrow \hat{z}_{t+1}' $")

ax[2, 0].semilogy(decoder_mag)
ax[2, 0].set_title(r"Rel. magnitude change $\hat{z}_t  \rightarrow \hat{x}_t$")
ax[2, 1].semilogy(filter_mag)
ax[2, 1].set_title(r"Relative magnitude change $\hat{x}_t  \rightarrow \hat{x}_t'$")
# ax[1, 2].semilogy(encoder_mag)
# ax[1, 2].set_title(r"Relative magnitude change $\hat{x}_t' \rightarrow \hat{z}_t'$")
ax[2, 2].semilogy(encoder_mag)
ax[2, 2].set_title(r"Rel. magnitude change $\hat{x}_t' \rightarrow \hat{z}_t'$")
ax[2, 3].semilogy(system_mag)
ax[2, 3].set_title(r"Rel. magnitude change $\hat{x}_t' \rightarrow \hat{z}_t'$")
# ax[2, 3].semilogy(combine_mag)
# ax[2, 3].set_title(r"Rel. magnitude change $\hat{z}_t \rightarrow \hat{z}_{t}'$")
# ax[2, 0].set_yscale("log")
fig.savefig(f"{RUN_NAME}_encoder_stats_post_training.pdf")

# distribution plots

In [None]:
xhat_pre_mean = torch.mean(xhat_pre, dim=-1).mean(dim=0)
xhat_pre_stdv = torch.std(xhat_pre, dim=-1).mean(dim=0)
xhat_post_mean = torch.mean(xhat_post, dim=-1).mean(dim=0)
xhat_post_stdv = torch.std(xhat_post, dim=-1).mean(dim=0)
zhat_pre_mean = torch.mean(zhat_pre, dim=-1).mean(dim=0)
zhat_pre_stdv = torch.std(zhat_pre, dim=-1).mean(dim=0)
zhat_post_mean = torch.mean(zhat_post, dim=-1).mean(dim=0)
zhat_post_stdv = torch.std(zhat_post, dim=-1).mean(dim=0)

tuples = [
    (r"$\hat{x}$", xhat_pre_mean, xhat_pre_stdv),
    (r"$\hat{x}'$", xhat_post_mean, xhat_post_stdv),
    (r"$\hat{z}$", zhat_pre_mean, zhat_pre_stdv),
    (r"$\hat{x}'$", zhat_post_mean, zhat_post_stdv),
]

S = np.arange(len(xhat_pre_mean));

In [None]:
fig, axes = plt.subplots(
    nrows=2, ncols=2, constrained_layout=True, figsize=(8, 5), sharex=True, sharey=True
)

for ax, (key, mean, std) in zip(axes.flatten(), tuples):
    color = next(ax._get_lines.prop_cycler)["color"]
    ax.plot(S, mean, color=color)
    ax.fill_between(S, mean + std, mean - std, alpha=0.3)
    ax.set_title(key)
    ax.set_yscale("symlog")

In [None]:
xhat_pre[0, 0]

In [None]:
xhat_pre[0, 1]

In [None]:
xhat_pre[0, 2]

In [None]:
xhat_pre[0, -1]

In [None]:
dummy = torch.randn(10_000, m, device="cuda")
dummy2 = model.encoder(dummy)
dummy1 = torch.linalg.norm(dummy, dim=-1) / m
dummy2 = torch.linalg.norm(dummy2, dim=-1) / m
chg = (dummy2 / dummy1).clone().detach().cpu().numpy()
plt.hist(chg, bins="auto");

In [None]:
expA = torch.matrix_exp(model.kernel)

for o in (-np.infty, -2, -1, 1, 2, np.infty, "fro", "nuc"):
    val = torch.linalg.matrix_norm(model.kernel, ord=o).item()
    val2 = torch.linalg.matrix_norm(expA, ord=o).item()
    o = str(o)
    print(f"{o=:6s}\t {val=:10.6f} \t {val2=:10.6f}")

In [None]:
from matplotlib import cm

mat = model.kernel.clone().detach().cpu()
mat = 0.5 + (mat - mat.mean()) / (6 * mat.std())
# mat = kernel.clip(0, 1)
colormap = cm.get_cmap("seismic")
mat = colormap(mat)
plt.imshow(mat)

# Profiling

In [None]:
from torch.profiler import ProfilerActivity, profile, record_function

In [None]:
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
) as prof:
    model(times, inputs)

In [None]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))