# Evaluate LinODEnet on Final Product Task

In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()
# import logging
# logging.basicConfig(level=logging.INFO)

# Load Task

In [2]:
from tsdm.tasks import KIWI_FINAL_PRODUCT, KIWI_RUNS_TASK

SPLIT=0
TARGET="OD600"
BATCH_SIZE = 128
observables = KIWI_RUNS_TASK().observables
controls = KIWI_RUNS_TASK().controls
task = KIWI_FINAL_PRODUCT(target=TARGET)
ts_train, md_train = task.splits[SPLIT, "train"]

## Setup Dataset object

In [3]:
from tsdm.datasets.torch import TimeSeriesDataset, MappingDataset
from tsdm.util.decorators import IterItems

ts, md = task.splits[SPLIT, "test"]
TSDs = {}
for idx in md.index:
    TSDs[idx] = TimeSeriesDataset(
        ts.loc[idx],
        metadata=(md.loc[idx], task.final_vec.loc[idx]),
    )
DS = IterItems(MappingDataset(TSDs))

# Setup Encoder

In [4]:
from tsdm.encoders.modular import *

encoder = ChainedEncoder(
    TensorEncoder(),
    DataFrameEncoder(
        FloatEncoder(),
        index_encoders=MinMaxScaler() @ DateTimeEncoder(),
    ),
    Standardizer()
)
encoder.fit(ts.reset_index([0, 1], drop=True))
task.target_idx = task.timeseries.columns.get_loc(task.target)
target_encoder = TensorEncoder() @ FloatEncoder() @ encoder[-1][task.target_idx]

## Define collate_fn

In [5]:
from typing import NamedTuple
from torch import Tensor
import torch

class Batch(NamedTuple):
    index: Tensor
    timeseries: Tensor
    metadata: Tensor
    targets: Tensor
    encoded_targets: Tensor

    def __repr__(self):
        return repr_mapping(
            self._asdict(), title=self.__class__.__name__, repr_fun=repr_array
        )

def get_data(dataset, sample):
    (key, slc), (ts, (md, target)) = sample
    target_time = target.index.item()
    target_value = target.item()
    start_time, stop_time = slc.start, slc.stop
    _, (t, m) = dataset[key, start_time:target_time]
    data = t.copy()
    data.loc[stop_time:target_time, observables] = float("nan")
    return data, target

@torch.no_grad()
def mycollate(batch: list, dataset=DS):
    index = []
    timeseries = []
    metadata = []
    targets = []
    encoded_targets = []

    for sample in batch:
        data, target = get_data(dataset, sample)
        timeseries.append(encoder.encode(target))
        targets.append(target)
        # encoded_targets.append(target_encoder.encode(target))

    # index = torch.stack(index)
    # targets = pandas.concat(targets)
    # encoded_targets = torch.concat(encoded_targets)

    return Batch(index, timeseries, metadata, targets, encoded_targets)

## Get DataLoaders

In [6]:
# TRAINLOADER = task.get_dataloader(
#     (SPLIT, "train"),
#     batch_size=BATCH_SIZE,
#     collate_fn=mycollate,
#     # pin_memory=True,
#     drop_last=True,
#     shuffle=True,
#     # num_workers=6,
#     # num_workers=os.cpu_count() // 4,
#     # persistent_workers=True,
# )


EVALLOADER = task.get_dataloader(
    (SPLIT, "train"),
    batch_size=BATCH_SIZE,
    collate_fn=mycollate,
    # pin_memory=True,
    drop_last=False,
    shuffle=False,
    # num_workers=6,
    # num_workers=os.cpu_count() // 4,
    # persistent_workers=True,
)

In [7]:
next(iter(EVALLOADER))

In [22]:
isamp = iter(TRAINLOADER.sampler)
dset = TRAINLOADER.dataset

In [90]:
key, slc = idx
start_time, stop_time = slc.start, slc.stop
target_time = dset[idx][1].metadata[-1].index.item()
_, (t, m) = dset[key, start_time:target_time]
data = t.copy()
data.loc[stop_time:target_time, observables] = float("nan")

In [92]:
data

In [58]:
idx = next(isamp)




# Load Model

In [6]:
import torch
from torch import jit
from torchinfo import summary

PATH = "checkpoints/LinODEnet/KIWI_RUNS/"
NAME = "SurrLoss+Sequential_Filter/2021-12-20T13:43:00/LinODEnet-70"
model = torch.jit.load(PATH + NAME, torch.device("cpu"))
summary(model, depth=1)