# Hands on KiWi 2022-10-06

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%matplotlib inline

import datetime
import warnings
import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from torchinfo import summary

warnings.filterwarnings(action="ignore", category=UserWarning, module="tsdm")

# Load ùóßime ùó¶eries ùóóatasets & ùó†odels (ùöùùöúùöçùöñ)

- Source: https://git.tu-berlin.de/bvt-htbd/kiwi/tf1/tsdm
- Documentation: https://bvt-htbd.gitlab-pages.tu-berlin.de/kiwi/tf1/tsdm/


### Installation

```bash
pip install tsdm --extra-index-url https://__token__:<your_personal_token>@git.tu-berlin.de/api/v4/projects/6694/packages/pypi/simple
```

**Note: This is all considered alpha software, API might change between versions!**

In [None]:
import tsdm

## Time Series Datasets

### A **time series** is a tuple $D = (ùêìùêí, M)$

- Time-indexed data $ùêìùêí=\{(t_n v_n)‚à£n=1‚Ä¶N\}$
  - timestamps $t‚ààùì£$, values $v‚ààùì•$
- Time-independent metadata $M‚ààùìú$


### A **time series collection** is a tuple $C = (I,S,G)$ consisting of

- Index $I‚äÜùìò$ (of id's)
- Collection of timeseries $\{D_i=(ùêìùêí_i, M_i) ‚à£ i‚ààùìòÔºåùêìùêí_i‚àà(ùì£√óùì•)_i^*Ôºå M_i‚ààùìú_i\}$
- Index-independent Global data $G‚ààùìñ$
- If $ùì£_i=ùì£$, $ùì•_i=ùì•$ and $ùìú_i=ùìú$ ‚ü∂ **equimodal**

## Examples

1. Clinical data:

    - index $ùìò$: patient / admission id
    - metadata $M_i$: patient metadata (age, sex, preconditions, ...)
    - values $V_i$: measured data (heart rate, blood pressure, etc.) 
    - globals $G$: units of measurement, measurement devices used, etc.
    
2. Bioreactor data

    - index $ùìò$: experiment / run id
    - metadata $M_i$: bacterial stem used, reactor size, reactor type
    - values $V_i$: measured data (O‚ÇÇ-, Glucose-, Acetate- concentration, stirring speed) 
    - globals $G$: units of measurement, measurement devices used, etc. 

# Load the dataset

In [None]:
from tsdm.datasets import KIWI_RUNS

dataset = KIWI_RUNS()

In [None]:
ts = dataset.timeseries

In [None]:
md = dataset.metadata

In [None]:
dataset.units

# Load the model & encoder

Preliminary API, ideally should be replaced with a database lookup.

- Model depends both on dataset and task, in particular the fold.
- Encoder depends both on model and dataset.
- ‚áù tags based lookup?: model, dataset, fold, epoch, hyperparameters

In [None]:
from tsdm.models.pretrained import LinODEnet

model = LinODEnet()
summary(model)

## Make a prediction

In [None]:
ts

In [None]:
key = (run_id, experiment_id) = (510, 16871)
s = ts.loc[key].astype(float).copy()
s

In [None]:
ts.reset_index(level=[0, 1], drop=True)

In [None]:
from tsdm.encoders import *
from tsdm.tasks import KIWI_RUNS_TASK

train_ts, train_md = KIWI_RUNS_TASK().splits[(0, "train")]

encoder = Frame2Tensor() @ FrameEncoder(
    Standardizer() @ FloatEncoder(), index_encoders=MinMaxScaler() @ TimeDeltaEncoder()
)
# encoder.fit(train_ts.reset_index(level=[0,1], drop=True))
encoder.fit(train_ts.drop([355]).reset_index(level=[0, 1], drop=True))
# encoder.fit(ts.loc[510].reset_index(level=0, drop=True))
# encoder.fit(ts.loc[key])

In [None]:
observables = ["Base", "DOT", "Glucose", "OD600", "Acetate", "Fluo_GFP", "Volume", "pH"]
controls = [
    "Cumulated_feed_volume_glucose",
    "Cumulated_feed_volume_medium",
    "InducerConcentration",
    "StirringSpeed",
    "Flow_Air",
    "Temperature",
    "Probe_Volume",
]

total_horizon = slice(np.timedelta64(3, "h"), np.timedelta64(9, "h"))
observation_horizon = slice(np.timedelta64(3, "h"), np.timedelta64(6, "h"))
forecasting_horizon = slice(np.timedelta64(6, "h"), np.timedelta64(9, "h"))

inputs = s.loc[total_horizon].copy()
inputs.loc[forecasting_horizon, observables] = float("nan")
inputs

In [None]:
encoded = encoder.encode(inputs)
# decoded = encoder.decode(encoded)

In [None]:
T, X = encoder.encode(inputs)

Xhat = model(T, X)

predictions = encoder.decode((T, Xhat))

In [None]:
def get_predictions(key, t_start, t_mid, t_stop):
    # encoder = Frame2Tensor() @ FrameEncoder(
    #     Standardizer() @ FloatEncoder(), index_encoders=MinMaxScaler() @ TimeDeltaEncoder()
    # )
    # encoder.fit(ts.loc[key])

    observations = ts.loc[key].astype(float).copy()

    total_horizon = slice(t_start, t_stop)
    observation_horizon = slice(t_start, t_mid)
    forecasting_horizon = slice(t_mid, t_stop)

    inputs = observations.loc[total_horizon].copy()
    inputs.loc[forecasting_horizon, observables] = float("nan")
    T, X = encoder.encode(inputs)
    Xhat = model(T, X)
    predictions = encoder.decode((T, Xhat))

    return observations, predictions, Xhat

# Make a Prediction with the model

In [None]:
%matplotlib inline

fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)

