In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging
import os
from pathlib import Path
from time import perf_counter, time
from typing import Any, NamedTuple

import numpy as np

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# logging.basicConfig(level=logging.INFO)

In [2]:
import torch
import torchinfo

BATCH_SIZE = 128
TARGET = "OD600"
SPLIT = 0

RUN_NAME = f"{TARGET}-{SPLIT}-More_params"  # | input("enter name for run")

In [27]:
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import pandas
from linodenet.models import LinODE, LinODECell, LinODEnet
from linodenet.models.filters import SequentialFilter
from linodenet.projections.functional import skew_symmetric, symmetric
from numpy.typing import NDArray
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
from torch import Tensor, jit, nn, tensor
from torch.optim import SGD, Adam, AdamW
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from tqdm import tqdm, trange
from typing import Union, Optional
from collections.abc import *

import tsdm
from tsdm.datasets import DATASETS
from tsdm.encoders.functional import time2float
from tsdm.encoders.modular import *
from tsdm.logutils import (
    log_kernel_information,
    log_metrics,
    log_model_state,
    log_optimizer_state,
)
from tsdm.losses import LOSSES
from tsdm.random.samplers import *
from tsdm.tasks import KIWI_FINAL_PRODUCT
from tsdm.utils import grad_norm, multi_norm
from tsdm.utils.strings import *

np.set_printoptions(4, linewidth=80)

In [4]:
# Disable benchmarking for variable sized input
torch.backends.cudnn.benchmark = True

# Initialize Task

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32

task = KIWI_FINAL_PRODUCT(
    train_batch_size=BATCH_SIZE,
    eval_batch_size=2048,
    target=TARGET,
)

DATASET = task.dataset
ts = task.timeseries
md = task.metadata
NUM_PTS, NUM_DIM = ts.shape

In [6]:
import pandas as pd
from pandas.core.indexes.frozen import FrozenList

task = KIWI_FINAL_PRODUCT()
ts = task.timeseries.sort_index(axis="index").sort_index(axis="columns")
channel_freq = pd.notna(ts).mean().sort_values()
fast_channels = FrozenList(channel_freq[channel_freq >= 0.1].index)
slow_channels = FrozenList(channel_freq[channel_freq < 0.1].index)
FAST = ts[fast_channels].dropna(how="all")
SLOW = ts[slow_channels].dropna(how="all")
groups = {"fast": fast_channels, "slow": slow_channels}

