In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import logging

logging.basicConfig(level=logging.DEBUG)

In [None]:
import linodenet

In [None]:
from math import sqrt

import torch
from torch import Tensor, jit

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
B = 48  # batch size
L = 256  # sequence length
N = 128  # latent size

T = torch.rand(B, L).to(device=DEVICE)
Z = torch.randn(B, L, N).to(device=DEVICE)
A = torch.randn(N, N).to(device=DEVICE) / sqrt(N)

In [None]:
@jit.script
def forward_batch(T: Tensor, A: Tensor, Z: Tensor) -> Tensor:
    At = torch.einsum("..., mn -> ...mn", T, A)
    expAt = torch.linalg.matrix_exp(At)
    expAtz = torch.einsum("...mn, ...n -> ...m", expAt, Z)
    return expAtz


@jit.script
def forward_loop(T: Tensor, A: Tensor, Z: Tensor) -> Tensor:
    T = T.moveaxis(0, -1)  # (..., LEN) - > (LEN, ...)
    Z = Z.moveaxis(0, -2)
    y_list: list[Tensor] = []

    for t, z in zip(T, Z):  # iterate over LEN
        At = torch.einsum("..., mn -> ...mn", t, A)
        expAt = torch.linalg.matrix_exp(At)
        expAtz = torch.einsum("...mn, ...n -> ...m", expAt, Z)
        y_list.append(expAtz)

    y = torch.cat(y_list).moveaxis(0, -2)

    return y

In [None]:
%%timeit
forward_batch(T, A, Z);

In [None]:
%%timeit
forward_loop(T, A, Z);