# Optimizing Performance by using torchscript to jit-compile ODE model

We make use of the details provided at https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

In [None]:
%config InlineBackend.figure_format = 'svg'

In [None]:
import torch
import torchdiffeq
from torch import nn
from torch.nn import GRUCell
import numpy as np
from opt_einsum import contract
from tqdm.auto import trange
from typing import Union, Callable
import scipy
from scipy import stats
import matplotlib.pyplot as plt
from scipy.integrate import odeint

In [None]:
plt.rc("text", usetex=True)
plt.rc("text.latex", preamble=r"\usepackage{{amsmath}}")

In [None]:
def scaled_Lp(x, p=2):
    x = np.abs(x)
    if p == 0:
        # https://math.stackexchange.com/q/282271/99220
        return stats.gmean(x, axis=None)
    elif p == 1:
        return np.mean(x)
    elif p == 2:
        return np.sqrt(np.mean(x**2))
    elif p == np.inf:
        return np.max(x)
    else:
        x = x.astype(np.float128)
        return np.mean(x**p) ** (1 / p)

In [None]:
def visualize_distribution(x, bins=50, log=True, ax=None):
    x = np.array(x)
    nans = np.isnan(x)
    x = x[~nans]

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6), tight_layout=True)

    ax.grid(axis="x")
    ax.set_axisbelow(True)

    if log:
        z = np.log10(x)
        ax.set_xscale("log")
        ax.set_yscale("log")
        low = np.floor(np.quantile(z, 0.01))
        high = np.quantile(z, 0.99)
        x = x[(z >= low) & (z <= high)]
        bins = np.logspace(low, high, num=bins, base=10)
    ax.hist(x, bins=bins, density=True)
    ax.text(
        0.975,
        0.975,
        r"\begin{tabular}{ll}"
        + f"NaNs   & {100*np.mean(nans):.2f}\%"
        + r" \\ "
        + f"Mean   & {np.mean(x):.2e}"
        + r" \\ "
        + f"Median & {np.median(x):.2e}"
        + r" \\ "
        + f"Mode   & {stats.mode(x)[0][0]:.2e}"
        + r" \\ "
        + f"stdev  & {np.std(x):.2e}"
        + r" \\ "
        + r"\end{tabular}",
        transform=ax.transAxes,
        va="top",
        ha="right",
        snap=True,
    )

In [None]:
class LinODECell(torch.jit.ScriptModule):
    """
    Linear System module

    x' = Ax + Bu + w
     y = Cx + Du + v

    """

    def __init__(
        self,
        input_size,
        kernel_initialization: Union[torch.Tensor, Callable[int, torch.Tensor]] = None,
        homogeneous: bool = True,
        matrix_type: str = None,
        device=torch.device("cpu"),
        dtype=torch.float32,
    ):
        """
        kernel_initialization: torch.tensor or callable
            either a tensor to assign to the kernel at initialization
            or a callable f: int -> torch.Tensor|L
        """
        super(LinODECell, self).__init__()

        if kernel_initialization is None:
            self.kernel_initialization = lambda: torch.randn(
                input_size, input_size
            ) / np.sqrt(input_size)
        elif callable(kernel_initialization):
            self.kernel = lambda: torch.tensor(kernel_initialization(input_size))
        else:
            self.kernel_initialization = lambda: torch.tensor(kernel_initialization)

        self.kernel = nn.Parameter(self.kernel_initialization())

        if not homogeneous:
            self.bias = nn.Parameter(torch.randn(input_size))
            raise NotImplementedError("Inhomogeneous Linear Model not implemented yet.")

        self.to(device=device, dtype=dtype)

    @torch.jit.script_method
    def forward(self, Δt, x):
        """
        Inputs:
        Δt: (...,)
        x:  (..., M)

        Outputs:
        xhat:  (..., M)


        Forward using matrix exponential
        # TODO: optimize if clauses away by changing definition in constructor.
        """

        AΔt = torch.einsum("kl, ... -> ...kl", self.kernel, Δt)
        expAΔt = torch.matrix_exp(AΔt)
        xhat = torch.einsum("...kl, ...l -> ...k", expAΔt, x)

        return xhat

