In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline


import pickle

import numpy as np
import pandas as pd
import torch
from pandas import DataFrame, Index, MultiIndex
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torchinfo import summary

In [None]:
from tsdm.models.pretrained import LinODEnet

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pretrained = LinODEnet(device=DEVICE)

# Load the Model

In [None]:
model = pretrained["model"]
summary(model, depth=2)

# Load the Encoder

In [None]:
encoder = pretrained["encoder"]

In [None]:
USED_COLUMNS = Index(encoder[-1].column_encoders)

# Load the Optimizer

In [None]:
optimizer = pretrained["optimizer"]

# Load the pickled data

In [None]:
def make_dataframes_from_pickle(
    filename: str,
) -> tuple[DataFrame, DataFrame, DataFrame]:
    with open(filename, "rb") as file:
        data = pickle.load(file)

    """Returns DataFrames from pickle"""
    timeseries_dict = {
        key: tables["measurements_aggregated"] for key, tables in data.items()
    }
    timeseries = pd.concat(timeseries_dict, names=["experiment_id"])

    metadata_dict = {
        key: tables["measurements_aggregated"] for key, tables in data.items()
    }
    metadata = pd.concat(metadata_dict, names=["experiment_id"])

    setpoints_dict = {
        key: tables["measurements_aggregated"] for key, tables in data.items()
    }
    setpoints = pd.concat(setpoints_dict, names=["experiment_id"])

    return timeseries, metadata, setpoints

In [None]:
TS, MD, SP = make_dataframes_from_pickle("example_510.pk");

## Clean the timeseries

In [None]:
def clean_timeseries(ts: DataFrame) -> DataFrame:
    """Select the correct columns"""
    columns = ts.columns
    used_columns = list(columns.intersection(USED_COLUMNS))
    drop_columns = list(columns.difference(USED_COLUMNS))
    miss_columns = list(USED_COLUMNS.difference(columns))

    # drop unused columns
    print(f">>> Dropping columns {drop_columns}")
    ts = ts.loc[:, used_columns]

    # fill up missing columns
    print(f">>> Adding columns {miss_columns}")
    ts.loc[:, miss_columns] = float("nan")

    # corerctly order columns
    ts = ts[list(USED_COLUMNS)].copy()

    # fixing timestamp_type
    ts = ts.reset_index("measurement_time")
    if ts["measurement_time"].dtype != "timdedelta64":
        print(">>> Converting float (seconds) to timedelta64")
        ts["measurement_time"] = ts["measurement_time"] * np.timedelta64(1, "s")
    ts = ts.set_index(["measurement_time"], append=True)
    return ts

In [None]:
TS = clean_timeseries(TS)

# get predictions with loop  - slow

In [None]:
@torch.no_grad()
def get_predictions(ts: DataFrame):
    if isinstance(ts.index, MultiIndex):
        names = ts.index.names[:-1]
        keys = ts.index.droplevel(-1).unique()
        frame_dict = {key: get_predictions(ts.loc[key]) for key in keys}
        return pd.concat(frame_dict, names=names)

    T, X = encoder.encode(ts).values()
    T = T.to(device=DEVICE)
    X = X.to(device=DEVICE)
    XHAT = model(T, X)
    return encoder.decode({"T": T, "X": XHAT})


preds = get_predictions(TS)

## Predict in Batch Mode - Way Faster!

In [None]:
@torch.no_grad()
def get_predictions_batch(ts: DataFrame) -> Tensor:
    if isinstance(ts.index, MultiIndex):
        names = ts.index.names[:-1]
        sizes = ts.groupby(names).size()
        T, X = encoder.encode(ts).values()
        T = T.to(device=DEVICE)
        X = X.to(device=DEVICE)
        T_list = torch.split(T, sizes.to_list())
        X_list = torch.split(X, sizes.to_list())
        T = pad_sequence(T_list, batch_first=True, padding_value=torch.nan)
        X = pad_sequence(X_list, batch_first=True, padding_value=torch.nan)

        XHAT = model(T, X)

        predictions = (
            {"T": t[:size], "X": xhat[:size]} for t, xhat, size in zip(T, XHAT, sizes)
        )
        d = {key: encoder.decode(pred) for key, pred in zip(sizes.index, predictions)}
        return pd.concat(d, names=names)
    return get_predictions(ts)


get_predictions_batch(TS)