# Title

In [85]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [86]:
import pickle
import pandas
import pandas as pd
import numpy as np
from pandas import DataFrame, Series

In [88]:
from cross_validate_kiwi_runs import ReplicateBasedSplitter, create_replicate_dict

with open("kiwi_experiments_and_run_355.pk", "rb") as f:
    experiments_per_run = pickle.load(f)

col_run_to_exp = dict(create_replicate_dict(experiments_per_run))

splitter = ReplicateBasedSplitter()

stefan_splits = []

for train_keys, test_keys in splitter.split(col_run_to_exp):
    stefan_splits.append( (train_keys, test_keys) )
stefan_splits

In [89]:
from tsdm.datasets import KIWI_RUNS

KIWI_RUNS.clean()

In [90]:
metadata = KIWI_RUNS.metadata
timeseries = KIWI_RUNS.dataset

In [91]:
reverse_lookup = {}

for run_id in metadata.index.unique("run_id"):
    colors = metadata["color"][[run_id]]
    for color in colors.unique():
        mask = colors == color
        indices = colors[colors == color].index.tolist()
        reverse_lookup[(color, run_id)] = indices
        
assert reverse_lookup == col_run_to_exp

## groupby solution

https://stackoverflow.com/a/51329888/9318372

In [71]:
rev = metadata.groupby(["color", "run_id"]).groups
rev = {key:idx.tolist() for key, idx in rev.items()}
assert rev == col_run_to_exp

## Custom splitting logic

In [72]:
from sklearn.model_selection import ShuffleSplit
from itertools import chain

splitter = ShuffleSplit(n_splits=5, random_state=0, test_size=0.25)
groups = metadata.groupby(["color", "run_id"])
rev_idx = groups.ngroup()
groups = metadata.groupby(["color", "run_id"])
group_idx = groups.ngroup()

splits = DataFrame(index=metadata.index)
for i, (train, test) in enumerate(splitter.split(groups)):
    splits[i] = group_idx.isin(train).map({False: "test", True: "train"})

splits.columns.name = "split"
splits.astype("string").astype("category")

## Loss function

Divide 'Glucose' by 10, 'OD600' by 20, 'DOT' by 100, 'Base' by 200, then use RMSE.

In [73]:
targets = {"Glucose", "OD600", "DOT", "Base"}
assert targets <= set(timeseries.columns)

In [74]:
timeseries.min()

In [75]:
timeseries.max()

In [76]:
timeseries.max() - timeseries.min()

In [77]:
from itertools import product

In [12]:
list(product(range(5), ("train", "test")))

In [13]:
timeseries.dtypes

In [14]:
mask  = splits[0] == "train"
idx = splits[0][mask].index

In [15]:
timeseries.reset_index(level=2).loc[idx].set_index(["measurement_time"], append=True)

# Implementation

In [78]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Splits

In [79]:
from tsdm.tasks import KIWI_RUNS_TASK
TASK = KIWI_RUNS_TASK()

In [18]:
TASK.splits((4, "train"))[0]

## Preprocessing

In [19]:
from sklearn.preprocessing import StandardScaler

In [20]:
preprocessor = StandardScaler()
ts, md = TASK.splits((4, "train"))


In [21]:
preprocessor.fit(ts)
preprocessor.transform(ts, copy=False)
ts

## Encoding in torch

In [22]:
import torch
from tsdm.encoders.functional import time2float

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
ts = ts.reset_index(level=2)  # make measurements regular col


In [23]:
time2float(ts["measurement_time"].values)

In [24]:
T = torch.tensor(time2float(ts["measurement_time"].values), device=device, dtype=dtype)
X = torch.tensor(
    ts.drop(columns=["measurement_time"]).values, device=device, dtype=dtype
)

## Creating DataSetCollection Object

In [25]:
from tsdm.datasets import DataSetCollection
from torch.utils.data import TensorDataset

In [26]:
shared_index = ts.index.unique().values
masks = {idx:(ts.index==idx) for idx in shared_index}
datasets = {idx: TensorDataset(T[masks[idx]], X[masks[idx]]) for idx in shared_index}

In [27]:
from pandas import Series

In [28]:
s = Series(datasets)

In [29]:
dataset = DataSetCollection(datasets)

In [30]:
some_index = ts.index.values[42]
dataset[some_index]

In [31]:
dataset[some_index][:10]

## Creating CollectionSampler Object

In [34]:
from tsdm.util.samplers import CollectionSampler, SequenceSampler
from torch.utils.data import TensorDataset
from functools import partial

subsampler = partial(SequenceSampler, seq_len=100, shuffle=True)
sampler = CollectionSampler(dataset, subsampler=subsampler)

In [35]:
sample = next(iter(sampler))
element = dataset[sample]

In [36]:
from tqdm.auto import tqdm
for b in tqdm(sampler):
    ...

## DataLoader Object

In [37]:
from torch.utils.data import DataLoader

In [38]:
dloader = DataLoader(dataset, sampler=sampler, batch_size=32)

In [39]:
next(iter(dloader))

In [40]:
for batch in tqdm(dloader):
    ...

## Testing implemented variant

In [53]:
from tsdm.tasks import KIWI_RUNS_TASK
task = KIWI_RUNS_TASK()

In [59]:
task.dataloaders

In [81]:
dloader = task.batchloader[0]

In [83]:
T, X = next(iter(dloader))

In [84]:
T.shape, X.shape

In [61]:
for batch in tqdm(dloader):
    ...

In [60]:
for batch in tqdm(task.dataloaders[(0, "train")]):
    ...