In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt

plt.style.use("bmh")
plt.rcParams["axes.axisbelow"] = True
import numpy as np
import pandas
import pandas as pd
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
import matplotlib.dates as mdates

rng = np.random.default_rng()
np.set_printoptions()

import torch
import torchinfo
from torch import Tensor, jit, tensor
from torch.utils.data import DataLoader
from copy import deepcopy

## Config

In [None]:
ID = (439, 15325)  # the Run_id / Experiment_id we want to plot.

# files:  2021-11-12T00:52:10 2021-11-12T00:51:55 2021-11-12T00:51:48
# "checkpoints/2021-11-15T12:05:00/LinODEnet-0"
# "checkpoints/LinODEnet/KIWI_RUNS/skew_allways/2021-11-15T16:05:41/LinODEnet-0"
# "adam/2021-11-15T20:38:52/LinODEnet-0"
PATH = "checkpoints/LinODEnet/KIWI_RUNS/"
NAME = "adam/2021-11-15T20:38:52/LinODEnet-0"
MODEL_FILE = PATH + NAME  # the model checkpoint
DEVICE = torch.device("cpu")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)

In [None]:
def apply_along_axes(a: Tensor, b: Tensor, op, axes: tuple[int, ...]) -> Tensor:
    print(a.shape, axes)
    rank = len(a.shape)
    axes = tuple(ax % rank for ax in axes)
    source = tuple(range(rank))
    iperm = axes + tuple(ax for ax in range(rank) if ax not in axes)
    perm = tuple(np.argsort(iperm))
    print(source, perm, iperm)
    a = a.moveaxis(source, perm)
    # print(a.shape, b.shape)
    a = op(a, b)
    a = a.moveaxis(source, iperm)
    # print(a.shape)
    return a

In [None]:
data = torch.randn(7, 8, 9)
data[data > 0] = NAN
axes = (-1,)
mask = ~torch.isnan(data)
count = mask.sum(dim=axes)
masked = torch.where(mask, data, tensor(0.0))
print(f"{torch.isnan(masked).any()=}")
mean = masked.sum(dim=axes) / count
residual = apply_along_axes(masked, mean, torch.sub, axes=axes)
stdv = (residual**2).sum(dim=axes) / torch.minimum(torch.tensor(1.0), count - 1)

In [None]:
from tsdm.tasks import KIWI_RUNS_TASK

task = KIWI_RUNS_TASK()
assert ID in task.splits[(0, "train")][0].index
TRAINLOADER = task.dataloaders[(0, "train")]
EVALLOADER = task.dataloaders[(0, "test")]

In [None]:
ts = task.timeseries
ts = ts.loc[ID].astype("float32")

## Load the Model

In [None]:
model = torch.jit.load(MODEL_FILE, torch.device("cpu"))
torchinfo.summary(model)

## First Look: simply plot the first item from each dataloader

In [None]:
def prep_batch(batch: tuple[Tensor, Tensor]):
    """Get batch and create model inputs and targets"""
    T, X = batch
    targets = X[..., task.observation_horizon :, task.targets.index].clone()
    # assert targets.shape == (BATCH_SIZE, PRD_HORIZON, len(TASK.targets))
    originals = X.clone()
    inputs = X.clone()
    inputs[:, task.observation_horizon :, task.targets.index] = NAN
    inputs[:, task.observation_horizon :, task.observables.index] = NAN
    # assert inputs.shape == (BATCH_SIZE, HORIZON, NUM_DIM)
    return T, inputs, targets, originals

In [None]:
iload = iter(TRAINLOADER)

In [None]:
batch = next(iload)
times, inputs, targets, originals = (x.to(device="cpu") for x in prep_batch(batch))
outputs = model(times, inputs)

times = times[0].detach().cpu().numpy()
inputs = inputs[0].detach().cpu().numpy()
outputs = outputs[0].detach().cpu().numpy()
targets = targets[0].detach().cpu().numpy()
originals = originals[0].detach().cpu().numpy()

times.shape, outputs.shape, inputs.shape, targets.shape, originals.shape

In [None]:
fig, axes = plt.subplots(
    nrows=2, ncols=2, sharex=True, figsize=(8, 4), constrained_layout=True
)

for ax, (idx, target) in zip(axes.flatten(), task.targets.items()):

    data = originals[:, idx]
    mask = ~np.isnan(data)

    ax.plot(
        times[mask],
        data[mask],
        ls="-",
        lw=0.5,
        marker=".",
        ms=3,
    )
    ax.plot(
        times,
        outputs[:, idx],
        ls="-",
        lw=0.5,
        marker=".",
        ms=3,
    )
    print(target, sum(mask))
    ax.legend([target, target])

## Using original data

In [None]:
def make_plot(axes, task, batch):
    # batch = slices[0]
    times, inputs, targets, originals = (x.to(device="cpu") for x in prep_batch(batch))
    outputs = model(times, inputs)

    times = times[0].detach().cpu()
    inputs = inputs[0].detach().cpu()
    outputs = outputs[0].detach().cpu()
    targets = targets[0].detach().cpu()
    originals = originals[0].detach().cpu()

    times.shape, outputs.shape, inputs.shape, targets.shape, originals.shape
    reconstructed = preprocessor.decode((times, outputs)).astype("float32")

    for ax, (idx, target) in zip(axes.flatten(), task.targets.items()):
        color = next(ax._get_lines.prop_cycler)["color"]
        data = originals[:, idx]
        mask = ~np.isnan(data)
        ax.plot(
            reconstructed.index[: task.observation_horizon],
            reconstructed.iloc[: task.observation_horizon, idx],
            ls=":",
            lw=0.5,
            color=color,
        )
        ax.plot(
            reconstructed.index[task.observation_horizon :],
            reconstructed.iloc[task.observation_horizon :, idx],
            ls="-",
            lw=0.5,
            color=color,
        )
        print(reconstructed)

In [None]:
dloader = TRAINLOADER
dataset = dloader.dataset[ID]
preprocessor = deepcopy(dloader.preprocessor)
sampler = deepcopy(dloader.sampler[ID])
sampler.shuffle = False
LOADER = DataLoader(dataset, sampler=sampler)
slices = Series(LOADER)

In [None]:
%matplotlib widget

fig, axes = plt.subplots(
    nrows=2, ncols=2, sharex=True, figsize=(16, 8), constrained_layout=True
)

for ax, (idx, target) in zip(axes.flatten(), task.targets.items()):
    data = ts[target]
    times = ts.index.values
    mask = ~np.isnan(data)
    ax.plot(
        times[mask],
        data[mask],
        ls="-",
        lw=0.5,
        marker=".",
        ms=3,
    )
    ax.legend([f"{target} - observations"])

In [None]:
batches = slices.iloc[[0, 300, 600, 900, 1200]]

for batch in batches:
    make_plot(axes, task, batch)

fig.savefig(f"{NAME.replace(r'/', r'_')}.pdf")

In [None]:
next(iter(TRAINLOADER))[0]

In [None]:
from tsdm.encoders import DateTimeEncoder, MinMaxScaler

In [None]:
enc = MinMaxScaler() @ DateTimeEncoder()
enc.fit(ts.index)
encoded = enc.encode(ts.index)

In [None]:
np.array(data)

In [None]:
DateTimeEncoder().fit(ts.index)