In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging
import os
from pathlib import Path
from time import perf_counter, time
from typing import Any, NamedTuple

import numpy as np

# enable JIT compilation - must be done before loading torch!
os.environ["PYTORCH_JIT"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# logging.basicConfig(level=logging.INFO)

In [2]:
import torch
import torchinfo

BATCH_SIZE = 128
TARGET = "OD600"
SPLIT = 0

RUN_NAME = f"CONVNET-ReZero-{TARGET}-{SPLIT}"  # | input("enter name for run")

In [3]:
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import pandas
import torch
import pandas as pd
from pandas.core.indexes.frozen import FrozenList

from linodenet.models import LinODE, LinODECell, LinODEnet
from linodenet.models.filters import SequentialFilter
from linodenet.projections.functional import skew_symmetric, symmetric
from numpy.typing import NDArray
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
from torch import Tensor, jit, nn, tensor
from torch.optim import SGD, Adam, AdamW
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from tqdm import tqdm, trange
from tsdm.logutils import compute_metrics

import tsdm
from tsdm.datasets import DATASETS
from tsdm.encoders.functional import time2float
from tsdm.encoders.modular import *
from tsdm.logutils import (
    log_kernel_information,
    log_metrics,
    log_model_state,
    log_optimizer_state,
)
from tsdm.losses import LOSSES
from tsdm.random.samplers import *
from tsdm.tasks import KIWI_FINAL_PRODUCT
from tsdm.utils import grad_norm, multi_norm
from tsdm.utils.strings import *

np.set_printoptions(4, linewidth=80)

In [4]:
# Disable benchmarking for variable sized input
torch.backends.cudnn.benchmark = True

# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Initialize Task

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32

task = KIWI_FINAL_PRODUCT(
    train_batch_size=BATCH_SIZE,
    eval_batch_size=2048,
    target=TARGET,
)

DATASET = task.dataset
ts = task.timeseries
md = task.metadata
NUM_PTS, NUM_DIM = ts.shape

In [6]:
ts, md = task.splits[SPLIT, "train"]

In [7]:
channel_freq = pd.notna(ts).mean().sort_values()
fast_channels = FrozenList(channel_freq[channel_freq >= 0.1].index)
slow_channels = FrozenList(channel_freq[channel_freq < 0.1].index)
FAST = ts[fast_channels].dropna(how="all")
SLOW = ts[slow_channels].dropna(how="all")
groups = {"slow": slow_channels, "fast": fast_channels}
FAST_DIM = len(fast_channels)
SLOW_DIM = len(slow_channels)
fast_channels, slow_channels

In [8]:
encoder = ChainedEncoder(
    TensorEncoder(),
    ValueEncoder("float32") | ValueEncoder("float32"),
    TripletEncoder() | TripletEncoder(),
    FrameSplitter(groups, dropna=True),
    FrameEncoder(
        index_encoders={
            "measurement_time": MinMaxScaler() @ TimeDeltaEncoder(unit="s")
        },
    ),
    Standardizer(),
)

In [9]:
encoder.fit(ts.reset_index([0, 1], drop=True))
target_idx = task.timeseries.columns.get_loc(task.target)
target_encoder = TensorEncoder() @ encoder[-1][target_idx]

In [10]:
encoded = encoder.encode(ts.loc[439, 15325])
decoded = encoder.decode(encoded)
pd.testing.assert_frame_equal(ts.loc[439, 15325], decoded, atol=1e-3)

## Define Batching Function

In [11]:
class Batch(NamedTuple):
    idx: list
    timeseries: list[Tensor]
    targets: NDArray
    encoded_targets: NDArray
    ts_data: Any


@torch.no_grad()
def mycollate(samples: list[tuple]) -> tuple[list[Tensor], NDArray, NDArray]:
    index = []
    timeseries = []
    targets = []
    encoded_targets = []
    originals = []

    for idx, (ts_data, md_data), target, _ in samples:
        index.append(idx)  # timeseries.append(ts_data)
        timeseries.append(encoder.encode(ts_data))
        targets.append(target)
        encoded_targets.append(target_encoder.encode(target))
        originals.append(ts_data)

    # timeseries = torch.cat(timeseries)
    targets = np.stack(targets)
    encoded_targets = torch.tensor(encoded_targets)

    return Batch(idx, timeseries, targets, encoded_targets, originals)

In [12]:
dloader = task.get_dataloader((SPLIT, "train"), batch_size=128, shuffle=False)
batch = next(iter(dloader))
mycollate(batch).timeseries

## Initialize Loss & Metrics

In [13]:
LOSS = task.test_metric.to(device=DEVICE)
metrics = {key: jit.script(LOSSES[key]()) for key in ("RMSE", "MSE", "MAE")}

## Initialize DataLoaders

In [14]:
TRAINLOADER = task.get_dataloader(
    (SPLIT, "train"),
    batch_size=BATCH_SIZE,
    collate_fn=mycollate,
    pin_memory=True,
    drop_last=True,
    shuffle=True,
    num_workers=8,
    # num_workers=os.cpu_count() // 4,
    persistent_workers=True,
)


EVALLOADER = task.get_dataloader(
    (SPLIT, "test"),
    batch_size=BATCH_SIZE,
    collate_fn=mycollate,
    pin_memory=True,
    drop_last=False,
    shuffle=False,
    num_workers=8,
    # num_workers=os.cpu_count() // 4,
    persistent_workers=True,
)

## Hyperparamters

In [15]:
OPTIMIZER_CONFIG = {
    "__name__": "AdamW",
    "lr": 0.001,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0.001,
    "amsgrad": False,
}

LR_SCHEDULER_CONFIG = {
    "__name__": "ReduceLROnPlateau",
    "mode": "min",
    # (str) – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
    "factor": 0.1,
    # (float) – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
    "patience": 10,
    # (int) – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
    "threshold": 0.0001,
    # (float) – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
    "threshold_mode": "rel",
    # (str) – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
    "cooldown": 0,
    # (int) – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
    "min_lr": 1e-08,
    # (float or list) – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
    "eps": 1e-08,
    # (float) – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
    "verbose": True
    # (bool) – If True, prints a message to stdout for each update. Default: False.
}

## Model

In [16]:
from tsdm.encoders.modular.torch import Time2Vec
from tsdm.models.generic import MLP
from tsdm.models.generic.rezero import ReZero, ReZeroMLP
from tsdm.models.generic.conv1d import ConvBlock
from tsdm.models.set_function_for_timeseries import GroupedSetFuncTS


dim_time = 6
dim_conv = FAST_DIM + 2 + dim_time - 1

slow_encoder = ReZeroMLP(SLOW_DIM + 2 + dim_time - 1, 16)

fast_encoder = nn.Sequential(
    # nn.Conv1d(
    #     16, 16, kernel_size=8, padding="same", padding_mode="replicate"
    # ),
    ReZero(ConvBlock(dim_conv), ConvBlock(dim_conv), ConvBlock(dim_conv)),
    nn.AvgPool1d(16),
)

MODEL = GroupedSetFuncTS
model = GroupedSetFuncTS(
    input_size=16,
    output_size=1,
    fast_encoder=fast_encoder,
    slow_encoder=slow_encoder,
    latent_size=128,
    dim_keys=64,
    dim_vals=64,
    dim_deepset=64,
    dim_time=dim_time,
)
model.to(device=DEVICE, dtype=DTYPE)
summary(model)

### Warmup - test forward / backward pass

In [17]:
data = encoder.encode(ts.loc[439, 15325].iloc[:100])
model.zero_grad()
y = model(*data)
y.norm().backward()

## Initalize Optimizer

In [18]:
from tsdm.optimizers import LR_SCHEDULERS, OPTIMIZERS
from tsdm.utils import initialize_from

OPTIMIZER_CONFIG |= {"params": model.parameters()}
optimizer = initialize_from(OPTIMIZERS, **OPTIMIZER_CONFIG)

## Load Existing Model instead

In [19]:
# # load existing model & optimizer
# from pathlib import Path

# model_path = Path(
#     "/home/rscholz/Projects/KIWI/tsdm/examples/"
#     "checkpoints/SetFuncTS/KIWI_RUNS/OD600-0-More_params/2022-02-07T00:55:25"
# )
# model_name = "SetFuncTS"
# optim_name = "AdamW"
# model_versions = {}
# optim_versions = {}

# for file in model_path.iterdir():
#     name, version = file.stem.split("-")
#     if name == model_name:
#         model_versions[int(version)] = file
#     if name == optim_name:
#         optim_versions[int(version)] = file

# MODEL = model_versions[max(model_versions)]
# model = jit.load(MODEL)
# MODEL = SetFuncTS
# OPTIM = optim_versions[max(optim_versions)]
# optim = torch.load(OPTIM)

# epoch = optim["epoch"]
# batch_num = optim["batch"]
# optim = optim["optimizer"]

## Logging Utilities

In [20]:
@torch.no_grad()
def get_all_predictions(model, dataloader):
    ys = []
    yhats = []
    for batch in tqdm(dataloader, leave=False):
        # ts = batch.timeseries
        # inputs = [(t.to(device=DEVICE),v.to(device=DEVICE), m.to(device=DEVICE)) for t,v,m in ts]
        # yhats.append(model.batch_forward(inputs))
        yhats.append(model.forward_batch(batch.timeseries))
        ys.append(batch.encoded_targets.to(device=DEVICE))
    y = torch.cat(ys)
    yhat = torch.cat(yhats)
    y = torch.tensor(target_encoder.decode(y))
    yhat = torch.tensor(target_encoder.decode(yhat))
    return y, yhat

In [21]:
# def log_all(i, model, writer, optimizer):
#     kernel = model.system.kernel.clone().detach().cpu()
#     log_kernel_information(i, writer, kernel, histograms=True)
#     log_optimizer_state(i, writer, optimizer, histograms=True)


def log_hparams(i, writer, *, metric_dict, hparam_dict):
    hparam_dict |= {"epoch": i}
    metric_dict = add_prefix(metric_dict, "hparam")
    writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict)

