# Testing whether batch processing works

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import torch

pd.set_option('display.max_rows', 5)

In [3]:
from tsdm.datasets import Electricity
X = Electricity.dataset
x = X.iloc[:100]
t = X.index[:100]
t = (t - t[0]) / np.timedelta64(1, 'h')
x

In [4]:
nsteps = 100
NDIM = 64
BATCH_SIZE = 7
OUTER_BATCH_SIZE = 3
X = np.random.rand(nsteps, NDIM)
T = np.sort(np.random.randn(nsteps))
dtype = torch.float32
device = torch.device('cpu')

In [5]:
X = torch.tensor(X, dtype=dtype, device=device)
T = torch.tensor(T, dtype=dtype, device=device)
ΔT = torch.diff(T)
Δt = ΔT[0]
x0 = X[0]
T_batch  = torch.stack([T]*BATCH_SIZE)
X_batch  = torch.stack([X]*BATCH_SIZE)
Δt_batch = torch.stack([Δt]*BATCH_SIZE)
ΔT_batch = torch.stack([ΔT]*BATCH_SIZE)
x0_batch = torch.stack([x0]*BATCH_SIZE)

In [6]:
from linodenet.models import LinODEnet, LinODECell, LinODE
model = LinODECell(NDIM, kernel_regularization="skew-symmetric")
model.to(device=device, dtype=dtype)
print(Δt.shape, x0.shape)
print(model(Δt, x0).shape)
print(Δt_batch.shape, x0_batch.shape)
print(model(Δt_batch, x0_batch).shape)

In [7]:
from linodenet.models import LinODEnet, LinODECell, LinODE
model = LinODE(NDIM, kernel_regularization='skew-symmetric')
model.to(device=device, dtype=dtype)
print(T.shape, x0.shape)
print(model(T, x0).shape)
print(T_batch.shape, x0_batch.shape)
print(model(T_batch, x0_batch).shape)

In [8]:
from linodenet.models import LinODEnet, LinODECell, LinODE
model = LinODEnet(NDIM, 400, embedding_type='concat')
model.to(device=device, dtype=dtype)
print(T.shape, X.shape)
print(model(T, X).shape)
print(T_batch.shape, X_batch.shape)
print(model(T_batch, X_batch).shape)

In [9]:
T_batch  = torch.stack([T_batch ]*OUTER_BATCH_SIZE)
X_batch  = torch.stack([X_batch ]*OUTER_BATCH_SIZE)
Δt_batch = torch.stack([Δt_batch]*OUTER_BATCH_SIZE)
ΔT_batch = torch.stack([ΔT_batch]*OUTER_BATCH_SIZE)
x0_batch = torch.stack([x0_batch]*OUTER_BATCH_SIZE)

In [10]:
from linodenet.models import LinODEnet, LinODECell, LinODE
model = LinODECell(NDIM, kernel_regularization="skew-symmetric")
model.to(device=device, dtype=dtype)
print(Δt.shape, x0.shape)
print(model(Δt, x0).shape)
print(Δt_batch.shape, x0_batch.shape)
print(model(Δt_batch, x0_batch).shape)

In [11]:
from linodenet.models import LinODEnet, LinODECell, LinODE
model = LinODE(NDIM, kernel_regularization='skew-symmetric')
model.to(device=device, dtype=dtype)
print(T.shape, x0.shape)
print(model(T, x0).shape)
print(T_batch.shape, x0_batch.shape)
print(model(T_batch, x0_batch).shape)

In [12]:
from linodenet.models import LinODEnet, LinODECell, LinODE
model = LinODEnet(NDIM, 250)
model.to(device=device, dtype=dtype)
print(T.shape, X.shape)
print(model(T, X).shape)
print(T_batch.shape, X_batch.shape)
print(model(T_batch, X_batch).shape)