In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["PYTORCH_JIT"] = "1"

In [None]:
import tsdm
import torch
import pandas
import numpy as np
from datetime import datetime
from tqdm import trange, tqdm
from torch.utils.tensorboard import SummaryWriter
from torch import tensor, Tensor, jit
from tsdm.datasets import Electricity
from tsdm.encoders import time2float
import matplotlib.pyplot as plt

ℵ = np.inf

In [None]:
from linodenet.projections import symmetric, skew_symmetric
from tsdm.utils.data import SliceSampler
from torch.utils.data import BatchSampler, DataLoader


def now():
    return datetime.now().isoformat(timespec="seconds")


def plot_spectrum(kernel):
    eigs = torch.linalg.eigvals(kernel).detach().cpu()
    fig, ax = plt.subplots(figsize=(12, 6), tight_layout=True)
    ax.set_xlim([-2.5, +2.5])
    ax.set_ylim([-2.5, +2.5])
    ax.set_aspect("equal")
    ax.set_xlabel("real part")
    ax.set_ylabel("imag part")
    ax.scatter(eigs.real, eigs.imag)
    return fig


def symmpart(kernel):
    return torch.mean(symmetric(kernel) ** 2) / torch.mean(kernel**2)


def skewpart(kenerl):
    return torch.mean(skew_symmetric(kernel) ** 2) / torch.mean(kernel**2)


def collate_tensor(tensors: list[tensor]) -> tensor:
    return torch.stack(tensors, axis=0)

In [None]:
ds = Electricity.dataset

In [None]:
DEVICE = torch.device("cuda")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
DATASET = Electricity.dataset
NPTS, NDIM = DATASET.shape
SEQLEN = 48
PRD_HORIZON = 24
OBS_HORIZON = SEQLEN - PRD_HORIZON
BATCH_SIZE = 16

In [None]:
# preprocessing cf. NBEATS-paper
# resample hourly
ds = ds.resample(pandas.Timedelta("1h"), label="right").sum()
# remove first year
ds = ds.loc[pandas.Timestamp("2012-01-01") :]
ds_train, ds_test = ds.iloc[:-PRD_HORIZON], ds.iloc[-PRD_HORIZON:]  # 168=7*24
t_train, t_test = time2float(ds_train.index), time2float(ds_test.index)
t_train, t_test = t_train / t_train.max(), t_test / t_test.max()

In [None]:
train = torch.cat(
    [
        tensor(t_train, device=DEVICE, dtype=DTYPE).unsqueeze(-1),
        tensor(ds_train.values, device=DEVICE, dtype=DTYPE),
    ],
    axis=-1,
)

In [None]:
sampler = SliceSampler(train, slice_sampler=SEQLEN)
sampler = BatchSampler(sampler, batch_size=BATCH_SIZE, drop_last=True)

In [None]:
from linodenet.models import LinODEnet, LinODECell, LinODE
from tsdm.metrics.functional import nrmse, nd
from torch.optim import Adam


model = LinODEnet(NDIM, 512, embedding_type="concat")
model.to(device=DEVICE, dtype=DTYPE)

optimizer = Adam(model.parameters(), lr=0.001)

batch = collate_tensor(next(iter(sampler)))
t, x = batch[:, :, 0], batch[:, :, 1:]
writer = SummaryWriter(f"runs/LinODEnet/{now()}")

In [None]:
# batch = collate_tensor(next(iter(sampler)))
# t, x = batch[:, :, 0], batch[:, :, 1:]
# writer.add_graph(model, (t, x), verbose=True)

In [None]:
@jit.script
def grad_norm(tensors: list[Tensor]) -> Tensor:
    n = len(tensors)
    # initializing s this way instead of s=tensor(0) automatically gets the dtype and device correct
    s = torch.sum(tensors.pop().grad ** 2)
    for x in tensors:
        s += torch.sum(x.grad**2)
    return s / n