In [28]:
class FrameSplitter(BaseEncoder, Mapping):
    r"""Split a DataFrame into multiple groups.

    The special value ``...`` (:class:`Ellipsis`) can be used to indicate
    that all other columns belong to this group.

    This function can be used on index columns as well.
    """

    column_columns: Index
    column_dtypes: Series
    column_indices: list[int]

    index_columns: Index
    index_dtypes = Series
    index_indices: list[int]

    # FIXME: Union[types.EllipsisType, set[Hashable]] in 3.10
    groups: dict[Hashable, Union[Hashable, list[Hashable]]]
    group_indices: dict[Hashable, list[int]]

    indices: dict[Hashable, list[int]]
    has_ellipsis: bool = False
    ellipsis: Optional[Hashable] = None

    permutation: list[int]
    inverse_permutation: list[int]

    # @property
    # def names(self) -> set[Hashable]:
    #     r"""Return the union of all groups."""
    #     sets: list[set] = [
    #         set(obj) if isinstance(obj, Iterable) else {Ellipsis}
    #         for obj in self.groups.values()
    #     ]
    #     union: set[Hashable] = set.union(*sets)
    #     assert sum(len(u) for u in sets) == len(union), "Duplicate columns!"
    #     return union

    def __init__(self, groups: Iterable[Hashable]) -> None:
        super().__init__()

        if not isinstance(groups, Mapping):
            groups = dict(enumerate(groups))

        self.groups = {}
        for key, obj in groups.items():
            if obj is Ellipsis:
                self.groups[key] = obj
                self.ellipsis = key
                self.has_ellipsis = True
            elif isinstance(obj, str) or not isinstance(obj, Iterable):
                self.groups[key] = [obj]
            else:
                self.groups[key] = list(obj)

    def __repr__(self):
        r"""Return a string representation of the object."""
        return repr_mapping(self)

    def __len__(self):
        r"""Return the number of groups."""
        return len(self.groups)

    def __iter__(self):
        r"""Iterate over the groups."""
        return iter(self.groups)

    def __getitem__(self, item):
        r"""Return the group."""
        return self.groups[item]

    def fit(self, data: DataFrame, /) -> None:
        r"""Fit the encoder."""
        index = data.index.to_frame()
        self.column_dtypes = data.dtypes
        self.column_columns = data.columns
        self.index_columns = index.columns
        self.index_dtypes = index.dtypes

        assert not (
            j := set(self.index_columns) & set(self.column_columns)
        ), f"index columns and data columns must be disjoint {j}!"

        data = data.copy().reset_index()

        def get_idx(cols: Any) -> list[int]:
            return [data.columns.get_loc(i) for i in cols]

        self.indices: dict[Hashable, int] = dict(enumerate(data.columns))
        self.group_indices: dict[Hashable, list[int]] = {}
        self.column_indices = get_idx(self.column_columns)
        self.index_indices = get_idx(self.index_columns)

        # replace ellipsis indices
        if self.has_ellipsis:
            # FIXME EllipsisType in 3.10
            fixed_cols = set().union(
                *(
                    set(cols)  # type: ignore[arg-type]
                    for cols in self.groups.values()
                    if cols is not Ellipsis
                )
            )
            ellipsis_columns = [c for c in data.columns if c not in fixed_cols]
            self.groups[self.ellipsis] = ellipsis_columns

        # set column indices
        self.permutation = []
        for group, columns in self.groups.items():
            if columns is Ellipsis:
                continue
            self.group_indices[group] = get_idx(columns)
            self.permutation += self.group_indices[group]
        self.inverse_permutation = np.argsort(self.permutation).tolist()
        # sorted(p.copy(), key=p.__getitem__)

    def encode(self, data: DataFrame, /) -> tuple[DataFrame, ...]:
        r"""Encode the data."""
        # copy the frame and add index as columns.
        data = data.reset_index()  # prepend index as columns!
        data_columns = set(data.columns)

        assert data_columns <= set(self.indices.values()), (
            f"Unknown columns {data_columns - set(self.indices)}."
            "If you want to encode unknown columns add a group ``...`` (Ellipsis)."
        )

        encoded = []
        for columns in self.groups.values():
            encoded.append(data[columns].squeeze(axis="columns"))
        return tuple(encoded)

    def decode(self, data: tuple[DataFrame, ...], /) -> DataFrame:
        r"""Decode the data."""
        data = tuple(DataFrame(x) for x in data)
        joined = pd.concat(data, axis="columns")

        # unshuffle the columns, restoring original order
        joined = joined.iloc[..., self.inverse_permutation]

        # Assemble the columns
        columns = joined.iloc[..., self.column_indices]
        columns.columns = self.column_columns
        columns = columns.astype(self.column_dtypes)
        columns = columns.squeeze(axis="columns")

        # assemble the index
        index = joined.iloc[..., self.index_indices]
        index.columns = self.index_columns
        index = index.astype(self.index_dtypes)
        index = index.squeeze(axis="columns")

        if isinstance(index, Series):
            decoded = columns.set_index(index)
        else:
            decoded = columns.set_index(MultiIndex.from_frame(index))
        return decoded

In [7]:
enc = FrameSplitter({"a": fast_channels, "b": slow_channels})

In [19]:
torch.tensor(ts.DOT.values)

In [16]:
torch.tensor(ts.DOT)

In [12]:
enc.encode(ts)[0]

In [None]:
enc = FrameEncoder(
    column_encoders={
        fast_channels: Standardizer(),
        slow_channels: MinMaxScaler(),
    },
    index_encoders={
        "run_id": IntEncoder(),
        "experiment_id": IntEncoder(),
        "measurement_time": TimeDeltaEncoder(),
    },
)

In [16]:
encoder = ChainedEncoder(
    #
    TensorEncoder(),
    FrameSplitter(["value", "measurement_time", ...]),
    # FrameSplitter([...]),
    TripletEncoder(),
    FrameEncoder(
        Standardizer(),
        index_encoders=MinMaxScaler() @ TimeDeltaEncoder(unit="s"),
    ),
)

## Encoded SetFuncTS

- We partition the channels $P = \{pₖ∣k=1:K\}$, where $pₖ≠∅$, $pᵢ∩pⱼ = ∅$ and $⋃ₖpₖ = \{1,…, n\}$

- Model receives separate Tensors per group, encode them into shared dimension.

- Special attention to the case when a groups has a single channel: do we give the useless one-hot?

- When using weight sharing, all indicator variables need to be passed!



In [12]:
demo = ts.reset_index([0, 1], drop=True)

In [9]:
encoder.fit(demo)

In [10]:
encoded = encoder.encode(demo)

In [11]:
decoded = encoder.decode(encoded)

In [12]:
# pandas.testing.assert_frame_equal(demo, decoded, atol=1e-5)

In [13]:
encoder[-1].column_encoders[0]

## Define Batching Function

In [14]:
target_index = ts.columns.get_loc(task.target)
target_encoder = encoder[-1].column_encoders[target_index]

In [15]:
class Batch(NamedTuple):
    timeseries: list[Tensor]
    targets: NDArray
    encoded_targets: NDArray