target = "DOT"  # "DOT" "Base", "DOT", "Glucose", "OD600", "Acetate", "Fluo_GFP", "pH"

key = run_id, exp_id = (
    474,
    16120,
)  # (474, 16120) (510, 16871)  (445, 15527)  (474, 16120) (484, 16346) (449, 15653)
h = np.timedelta64(1, "h")
t_start = np.timedelta64(1, "h")
t_mid = np.timedelta64(8, "h")
t_stop = t_mid + np.timedelta64(2, "h")

observations, predictions, xhat = get_predictions(
    (run_id, exp_id), t_start, t_mid, t_stop
)

Xhat_observation = predictions[t_start:t_mid]
Xhat_forecasting = predictions[t_mid:t_stop]
# T = observations.index.to_numpy() / np.timedelta64(1, "h")
# T_observation = Xhat_observation.index / np.timedelta64(1, "h")
# T_forecasting = Xhat_forecasting.index / np.timedelta64(1, "h")

ax.plot(observations.index / h, observations[target], ".b", label=f"{target} observed")
ax.plot(
    Xhat_observation.index / h,
    Xhat_observation[target],
    ":r",
    label=f"{target} estimated",
)
ax.plot(
    Xhat_forecasting.index / h,
    Xhat_forecasting[target],
    "-r",
    label=f"{target} estimated",
)
ax.axvspan(t_start, t_mid, facecolor="grey", alpha=0.3)
ax.axvspan(t_mid, t_stop, facecolor="green", alpha=0.3)
# ax.set_ylim(observations[target].min(), observations[target].max())
ax.legend()

## Glitches in the data

![Screenshot from 2022-10-06 13-32-35.png](attachment:8c8a07d9-2bbd-4b08-92a7-bdf2563ec021.png)
![Screenshot from 2022-10-06 13-33-03.png](attachment:2aa14f8c-5cee-4d3e-8fc8-37f34294f98d.png)
![Screenshot from 2022-10-06 13-33-17.png](attachment:284ba9a9-ab3c-4013-8878-f9311a5e7ae1.png)
![Screenshot from 2022-10-06 13-33-37.png](attachment:f71a0545-1f54-4959-8bb3-555dcbec09a2.png)

## How it should look like

![long_forecast.png](attachment:9290594d-d1e3-4942-b39f-49f6222d0e97.png)