In [None]:
class LinODE(torch.jit.ScriptModule):
    def __init__(
        self,
        input_size,
        kernel_initialization: Union[torch.Tensor, Callable[int, torch.Tensor]] = None,
        homogeneous: bool = True,
        matrix_type: str = None,
        device=torch.device("cpu"),
        dtype=torch.float32,
    ):
        """
        kernel_initialization: torch.tensor or callable
            either a tensor to assign to the kernel at initialization
            or a callable f: int -> torch.Tensor|L
        """
        super(LinODE, self).__init__()

        if kernel_initialization is None:
            self.kernel_initialization = lambda: torch.randn(
                input_size, input_size
            ) / np.sqrt(input_size)
        elif callable(kernel_initialization):
            self.kernel = lambda: torch.tensor(kernel_initialization(input_size))
        else:
            self.kernel_initialization = lambda: torch.tensor(kernel_initialization)

        self.kernel = nn.Parameter(self.kernel_initialization())

        if not homogeneous:
            self.bias = nn.Parameter(torch.randn(input_size))
            raise NotImplementedError("Inhomogeneous Linear Model not implemented yet.")

        self.to(device=device, dtype=dtype)

    @torch.jit.script_method
    def forward(self, x0, T):
        # type: (Tensor, Tensor) -> Tensor
        ΔT = T - T[0]
        AΔT = torch.einsum("kl, ... -> ...kl", self.kernel, ΔT)
        expAΔT = torch.matrix_exp(AΔT)
        Xhat = torch.einsum("...kl, ...l -> ...k", expAΔT, x0)

        return Xhat

In [None]:
class LinODE(torch.jit.ScriptModule):
    def __init__(self, *cell_args, **cell_kwargs):
        super(LinODE, self).__init__()
        self.cell = LinODECell(*cell_args, **cell_kwargs)

    @torch.jit.script_method
    def forward(self, x0, T):
        # type: (Tensor, Tensor) -> Tensor

        ΔT = torch.diff(T)

        xhat = torch.empty((len(T), len(x0)))
        xhat[0] = x0

        ret = x0

        for i, Δt in enumerate(ΔT):
            ret = self.cell(Δt, ret)
            xhat[i + 1] = ret
        return xhat


#         x = torch.jit.annotate(List[Tensor], [])
#         x += [x0]

#         for i, Δt in enumerate(ΔT):
#             x += [self.cell(Δt, x[-1])]

#         return torch.stack(x)

In [None]:
def test_LinODE(
    dim=None,
    num=None,
    tol=1e-3,
    precision="single",
    relative_error=True,
    device=torch.device("cpu"),
):
    if precision == "single":
        eps = 2**-24
        numpy_dtype = np.float32
        torch_dtype = torch.float32
    elif precision == "double":
        eps = 2**-53
        numpy_dtype = np.float64
        torch_dtype = torch.float64
    else:
        raise ValueError

    num = np.random.randint(low=20, high=1000) or num
    dim = np.random.randint(low=2, high=100) or dim
    t0, t1 = np.random.uniform(low=-10, high=10, size=(2,)).astype(numpy_dtype)
    A = np.random.randn(dim, dim).astype(numpy_dtype)
    x0 = np.random.randn(dim).astype(numpy_dtype)
    T = np.random.uniform(low=t0, high=t1, size=num - 2).astype(numpy_dtype)
    T = np.sort([t0, *T, t1]).astype(numpy_dtype)
    func = lambda t, x: A @ x

    X = odeint(func, x0, T, tfirst=True)

    model = LinODE(
        input_size=dim, kernel_initialization=A, dtype=torch_dtype, device=device
    )
    Xhat = model(
        torch.tensor(x0, dtype=torch_dtype, device=device),
        torch.tensor(T, dtype=torch_dtype, device=device),
    )
    Xhat = Xhat.detach().cpu().numpy()

    err = np.abs(X - Xhat)

    if relative_error:
        err /= np.abs(X) + eps

    return np.array([scaled_Lp(err, p=p) for p in (1, 2, np.inf)])