@torch.no_grad()
def mycollate(samples: list[tuple]) -> tuple[list[Tensor], NDArray, NDArray]:
    timeseries = []
    targets = []
    encoded_targets = []

    for idx, (ts_data, md_data), target, originals in samples:
        timeseries.append(encoder.encode(ts_data))
        targets.append(target)
        encoded_targets.append(target_encoder.encode(target))

    # timeseries = torch.cat(timeseries)
    targets = np.stack(targets)
    encoded_targets = torch.tensor(encoded_targets)

    return Batch(timeseries, targets, encoded_targets)

In [16]:
dloader = task.get_dataloader((SPLIT, "train"), batch_size=6, shuffle=True)
batch = next(iter(dloader))
sample = batch[0]
mycollate(batch)

## Initialize Loss & Metrics

In [17]:
LOSS = task.test_metric.to(device=DEVICE)
metrics = {key: jit.script(LOSSES[key]()) for key in ("RMSE", "MSE", "MAE")}

## Initialize DataLoaders

In [18]:
task.final_product_times

In [19]:
md["target_value"]

In [20]:
TRAINLOADER = task.get_dataloader(
    (SPLIT, "train"),
    batch_size=BATCH_SIZE,
    collate_fn=mycollate,
    pin_memory=True,
    drop_last=True,
    shuffle=True,
    num_workers=8,
    # num_workers=os.cpu_count() // 4,
    persistent_workers=True,
)


EVALLOADER = task.get_dataloader(
    (SPLIT, "test"),
    batch_size=BATCH_SIZE,
    collate_fn=mycollate,
    pin_memory=True,
    drop_last=False,
    shuffle=False,
    num_workers=8,
    # num_workers=os.cpu_count() // 4,
    persistent_workers=True,
)

## Hyperparamters

In [21]:
OPTIMIZER_CONFIG = {
    "__name__": "AdamW",
    "lr": 0.001,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0.001,
    "amsgrad": False,
}

LR_SCHEDULER_CONFIG = {
    "__name__": "ReduceLROnPlateau",
    "mode": "min",
    # (str) – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
    "factor": 0.1,
    # (float) – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
    "patience": 10,
    # (int) – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
    "threshold": 0.0001,
    # (float) – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
    "threshold_mode": "rel",
    # (str) – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
    "cooldown": 0,
    # (int) – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
    "min_lr": 1e-08,
    # (float or list) – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
    "eps": 1e-08,
    # (float) – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
    "verbose": True
    # (bool) – If True, prints a message to stdout for each update. Default: False.
}

## Model

In [22]:
from tsdm.models import SetFuncTS

MODEL = SetFuncTS
model = MODEL(17, 1, latent_size=256, dim_keys=128, dim_vals=128, dim_deepset=128)
model.to(device=DEVICE, dtype=DTYPE)
summary(model)

### Warmup - test forward / backward pass

In [23]:
batch = next(iter(TRAINLOADER))
y = model.forward_batch(batch.timeseries)
torch.linalg.norm(y).backward()
model.zero_grad()

## Initalize Optimizer

In [24]:
from tsdm.optimizers import LR_SCHEDULERS, OPTIMIZERS
from tsdm.utils import initialize_from

In [25]:
OPTIMIZER_CONFIG |= {"params": model.parameters()}
optimizer = initialize_from(OPTIMIZERS, **OPTIMIZER_CONFIG)

## Load Existing Model instead

In [26]:
# # load existing model & optimizer
# from pathlib import Path

# model_path = Path(
#     "/home/rscholz/Projects/KIWI/tsdm/examples/"
#     "checkpoints/SetFuncTS/KIWI_RUNS/OD600-0-More_params/2022-02-07T00:55:25"
# )
# model_name = "SetFuncTS"
# optim_name = "AdamW"
# model_versions = {}
# optim_versions = {}

# for file in model_path.iterdir():
#     name, version = file.stem.split("-")
#     if name == model_name:
#         model_versions[int(version)] = file
#     if name == optim_name:
#         optim_versions[int(version)] = file

# MODEL = model_versions[max(model_versions)]
# model = jit.load(MODEL)
# MODEL = SetFuncTS
# OPTIM = optim_versions[max(optim_versions)]
# optim = torch.load(OPTIM)

# epoch = optim["epoch"]
# batch_num = optim["batch"]
# optim = optim["optimizer"]

## Logging Utilities

In [27]:
from tsdm.logutils import compute_metrics

# def log_all(i, model, writer, optimizer):
#     kernel = model.system.kernel.clone().detach().cpu()
#     log_kernel_information(i, writer, kernel, histograms=True)
#     log_optimizer_state(i, writer, optimizer, histograms=True)


