# Optimizing Performance by using torchscript to jit-compile ODE model

We make use of the details provided at https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

In [1]:
%config InlineBackend.figure_format = 'retina'

In [2]:
import pandas
import torch
import torchdiffeq
from torch import Tensor, nn, jit, optim
from torch.nn import GRUCell
import numpy as np
from opt_einsum import contract
from tqdm.auto import trange
from typing import Union, Callable
from scipy import stats
import matplotlib.pyplot as plt
from scipy.integrate import odeint
from typing import Final

In [3]:
import tsdm
from tsdm.util import deep_dict_update
from linodenet.models import (
    LinODEnet,
    LinODE,
    LinODECell,
    iResNet,
    iResNetBlock,
    LinearContraction,
)
from linodenet import init

In [4]:
b, n, k, l = 20, 5, 7, 3
X = torch.randn(b, k)
T = torch.randn(b)
ΔT = torch.diff(T)

In [5]:
model = LinearContraction(k, l)
model(X).shape

In [6]:
model = iResNetBlock(k)
model(X).shape

In [11]:
model = iResNet(k)
model(X).shape
torch.linalg.norm(model(model.inverse(X)) - X)

In [8]:
model = GRUCell(2 * k, k)
x = X[[0]]
mask = torch.isnan(x)
c = torch.cat([x, mask], dim=-1)
x.shape, c.shape

In [9]:
model = LinODECell(k)
model(ΔT[0], X[0])

In [14]:
model = LinODE(k)
model(X[0], ΔT).shape

In [13]:
model = LinODEnet(k, 2 * k)
model(T, X).shape

In [11]:
X.shape, T.shape
x0 = torch.where(torch.isnan(X[0]), torch.zeros(1), X[0])
x0.shape

In [12]:
model = LinODE_RNN(input_size=10, hidden_size=20)
model(T, X)

In [69]:
df = tsdm.load_dataset("electricity")
ΔT = np.diff(df.index)
Δt = ΔT[0].astype("timedelta64[m]")
assert np.all(ΔT == Δt)
N, M = df.shape
# remove first year from the data (useless zeros)
span = np.timedelta64(365, "D") // Δt - 1
df = df.iloc[span:]

In [70]:
time = df.index

time[0], time[-1]

In [75]:
pandas.

In [78]:
from pandas import DataFrame, Timestamp, Timedelta
from pandas.tseries.offsets import DateOffset

In [79]:
df.index[0] - Timestamp("2014-03-31")

In [None]:
# see N-BEATS paper
split_dates = [Timestamp("2014-09-01"), Timestamp("2014-03-31"), df.index[0]-DateOffset(days=7)]

2014_09_01
date_7d



X_TEST = df.loc[]


In [34]:
final_time = time[-1]
final_time, final_time - DateOffset(months=1)

In [15]:
X = torch.tensor(df.values)

In [16]:
LEN = 100

In [17]:
model = LinODE_RNN(input_size=370, hidden_size=400)
optimizer = optim.Adamax(model.parameters(), lr=0.001)

In [18]:
for n in (pbar := trange(1000)):
    optimizer.zero_grad()

    pbar.set_postfix({key: float(val) for key, val in train_res.items()})
    train_res["loss"].backward()
    optimizer.step()

In [135]:
n = 1000
x = np.random.randn(n)
for k in range(5):
    A = np.eye(n) + np.random.normal(loc=0, scale=1 / n, size=(n, n))
    y = A @ x
    print(f"{y.mean():.6f}  {y.std():.6f}")

In [144]:
SHAPE = (5, 5, 5, 5)
DIM = 5