# Optimizing Performance by using torchscript to jit-compile ODE model

We make use of the details provided at https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
from typing import Callable, Final, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas
import torch
import torchdiffeq
from opt_einsum import contract
from scipy import stats
from scipy.integrate import odeint
from torch import Tensor, jit, nn, optim
from torch.nn import GRUCell
from tqdm.auto import trange

In [None]:
import tsdm
from tsdm.util import deep_dict_update

from linodenet import init
from linodenet.models import (
    LinearContraction,
    LinODE,
    LinODECell,
    LinODEnet,
    iResNet,
    iResNetBlock,
)

In [None]:
from pandas import DataFrame, DatetimeIndex, Series, Timedelta, Timestamp
from pandas.tseries.offsets import DateOffset

In [None]:
df = tsdm.load_dataset("electricity")
ΔT = np.diff(df.index)
Δt = ΔT[0].astype("timedelta64[m]")
assert np.all(ΔT == Δt)
N, M = df.shape
# remove first year from the data (useless zeros)
span = np.timedelta64(365, "D") // Δt - 1
df = df.iloc[span:]

In [None]:
time = df.index

time[0], time[-1]

In [None]:
df.index[0] - Timestamp("2014-03-31")

### Train test split
For details check N-BEATS paper

In [None]:
split_dates = [
    Timestamp("2014-09-01"),
    Timestamp("2014-03-31"),
    df.index[-1] - DateOffset(days=7),
]
assert Series(split_dates).isin(df.index).all()
split = split_dates[-1]

X_TRAIN = df.loc[:split]
X_TEST = df.loc[split:]

### Pre-processing

In [None]:
# Optionm 1: Normalization

In [None]:
# Option 1: aggregation
X_TRAIN.resample("1H").sum()
X_TEST.resample("1H").mean()

In [None]:
T_TRAIN.resample("2H")

In [None]:
device = torch.device("cuda")
dtype = torch.float32

In [None]:
torch.tensor(T_TRAIN)

In [None]:
final_time = time[-1]
final_time, final_time - DateOffset(months=1)

In [None]:
X = torch.tensor(df.values)

In [None]:
LEN = 100

In [None]:
model = LinODE_RNN(input_size=370, hidden_size=400)
optimizer = optim.Adamax(model.parameters(), lr=0.001)

In [None]:
for n in (pbar := trange(1000)):
    optimizer.zero_grad()

    pbar.set_postfix({key: float(val) for key, val in train_res.items()})
    train_res["loss"].backward()
    optimizer.step()

In [None]:
n = 1000
x = np.random.randn(n)
for k in range(5):
    A = np.eye(n) + np.random.normal(loc=0, scale=1 / n, size=(n, n))
    y = A @ x
    print(f"{y.mean():.6f}  {y.std():.6f}")

In [None]:
SHAPE = (5, 5, 5, 5)
DIM = 5