# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging
from collections.abc import *
from dataclasses import KW_ONLY, dataclass
from typing import Any, NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
from pandas import DataFrame, Index, MultiIndex, Series

rng = np.random.default_rng()
np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
logging.basicConfig(level=logging.INFO)

In [None]:
import torch
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch import nan as NAN
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Sampler as TorchSampler

from tsdm.datasets import KiwiDataset, TimeSeriesCollection
from tsdm.random.samplers import HierarchicalSampler, SlidingWindowSampler
from tsdm.tasks import TimeSeriesSampleGenerator, TimeSeriesTask
from tsdm.tasks.base import Batch, Sample, TimeSeriesSampleGenerator, TimeSeriesTask
from tsdm.utils.data import folds_as_frame, folds_as_sparse_frame, folds_from_groups
from tsdm.utils.types import KeyVar

In [None]:
class KiwiSampleGenerator(TimeSeriesSampleGenerator):
    r"""Sample generator for the KIWI dataset."""

    def __init__(self, dataset):
        super().__init__(
            dataset,
            observables=[
                "Base",
                "DOT",
                "Glucose",
                "OD600",
                "Acetate",
                "Fluo_GFP",
                "Temperature",
                "pH",
            ],
            covariates=[
                "Cumulated_feed_volume_glucose",
                "Cumulated_feed_volume_medium",
                "InducerConcentration",
                "StirringSpeed",
                "Flow_Air",
                "Probe_Volume",
            ],
            targets=["OD600", "Fluo_GFP"],
        )


class KiwiTask(TimeSeriesTask):
    r"""Task for the KIWI dataset."""
    # dataset: TimeSeriesCollection = KiwiDataset()
    observation_horizon: str = "2h"
    r"""The number of datapoints observed during prediction."""
    forecasting_horizon: str = "1h"
    r"""The number of datapoints the model should forecast."""

    def __init__(self) -> None:
        dataset = KiwiDataset()
        dataset.timeseries = dataset.timeseries.astype("float64")
        super().__init__(dataset=dataset)

    @staticmethod
    def default_metric(*, targets, predictions):
        r"""TODO: implement this."""

    def default_collate(self):
        r"""TODO: implement this."""

    # def make_encoder(self, key: KeyVar, /) -> ModularEncoder:
    #     ...

    def make_sampler(self, key: KeyVar, /) -> TorchSampler:
        split: TimeSeriesCollection = self.splits[key]
        subsamplers = {
            key: SlidingWindowSampler(tsd.index, horizons=["2h", "1h"], stride="1h")
            for key, tsd in split.items()
        }
        return HierarchicalSampler(split, subsamplers, shuffle=False)  # type: ignore[return-value]

    def make_folds(self, /) -> DataFrame:
        r"""Group by RunID and color which indicates replicates."""
        md = self.dataset.metadata
        groups = md.groupby(["run_id", "color"], sort=False).ngroup()
        folds = folds_from_groups(
            groups, seed=2022, num_folds=5, train=7, valid=1, test=2
        )
        df = folds_as_frame(folds)
        return folds_as_sparse_frame(df)

    def make_generator(self, key: KeyVar, /) -> KiwiSampleGenerator:
        split = self.splits[key]
        return KiwiSampleGenerator(split)

In [None]:
task = KiwiTask()

In [None]:
task.dataset

In [None]:
dataloader = task.dataloaders[(0, "train")]

In [None]:
sampler = task.samplers[(0, "train")]

In [None]:
generator = task.generators[(0, "train")]

In [None]:
task.splits

In [None]:
batch = next(iter(dataloader))

In [None]:
sample = batch[0]

In [None]:
sample.inputs

In [None]:
sample.targets

In [None]:
sample.key

# Designing Encoder for the data

In [None]:
from tsdm.encoders import *

In [None]:
t = sample.inputs.x.index
x = sample.inputs.x
t_target = sample.inputs.t_target
y = sample.targets.y

In [None]:
VF = task.dataset.value_features

In [None]:
column_encoders = {}