## Standalone Speed Test

How long does it take to integrate ODE?

In [None]:
numpy_dtype = np.float32
num = 1000
dim = 100
t0, t1 = np.random.uniform(low=-10, high=10, size=(2,)).astype(numpy_dtype)
A = np.random.randn(dim, dim).astype(numpy_dtype)
x0 = np.random.randn(dim).astype(numpy_dtype)
T = np.random.uniform(low=t0, high=t1, size=num - 2).astype(numpy_dtype)
T = np.sort([t0, *T, t1]).astype(numpy_dtype)
y = np.random.randn(dim).astype(numpy_dtype)
func = lambda t, x: A @ x

# torch cpu setup
torch_dtype = torch.float32
device = torch.device("cpu")
A_cpu = torch.nn.Parameter(torch.from_numpy(A)).to(dtype=torch_dtype, device=device)
func_cpu = lambda t, x: A_cpu @ x
T_cpu = torch.tensor(T).to(dtype=torch_dtype, device=device)
x0_cpu = torch.tensor(x0).to(dtype=torch_dtype, device=device)
y_cpu = torch.tensor(y).to(dtype=torch_dtype, device=device)
model_cpu = LinODE(input_size=dim, kernel_initialization=A).to(
    dtype=torch_dtype, device=device
)

# torch gpu setup
device = torch.device("cuda")
A_gpu = torch.nn.Parameter(torch.from_numpy(A)).to(dtype=torch_dtype, device=device)
func_gpu = lambda t, x: A_gpu @ x
T_gpu = torch.tensor(T).to(dtype=torch_dtype, device=device)
x0_gpu = torch.tensor(x0).to(dtype=torch_dtype, device=device)
y_gpu = torch.tensor(y).to(dtype=torch_dtype, device=device)
model_gpu = LinODE(input_size=dim, kernel_initialization=A).to(
    dtype=torch_dtype, device=device
)

In [None]:
yhat = torchdiffeq.odeint(func_cpu, x0_cpu, T_cpu)
r = torch.linalg.norm(yhat[-1] - y_cpu)
r.backward()

yhat = model_cpu(x0_cpu, T_cpu)
r = torch.linalg.norm(yhat[-1] - y_cpu)
r.backward()

grad = list(model_cpu.parameters())[0].grad

err = torch.abs(A_cpu.grad - grad)
relerr = err / (torch.abs(grad) + 2**-24)
print(torch.mean(err))
print(torch.mean(relerr))

### scipy on cpu

In [None]:
%%timeit
odeint(func, x0, T, tfirst=True)

### torch on cpu

In [None]:
%%timeit
y = model_cpu(x0_cpu, T_cpu)
r = torch.linalg.norm(y)
r.backward()

### torchdiffeq on cpu

In [None]:
%%timeit
y = torchdiffeq.odeint(func_cpu, x0_cpu, T_cpu)
r = torch.linalg.norm(y)
r.backward()

### torch on gpu

In [None]:
%%timeit
y = model_gpu(x0_gpu, T_gpu)
r = torch.linalg.norm(y)
r.backward()

### torchdiffeq on gpu

In [None]:
%%timeit
y = torchdiffeq.odeint(func_gpu, x0_gpu, T_gpu)
r = torch.linalg.norm(y)
r.backward()

## Matrix Exponential comparison

In [None]:
%%timeit
scipy.linalg.expm(A)

In [None]:
%%timeit
torch.matrix_exp(A_cpu)

In [None]:
%%timeit
torch.matrix_exp(A_gpu)

## Forward pass only

In [None]:
%%timeit
odeint(func, x0, T, tfirst=True)

In [None]:
%%timeit
ΔT = torch.diff(T_cpu)
AΔT = torch.einsum("kl, ... -> ...kl", A_cpu, ΔT)
expAΔT = torch.matrix_exp(AΔT)
Xhat = torch.einsum("...kl, ...l -> ...k", expAΔT, x0_cpu)

In [None]:
%%timeit
Xhat = model_cpu(x0_cpu, T_cpu)