In [None]:
@jit.script
def prep(
    batch: Tensor, OBS_HORIZON: int = OBS_HORIZON, NAN: Tensor = NAN
) -> tuple[Tensor, Tensor, Tensor]:
    t, x = batch[:, :, 0], batch[:, :, 1:]
    t_obs, t_pred = t[:, :OBS_HORIZON], t[:, OBS_HORIZON:]
    x_obs = x.detach().clone()
    x_obs[:, OBS_HORIZON:, :] = NAN
    return t, x, x_obs

In [None]:
for batch in (pbar := tqdm(sampler)):
    batch = collate_tensor(batch)
    t, x, x_obs = prep(batch)
    x_hat = model(t, x_obs)
    loss = torch.mean(nd(x_hat, x))
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        pbar.set_postfix(loss=f"{loss:.2e}")
        kernel = model.system.kernel.detach().cpu()
        writer.add_scalar("train/grad", grad_norm(list(model.parameters())), pbar.n)
        writer.add_scalar("train/loss", loss, pbar.n)
        writer.add_scalar("loss/nrmse", torch.mean(nrmse(x_hat, x)), pbar.n)
        writer.add_histogram("kernel/histogram", model.system.kernel, pbar.n)
        writer.add_image("kernel/values", model.system.kernel, pbar.n, dataformats="HW")
        writer.add_figure("kernel/spectrum", plot_spectrum(kernel), pbar.n)
        writer.add_scalar("kernel/skewpart", skewpart(kernel), pbar.n)
        writer.add_scalar("kernel/symmpart", symmpart(kernel), pbar.n)
        writer.add_scalar("kernel/det", torch.linalg.det(kernel), pbar.n)
        writer.add_scalar("kernel/rank", torch.linalg.matrix_rank(kernel), pbar.n)
        writer.add_scalar("kernel/trace", torch.trace(kernel), pbar.n)
        writer.add_scalar("kernel/cond", torch.linalg.cond(kernel), pbar.n)
        writer.add_scalar("kernel/logdet", torch.linalg.slogdet(kernel)[-1], pbar.n)
        writer.add_scalar(
            "kernel/norm-fro", torch.linalg.matrix_norm(kernel, ord="fro"), pbar.n
        )
        writer.add_scalar(
            "kernel/norm-nuc", torch.linalg.matrix_norm(kernel, ord="nuc"), pbar.n
        )
        writer.add_scalar(
            "kernel/norm-∞", torch.linalg.matrix_norm(kernel, ord=-ℵ), pbar.n
        )
        writer.add_scalar(
            "kernel/norm-2", torch.linalg.matrix_norm(kernel, ord=-2), pbar.n
        )
        writer.add_scalar(
            "kernel/norm-1", torch.linalg.matrix_norm(kernel, ord=-1), pbar.n
        )
        writer.add_scalar(
            "kernel/norm+1", torch.linalg.matrix_norm(kernel, ord=+1), pbar.n
        )
        writer.add_scalar(
            "kernel/norm+2", torch.linalg.matrix_norm(kernel, ord=+2), pbar.n
        )
        writer.add_scalar(
            "kernel/norm+∞", torch.linalg.matrix_norm(kernel, ord=+ℵ), pbar.n
        )

# plots

- scatter plot spectrum
- ratio norm(A)/norm(A+Aᵀ/2)  and  norm(A)/norm(A-Aᵀ/2) (measure symmetry/asymmetry)


In [None]:
@jit.script
def f(x: list[int]) -> int:
    return sum(x)


@jit.script
def f(x: tuple[int, ...]) -> int:
    return sum(x)

In [None]:
import torch
from typing import Optional, Union


def scaled_norm(
    x: torch.Tensor,
    dim: Optional[Union[int, tuple[int, ...]]] = None,
    p: float = 2.0,
    keepdim: bool = False,
) -> torch.Tensor:
    return torch.mean(x**p, dim=dim, keepdim=keepdim) ** (1 / p)

In [None]:
def scaled_norm(
    xs: tuple[torch.Tensor, ...],
    p: float = 2.0,
) -> torch.Tensor:

    sum(torch.mean(x**p, keepdim=false) for x in xs) ** 1 / p

In [None]:
from typing import Union, Optional


