# TODO

- compare speed & memory of batching to (async?) loop over batch (supports variable length!)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import torch

pd.set_option("display.max_rows", 5)

In [None]:
from tsdm.datasets import Electricity

In [None]:
X = Electricity.dataset
x = X.iloc[:100]
t = X.index[:100]
t = (t - t[0]) / np.timedelta64(1, "h")
x

In [None]:
device = torch.device("cuda")
dtype = torch.float32
X = torch.tensor(x.values, dtype=dtype, device=device)
T = torch.tensor(t.values, dtype=dtype, device=device)
ΔT = torch.diff(T)
Δt = ΔT[0]
x0 = X[0]
T_batch = torch.stack([T, T, T])
X_batch = torch.stack([X, X, X])
Δt_batch = torch.stack([Δt, Δt, Δt])
ΔT_batch = torch.stack([ΔT, ΔT, ΔT])
x0_batch = torch.stack([x0, x0, x0])

In [None]:
from linodenet.models import LinODEnet, LinODECell, LinODE

model = LinODECell(370, kernel_regularization="skew-symmetric")
model.to(device=device, dtype=dtype)
print(Δt.shape, x0.shape)
print(model(Δt, x0).shape)
print(Δt_batch.shape, x0_batch.shape)
print(model(Δt_batch, x0_batch).shape)

In [None]:
Δt_batch.shape

In [None]:
BATCHSIZE = 32
Δt_batch = torch.stack([Δt] * BATCHSIZE)
x0_batch = torch.stack([x0] * BATCHSIZE)

In [None]:
%%timeit
Y = [model(Δt, Δx) for Δt, Δx in zip(Δt_batch, x0_batch)]

In [None]:
%%timeit
Y = []
for Δt, Δx in zip(Δt_batch, x0_batch):
    Y += [model(Δt, Δx)]

In [None]:
%%timeit

s = torch.cuda.Stream()
with torch.cuda.stream(s):
    Y = []

    for Δt, Δx in zip(Δt_batch, x0_batch):
        Y += [model(Δt, Δx)]
        s.wait_stream(torch.cuda.current_stream())

In [None]:
%%timeit
Y = model(Δt_batch, x0_batch)

# Variable Length Batch

In [None]:
LEN, DIM = X.shape

T_batch = []
X_batch = []
for k in range(BATCHSIZE):
    idx = np.random.choice(range(1, LEN))
    X_batch += [X[:idx]]
    T_batch += [T[:idx]]

print([len(x) for x in X_batch])

In [None]:
from linodenet.models import LinODEnet, LinODECell, LinODE

model = LinODEnet(370, 250)
model.to(device=device, dtype=dtype)
print(T.shape, X.shape)
print(model(T, X).shape)

In [None]:
from torchinfo import summary

summary(model, input_size=[(BATCHSIZE, LEN), (BATCHSIZE, LEN, DIM)])

In [None]:
%%timeit
Y = [model(T, X) for T, X in zip(T_batch, X_batch)]

In [None]:
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence

lengths = [len(x) for x in X_batch]

In [None]:
X_batch_padded = pad_sequence(X_batch, batch_first=True, padding_value=float("nan"))
T_batch_padded = pad_sequence(X_batch, batch_first=True, padding_value=float("nan"))

In [None]:
X_batch_packed = pack_padded_sequence(
    X_batch_padded, lengths, batch_first=True, enforce_sorted=False
)
T_batch_packed = pack_padded_sequence(
    T_batch_padded, lengths, batch_first=True, enforce_sorted=False
)

In [None]:
T_batch_packed

In [None]:
model(T_batch_packed, X_batch_packed)

In [None]:
from linodenet.models import LinODEnet, LinODECell, LinODE

model = LinODE(370, kernel_regularization="skew-symmetric")
model.to(device=device, dtype=dtype)
print(T.shape, x0.shape)
print(model(T, x0).shape)
print(T_batch.shape, x0_batch.shape)
print(model(T_batch, x0_batch).shape)

In [None]:
from linodenet.models import LinODEnet, LinODECell, LinODE

model = LinODEnet(370, 400, embedding_type="concat")
model.to(device=device, dtype=dtype)
print(T.shape, X.shape)
print(model(T, X).shape)
print(T_batch.shape, X_batch.shape)
print(model(T_batch, X_batch).shape)

In [None]:
T_batch = torch.stack([T_batch, T_batch])
X_batch = torch.stack([X_batch, X_batch])
Δt_batch = torch.stack([Δt_batch, Δt_batch])
ΔT_batch = torch.stack([ΔT_batch, ΔT_batch])
x0_batch = torch.stack([x0_batch, x0_batch])

In [None]:
from linodenet.models import LinODEnet, LinODECell, LinODE

model = LinODECell(370, kernel_regularization="skew-symmetric")
model.to(device=device, dtype=dtype)
print(Δt.shape, x0.shape)
print(model(Δt, x0).shape)
print(Δt_batch.shape, x0_batch.shape)
print(model(Δt_batch, x0_batch).shape)

In [None]:
from linodenet.models import LinODEnet, LinODECell, LinODE

model = LinODE(370, kernel_regularization="skew-symmetric")
model.to(device=device, dtype=dtype)
print(T.shape, x0.shape)
print(model(T, x0).shape)
print(T_batch.shape, x0_batch.shape)
print(model(T_batch, x0_batch).shape)

In [None]:
from linodenet.models import LinODEnet, LinODECell, LinODE

model = LinODEnet(370, 250)
model.to(device=device, dtype=dtype)
print(T.shape, X.shape)
print(model(T, X).shape)
print(T_batch.shape, X_batch.shape)
print(model(T_batch, X_batch).shape)