for col, scale, lower, upper in VF[["scale", "lower", "upper"]].itertuples():
    encoder = BoundaryEncoder(lower, upper, mode="clip")
    match scale:
        case "percent":
            encoder = (
                LogitBoxCoxEncoder()
                @ LinearScaler(lower, upper)
                @ BoundaryEncoder(lower, upper, mode="clip")
            )
        case "absolute":
            if upper < np.inf:
                encoder = (
                    BoxCoxEncoder()
                    @ LinearScaler(lower, upper)
                    @ BoundaryEncoder(lower, upper, mode="clip")
                )
            else:
                encoder = BoxCoxEncoder() @ BoundaryEncoder(lower, upper, mode="clip")
        case "linear":
            encoder = IdentityEncoder()
        case _:
            raise ValueError(f"{scale=} unknown")
    column_encoders[col] = encoder
column_encoders

# original data

In [None]:
%matplotlib inline
ts = task.dataset.timeseries.copy()
ts["dummy"] = float("nan")

In [None]:
fig, ax = plt.subplots(ncols=5, nrows=3, figsize=(20, 6), constrained_layout=True)
ts.hist(ax=ax, density=True, log=True, bins=20)
fig.savefig("data_original.pdf")
print(list(ts.columns))

# plain standardization

In [None]:
encoder = Standardizer()
encoder.fit(ts)
encoded = encoder.encode(ts)
fig, ax = plt.subplots(ncols=5, nrows=3, figsize=(20, 6), constrained_layout=True)
encoded.hist(ax=ax, density=True, log=True, bins=20)
fig.savefig("data_encoded_standardizer.pdf")
print(list(encoded.columns))

# With BoxCox

In [None]:
encoder = Standardizer() @ FrameEncoder(column_encoders)
encoder.fit(ts)
encoded = encoder.encode(ts)
fig, ax = plt.subplots(ncols=5, nrows=3, figsize=(20, 6), constrained_layout=True)
encoded.hist(ax=ax, density=True, log=True, bins=20)
fig.savefig("data_encoded_box_cox.pdf")
print(list(encoded.columns))

## residual error

In [None]:
decoded = encoder.decode(encoded)
assert (encoded.isna() == ts.isna()).all().all()
(decoded - ts).abs().mean().max()

In [None]:
enc = FrameEncoder(
    column_encoders=column_encoders,
    index_encoders={
        "run_id": IdentityEncoder(),
        "exp_id": IdentityEncoder(),
        "measurement_time": MinMaxScaler() @ TimeDeltaEncoder(),
    },
)

In [None]:
enc.fit(ts)

In [None]:
encoded = enc.encode(ts)

In [None]:
f2t = Frame2TensorDict(
    groups={"key": ["run_id", "exp_id"], "T": ["measurement_time"], "X": ...},
    dtypes={"T": "float32", "X": "float32"},
)

f2t.fit(encoded)
f2t.encode(encoded)

# single_encoder

In [None]:
encoder = (
    Frame2TensorDict(
        groups={"key": ["run_id", "exp_id"], "T": ["measurement_time"], "X": ...},
        dtypes={"T": "float32", "X": "float32"},
    )
    @ Standardizer()
    @ FrameEncoder(
        column_encoders=column_encoders,
        index_encoders={
            # "run_id": IdentityEncoder(),
            # "exp_id": IdentityEncoder(),
            "measurement_time": MinMaxScaler()
            @ TimeDeltaEncoder(),
        },
    )
)

In [None]:
ts = task.dataset.timeseries
encoder.fit(ts)
encoded = encoder.encode(ts)

In [None]:
decoded = encoder.decode(encoded)
MAD = (decoded - ts).abs().mean().mean()

## Applying to slice

In [None]:
encoded = encoder.encode(sample.inputs.x)

In [None]:
decoded = encoder.decode(encoded)
assert (decoded.isna() == sample.inputs.x.isna()).all().all()
MAD = (decoded - sample.inputs.x).abs().mean().mean()

# Collate_fn