@jit.script
def _torch_scaled_norm(
    x: Tensor,
    axis: tuple[int, ...] == (),
    p: float = 2,
    keepdims: bool = False,
) -> Tensor:
    axis = () if axis is None else axis

    #     if not _torch_is_float_dtype(x):
    #         x = x.to(dtype=torch.float)
    #     x = torch.abs(x)

    #     if p == 0:
    #         # https://math.stackexchange.com/q/282271/99220
    #         return torch.exp(torch.mean(torch.log(x), dim=axis, keepdim=keepdims))
    #     if p == 1:
    #         return torch.mean(x, dim=axis, keepdim=keepdims)
    #     if p == 2:
    #         return torch.sqrt(torch.mean(x ** 2, dim=axis, keepdim=keepdims))
    #     if p == float("inf"):
    #         return torch.amax(x, dim=axis, keepdim=keepdims)
    #     # other p
    return torch.mean(x**p, dim=axis, keepdim=keepdims) ** (1 / p)

In [None]:
@jit.script
def summed(x: list[Tensor]) -> Tensor:
    return torch.sum(torch.cat(x))

In [None]:
multi_scaled_norm([torch.randn(2, 5) for _ in range(5)])

In [None]:
torch.randn(1, 2, 3, 4).numel()

In [None]:
@jit.script
def torch_scaled_norm(
    x: Tensor,
    axis: list[int],
    p: float = 2.0,
) -> Tensor:
    return torch.mean(x**p, dim=axis) ** (1 / p)

In [None]:
@jit.script
def torch_scaled_norm(
    x: Tensor,
    p: float = 2,
    #     axis: Optional[Union[int, tuple[int, ...]]] = None,
    axis: list[int] = (),
    keepdims: bool = False,
) -> Tensor:
    #     axis = () if axis is None else axis

    #     if not _torch_is_float_dtype(x):
    #         x = x.to(dtype=torch.float)
    #     x = torch.abs(x)

    #     if p == 0:
    #         # https://math.stackexchange.com/q/282271/99220
    #         return torch.exp(torch.mean(torch.log(x), dim=axis, keepdim=keepdims))
    #     if p == 1:
    #         return torch.mean(x, dim=axis, keepdim=keepdims)
    #     if p == 2:
    #         return torch.sqrt(torch.mean(x ** 2, dim=axis, keepdim=keepdims))
    #     if p == float("inf"):
    #         return torch.amax(x, dim=axis, keepdim=keepdims)
    # other p
    return torch.mean(x**p, dim=axis, keepdim=keepdims)

In [None]:
import tsdm

In [None]:
tsdm.utils.scaled_norm(torch.randn(2, 3, 4, 5))

In [None]:
%%timeit
grad_norm = sum(w.grad.detach().norm(p=2) for w in model.parameters())

In [None]:
%%timeit
grad_norm = sum(torch.sum(w.grad**2) for w in model.parameters())

In [None]:
from typing import Iterable

In [None]:
@jit.script
def average_grad_norm(tensors: list[Tensor]) -> Tensor:
    s = torch.tensor(0, device=torch.device("cuda"), dtype=torch.float32)
    for x in tensors:
        s += torch.sum(x.grad**2)
    return s / len(tensors)

In [None]:
@jit.script
def m_norm(tensors: list[Tensor]) -> Tensor:
    n = len(tensors)
    # initializing s this way instead of s=tensor(0) automatically gets the dtype and device correct
    s = torch.sum(tensors.pop() ** 2)
    for x in tensors:
        s += torch.sum(x**2)
    return s / n

In [None]:
%%timeit
m_norm([x.grad for x in model.parameters()])

In [None]:
@jit.script
def grad_norm(tensors: list[Tensor]) -> Tensor:
    n = len(tensors)
    # initializing s this way instead of s=tensor(0) automatically gets the dtype and device correct
    s = torch.sum(tensors.pop().grad ** 2)
    for x in tensors:
        s += torch.sum(x.grad**2)
    return s / n

In [None]:
%%timeit
average_grad_norm(list(model.parameters()))

In [None]:
gen_MSE([tensor(2) for _ in range(3)])