# Training

In [22]:
RUN_START = tsdm.utils.now()
CHECKPOINTDIR = Path(
    f"checkpoints/{MODEL.__name__}/{DATASET.name}/{RUN_NAME}/{RUN_START}"
)
CHECKPOINTDIR.mkdir(parents=True, exist_ok=True)
LOGGING_DIR = f"runs/{MODEL.__name__}/{DATASET.name}/{RUN_NAME}/{RUN_START}"
writer = SummaryWriter(LOGGING_DIR)

In [23]:
i = 0  # batch_num
epoch = 0  # epoch

# with torch.no_grad():
#     for key, dloader in {"train": TRAINLOADER, "test": EVALLOADER}.items():
#         y, ŷ = get_all_predictions(model, dloader)
#         assert torch.isfinite(y).all()
#         log_metrics(epoch, writer=writer, metrics=metrics, targets=y, predics=ŷ, prefix=key)

In [24]:
for epoch in (epochs := range(epoch, epoch + 1000)):
    if epoch == 1000:
        for g in optimizer.param_groups:
            g["lr"] = 0.0001
    batching_time = perf_counter()
    for batch in (batches := tqdm(TRAINLOADER, leave=False)):
        batching_time = perf_counter() - batching_time
        i += 1
        # Optimization step
        model.zero_grad(set_to_none=True)
        targets = batch.encoded_targets.to(device=DEVICE)

        # forward
        forward_time = perf_counter()
        predics = model.forward_batch(batch.timeseries)
        loss = LOSS(targets, predics)
        forward_time = perf_counter() - forward_time

        # backward
        backward_time = perf_counter()
        loss.backward()
        backward_time = perf_counter() - backward_time

        # step
        optimizer.step()

        # batch logging
        with torch.no_grad():
            logging_time = time()
            if torch.any(~torch.isfinite(loss)):
                raise RuntimeError("NaN/INF-value encountered!!")

            log_metrics(
                i,
                writer=writer,
                metrics=metrics,
                targets=targets.clone(),
                predics=predics.clone(),
                prefix="batch",
            )
            log_optimizer_state(i, writer=writer, optimizer=optimizer, prefix="batch")

            # lval = loss.clone().detach().cpu().numpy()
            # gval = grad_norm(list(model.parameters())).clone().detach().cpu().numpy()
            logging_time = time() - logging_time

        batches.set_postfix(
            # gnorm=f"{gval:.2e}",
            epoch=epoch,
            loss=f"{loss.clone().detach().cpu().numpy().item():.2e}",
            # Δt_forward=f"{forward_time:.1f}",
            # Δt_backward=f"{backward_time:.1f}",
            # Δt_logging=f"{logging_time:.1f}",
            # Δt_batching=f"{batching_time:.1f}",
        )
        batching_time = perf_counter()

    with torch.no_grad():
        # end-of-epoch logging
        for key, dloader in {"train": TRAINLOADER, "test": EVALLOADER}.items():
            y, ŷ = get_all_predictions(model, dloader)
            assert torch.isfinite(y).all()
            log_metrics(
                epoch, writer=writer, metrics=metrics, targets=y, predics=ŷ, prefix=key
            )

        # Model Checkpoint
        torch.jit.save(model, CHECKPOINTDIR.joinpath(f"{MODEL.__name__}-{epoch}"))
        torch.save(
            {
                "optimizer": optimizer,
                "optimizer_state": optimizer.state_dict(),
                "epoch": epoch,
                "batch": i,
            },
            CHECKPOINTDIR.joinpath(f"{optimizer.__class__.__name__}-{epoch}"),
        )

# Post Training Analysis

# Profiling

In [None]:
raise

In [None]:
from torch.profiler import ProfilerActivity, profile, record_function

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=False,
    profile_memory=False,
    with_modules=True,
    with_flops=True,
    on_trace_ready=torch.profiler.tensorboard_trace_handler("profiling"),
) as prof:
    for batch in tqdm(TRAINLOADER):
        targets = batch.encoded_targets.to(device=DEVICE)
        model.zero_grad(set_to_none=True)
        predics = model.forward_batch(batch.timeseries)
        loss = LOSS(targets, predics)
        loss.backward()
        optimizer.step()
        break

In [None]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

In [None]:
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
) as prof:
    model.forward_batch(inputs.timeseries)

In [None]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))