## collate_fn with encoder

In [None]:
TX, X = encoder.encode(sample.inputs.x).values()
TY, Y = encoder.encode(sample.targets.y).values()

In [None]:
assert sample.inputs.x.isna().sum().sum() == X.isnan().sum()
assert sample.targets.y.isna().sum().sum() == Y.isnan().sum()

In [None]:
x_vals: list[Tensor] = []
y_vals: list[Tensor] = []
x_time: list[Tensor] = []
y_time: list[Tensor] = []
x_mask: list[Tensor] = []
y_mask: list[Tensor] = []

for sample in batch:
    t = sample.inputs.x.index
    x = sample.inputs.x.values
    t_target = sample.inputs.t_target
    y = sample.targets.y

    # get whole time interval
    time = torch.cat((t, t_target))
    sorted_idx = torch.argsort(time)

    # pad the x-values
    x_padding = torch.full(
        (t_target.shape[0], x.shape[-1]), fill_value=NAN, device=x.device
    )
    values = torch.cat((x, x_padding))

    # create a mask for looking up the target values
    mask_y = y.isfinite()
    mask_pad = torch.zeros_like(x, dtype=torch.bool)
    mask_x = torch.cat((mask_pad, mask_y))

    x_vals.append(values[sorted_idx])
    x_time.append(time[sorted_idx])
    x_mask.append(mask_x[sorted_idx])

    y_time.append(t_target)
    y_vals.append(y)
    y_mask.append(mask_y)

Batch(
    x_time=pad_sequence(x_time, batch_first=True).squeeze(),
    x_vals=pad_sequence(x_vals, batch_first=True, padding_value=NAN).squeeze(),
    x_mask=pad_sequence(x_mask, batch_first=True).squeeze(),
    y_time=pad_sequence(y_time, batch_first=True).squeeze(),
    y_vals=pad_sequence(y_vals, batch_first=True, padding_value=NAN).squeeze(),
    y_mask=pad_sequence(y_mask, batch_first=True).squeeze(),
)

In [None]:
def collate_fn(batch: list[Sample]) -> Batch:
    r"""Collate tensors into batch.

    Transform the data slightly: t, x, t_target → T, X where X[t_target:] = NAN
    """
    x_vals: list[Tensor] = []
    y_vals: list[Tensor] = []
    x_time: list[Tensor] = []
    y_time: list[Tensor] = []
    x_mask: list[Tensor] = []
    y_mask: list[Tensor] = []

    for sample in batch:
        t = sample.inputs.x.index
        x = sample.inputs.x.values
        t_target = sample.inputs.t_target
        y = sample.targets.y

        # get whole time interval
        time = torch.cat((t, t_target))
        sorted_idx = torch.argsort(time)

        # pad the x-values
        x_padding = torch.full(
            (t_target.shape[0], x.shape[-1]), fill_value=NAN, device=x.device
        )
        values = torch.cat((x, x_padding))

        # create a mask for looking up the target values
        mask_y = y.isfinite()
        mask_pad = torch.zeros_like(x, dtype=torch.bool)
        mask_x = torch.cat((mask_pad, mask_y))

        x_vals.append(values[sorted_idx])
        x_time.append(time[sorted_idx])
        x_mask.append(mask_x[sorted_idx])

        y_time.append(t_target)
        y_vals.append(y)
        y_mask.append(mask_y)

    return Batch(
        x_time=pad_sequence(x_time, batch_first=True).squeeze(),
        x_vals=pad_sequence(x_vals, batch_first=True, padding_value=NAN).squeeze(),
        x_mask=pad_sequence(x_mask, batch_first=True).squeeze(),
        y_time=pad_sequence(y_time, batch_first=True).squeeze(),
        y_vals=pad_sequence(y_vals, batch_first=True, padding_value=NAN).squeeze(),
        y_mask=pad_sequence(y_mask, batch_first=True).squeeze(),
    )

In [None]:
t = sample.inputs.x.index

In [None]:
collate_fn(batch)

## collate_fn using encoder!