def log_hparams(i, writer, *, metric_dict, hparam_dict):
    hparam_dict |= {"epoch": i}
    metric_dict = add_prefix(metric_dict, "hparam")
    writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict)

In [28]:
RUN_START = tsdm.utils.now()
CHECKPOINTDIR = Path(
    f"checkpoints/{MODEL.__name__}/{DATASET.name}/{RUN_NAME}/{RUN_START}"
)
CHECKPOINTDIR.mkdir(parents=True, exist_ok=True)
LOGGING_DIR = f"runs/{MODEL.__name__}/{DATASET.name}/{RUN_NAME}/{RUN_START}"
writer = SummaryWriter(LOGGING_DIR)

# Training

In [29]:
@torch.no_grad()
def get_all_predictions(model, dataloader):
    ys = []
    yhats = []
    for batch in tqdm(dataloader, leave=False):
        # ts = batch.timeseries
        # inputs = [(t.to(device=DEVICE),v.to(device=DEVICE), m.to(device=DEVICE)) for t,v,m in ts]
        # yhats.append(model.batch_forward(inputs))
        yhats.append(model.forward_batch(batch.timeseries))
        ys.append(batch.encoded_targets.to(device=DEVICE))
    y = torch.cat(ys).cpu().numpy()
    yhat = torch.cat(yhats).cpu().numpy()
    y = torch.tensor(target_encoder.decode(y))
    yhat = torch.tensor(target_encoder.decode(yhat))
    return y, yhat

In [30]:
i = 0  # batch_num
epoch = 0  # epoch

with torch.no_grad():
    for key, dloader in {"train": TRAINLOADER, "test": EVALLOADER}.items():
        y, ŷ = get_all_predictions(model, dloader)
        assert torch.isfinite(y).all()
        log_metrics(
            epoch, writer=writer, metrics=metrics, targets=y, predics=ŷ, prefix=key
        )

In [None]:
for epoch in (epochs := range(epoch, epoch + 1000)):
    if epoch == 1000:
        for g in optimizer.param_groups:
            g["lr"] = 0.0001
    batching_time = perf_counter()
    for batch in (batches := tqdm(TRAINLOADER, leave=False)):
        batching_time = perf_counter() - batching_time
        i += 1
        # Optimization step
        model.zero_grad(set_to_none=True)
        targets = batch.encoded_targets.to(device=DEVICE)

        # forward
        forward_time = perf_counter()
        predics = model.forward_batch(batch.timeseries)
        loss = LOSS(targets, predics)
        forward_time = perf_counter() - forward_time

        # backward
        backward_time = perf_counter()
        loss.backward()
        backward_time = perf_counter() - backward_time

        # step
        optimizer.step()

        # batch logging
        with torch.no_grad():
            logging_time = time()
            if torch.any(~torch.isfinite(loss)):
                raise RuntimeError("NaN/INF-value encountered!!")

            log_metrics(
                i,
                writer=writer,
                metrics=metrics,
                targets=targets.clone(),
                predics=predics.clone(),
                prefix="batch",
            )
            log_optimizer_state(i, writer=writer, optimizer=optimizer, prefix="batch")

            # lval = loss.clone().detach().cpu().numpy()
            # gval = grad_norm(list(model.parameters())).clone().detach().cpu().numpy()
            logging_time = time() - logging_time

        batches.set_postfix(
            # loss=f"{lval:.2e}",
            # gnorm=f"{gval:.2e}",
            epoch=epoch,
            Δt_forward=f"{forward_time:.1f}",
            Δt_backward=f"{backward_time:.1f}",
            Δt_logging=f"{logging_time:.1f}",
            Δt_batching=f"{batching_time:.1f}",
        )
        batching_time = perf_counter()

    with torch.no_grad():
        # end-of-epoch logging
        for key, dloader in {"train": TRAINLOADER, "test": EVALLOADER}.items():
            y, ŷ = get_all_predictions(model, dloader)
            assert torch.isfinite(y).all()
            log_metrics(
                epoch, writer=writer, metrics=metrics, targets=y, predics=ŷ, prefix=key
            )

        # Model Checkpoint
        torch.jit.save(model, CHECKPOINTDIR.joinpath(f"{MODEL.__name__}-{epoch}"))
        torch.save(
            {
                "optimizer": optimizer,
                "optimizer_state": optimizer.state_dict(),
                "epoch": epoch,
                "batch": i,
            },
            CHECKPOINTDIR.joinpath(f"{optimizer.__class__.__name__}-{epoch}"),
        )

In [None]:
predics

In [None]:
raise StopIteration

# Post Training Analysis

# Profiling

In [None]:
from torch.profiler import ProfilerActivity, profile, record_function

In [None]:
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
) as prof:
    model(times, inputs)

In [None]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))