# Testing whether batch processing works

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["PYTORCH_JIT"] = "1"

In [None]:
import numpy as np
import pandas as pd
import torch
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

pd.set_option("display.max_rows", 5)

In [None]:
from tsdm.datasets import Electricity

X = Electricity.dataset
x = X.iloc[:100]
t = X.index[:100]
t = (t - t[0]) / np.timedelta64(1, "h")
x

In [None]:
nsteps = 100
NDIM = 64
BATCH_SIZE = 7
OUTER_BATCH_SIZE = 3
X = np.random.rand(nsteps, NDIM)
T = np.sort(np.random.randn(nsteps))
dtype = torch.float32
device = torch.device("cpu")

In [None]:
X = torch.tensor(X, dtype=dtype, device=device)
T = torch.tensor(T, dtype=dtype, device=device)
ΔT = torch.diff(T)
Δt = ΔT[0]
x0 = X[0]
T_batch = torch.stack([T] * BATCH_SIZE)
X_batch = torch.stack([X] * BATCH_SIZE)
Δt_batch = torch.stack([Δt] * BATCH_SIZE)
ΔT_batch = torch.stack([ΔT] * BATCH_SIZE)
x0_batch = torch.stack([x0] * BATCH_SIZE)

In [None]:
def save_graph(model, inputs):
    with SummaryWriter(
        f"runs/{model.__class__.__name__}/{datetime.now().isoformat(timespec='seconds')}"
    ) as writer:
        writer.add_graph(model, inputs, verbose=True)

In [None]:
from linodenet.models import LinearContraction

model = LinearContraction(NDIM, 17)
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (x0, model(x0))])
print([tuple(w.shape) for w in (x0_batch, model(x0_batch))])
print([tuple(w.shape) for w in (X, model(X))])
print([tuple(w.shape) for w in (X_batch, model(X_batch))])
save_graph(model, x0)

In [None]:
from linodenet.models import iResNetBlock

model = iResNetBlock(NDIM)
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (x0, model(x0))])
print([tuple(w.shape) for w in (x0_batch, model(x0_batch))])
print([tuple(w.shape) for w in (X, model(X))])
print([tuple(w.shape) for w in (X_batch, model(X_batch))])
save_graph(model, x0)

In [None]:
from linodenet.models import iResNet

model = iResNet(NDIM)
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (x0, model(x0))])
print([tuple(w.shape) for w in (x0_batch, model(x0_batch))])
print([tuple(w.shape) for w in (X, model(X))])
print([tuple(w.shape) for w in (X_batch, model(X_batch))])
save_graph(model, x0)

In [None]:
from linodenet.models import LinODECell

model = LinODECell(NDIM, kernel_projection="skew-symmetric")
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (Δt, x0, model(Δt, x0))])
print([tuple(w.shape) for w in (Δt_batch, x0_batch, model(Δt_batch, x0_batch))])
save_graph(model, (Δt, x0))

In [None]:
from linodenet.models import LinODE

model = LinODE(NDIM, kernel_projection="skew-symmetric")
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (T, x0, model(T, x0))])
print([tuple(w.shape) for w in (T_batch, x0_batch, model(T_batch, x0_batch))])
save_graph(model, (T, x0))

In [None]:
from linodenet.models import LinODEnet

model = LinODEnet(NDIM, 400, embedding_type="concat")
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (T, X, model(T, X))])
print([tuple(w.shape) for w in (T_batch, X_batch, model(T_batch, X_batch))])
save_graph(model, (Δt, X))

## Multiple Batch Dimensions

In [None]:
T_batch = torch.stack([T_batch] * OUTER_BATCH_SIZE)
X_batch = torch.stack([X_batch] * OUTER_BATCH_SIZE)
Δt_batch = torch.stack([Δt_batch] * OUTER_BATCH_SIZE)
ΔT_batch = torch.stack([ΔT_batch] * OUTER_BATCH_SIZE)
x0_batch = torch.stack([x0_batch] * OUTER_BATCH_SIZE)

In [None]:
from linodenet.models import LinearContraction

model = LinearContraction(NDIM, 17)
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (x0, model(x0))])
print([tuple(w.shape) for w in (x0_batch, model(x0_batch))])
print([tuple(w.shape) for w in (X, model(X))])
print([tuple(w.shape) for w in (X_batch, model(X_batch))])
save_graph(model, x0)

In [None]:
from linodenet.models import iResNetBlock

model = iResNetBlock(NDIM)
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (x0, model(x0))])
print([tuple(w.shape) for w in (x0_batch, model(x0_batch))])
print([tuple(w.shape) for w in (X, model(X))])
print([tuple(w.shape) for w in (X_batch, model(X_batch))])
save_graph(model, x0)

In [None]:
from linodenet.models import iResNet

model = iResNet(NDIM)
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (x0, model(x0))])
print([tuple(w.shape) for w in (x0_batch, model(x0_batch))])
print([tuple(w.shape) for w in (X, model(X))])
print([tuple(w.shape) for w in (X_batch, model(X_batch))])
save_graph(model, x0)

In [None]:
from linodenet.models import LinODECell

model = LinODECell(NDIM, kernel_projection="skew-symmetric")
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (Δt, x0, model(Δt, x0))])
print([tuple(w.shape) for w in (Δt_batch, x0_batch, model(Δt_batch, x0_batch))])
save_graph(model, (Δt, x0))

In [None]:
from linodenet.models import LinODE

model = LinODE(NDIM, kernel_projection="skew-symmetric")
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (T, x0, model(T, x0))])
print([tuple(w.shape) for w in (T_batch, x0_batch, model(T_batch, x0_batch))])
save_graph(model, (T, x0))

In [None]:
from linodenet.models import LinODEnet

model = LinODEnet(NDIM, 250)
model.to(device=device, dtype=dtype)
print([tuple(w.shape) for w in (T, X, model(T, X))])
print([tuple(w.shape) for w in (T_batch, X_batch, model(T_batch, X_batch))])
save_graph(model, (T, X))

In [None]:
list(model.modules())

In [None]:
dir(model)

In [None]:
help(torch.nn.Linear(3, 4).register_forward_pre_hook)

In [None]:
torch.nn.Linear(3, 4).register_forward_pre_hook

In [None]:
torch.nn.Linear(3, 4).register_forward_pre_hook