# Optimizing Performance by using torchscript to jit-compile ODE model

This should be ~50% faster, even on CPU

We make use of the details provided at https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

In [None]:
%config InlineBackend.figure_format = 'svg'

In [None]:
import torch
import torchdiffeq
from torch import nn, Tensor
from torch.nn import GRUCell
import numpy as np
from opt_einsum import contract
from tqdm.auto import trange
from typing import Union, Callable
from scipy import stats
import matplotlib.pyplot as plt
from scipy.integrate import odeint

In [None]:
from tsdm.util import scaled_norm, relative_error, timefun

In [None]:
n = 100
A = torch.normal(mean=torch.zeros(n, n), std=1 / np.sqrt(n))
x = torch.normal(mean=torch.zeros(n), std=1)

In [None]:
torch.std(A @ x), torch.std((A + A.T) / np.sqrt(2) @ x), torch.std(
    (A - A.T) / np.sqrt(2) @ x
)

In [None]:
class LinODECell(torch.jit.ScriptModule):
    """
    Linear System module

    x' = Ax + Bu + w
     y = Cx + Du + v

    """

    def __init__(
        self,
        input_size,
        kernel_initialization: Union[
            torch.Tensor, Callable[int, torch.Tensor], str
        ] = None,
        homogeneous: bool = True,
        matrix_type: str = None,
        device=torch.device("cpu"),
        dtype=torch.float32,
    ):
        """
        kernel_initialization: torch.tensor, callable
            either a tensor to assign to the kernel at initialization
            or a callable f: int -> torch.Tensor|L
            "skew_symmetric"
            "symmetric"
            "normal"
        """
        super(LinODECell, self).__init__()

        if kernel_initialization is None:
            self.kernel_initialization = lambda: torch.randn(
                input_size, input_size
            ) / np.sqrt(input_size)
        elif callable(kernel_initialization):
            self.kernel = lambda: torch.tensor(kernel_initialization(input_size))
        elif type(kernel_initialization) == torch.Tensor:
            self._kernel_initialization = kernel_initialization.clone().detach()
            self.kernel_initialization = lambda: self._kernel_initialization.clone()
        else:
            self.kernel_initialization = lambda: torch.tensor(kernel_initialization)

        self.kernel = nn.Parameter(self.kernel_initialization())

        if not homogeneous:
            self.bias = nn.Parameter(torch.randn(input_size))
            raise NotImplementedError("Inhomogeneous Linear Model not implemented yet.")

        self.to(device=device, dtype=dtype)

    @torch.jit.script_method
    def forward(self, Δt, x):
        """
        Inputs:
        Δt: (...,)
        x:  (..., M)

        Outputs:
        xhat:  (..., M)


        Forward using matrix exponential
        # TODO: optimize if clauses away by changing definition in constructor.
        """

        AΔt = torch.einsum("kl, ... -> ...kl", self.kernel, Δt)
        expAΔt = torch.matrix_exp(AΔt)
        xhat = torch.einsum("...kl, ...l -> ...k", expAΔt, x)

        return xhat

In [None]:
class LinODEv2(nn.Module):
    def __init__(self, *cell_args, **cell_kwargs):
        super(LinODEv2, self).__init__()
        self.cell = LinODECell(*cell_args, **cell_kwargs)
        self.kernel = self.cell.kernel

    def forward(self, x0: Tensor, T: Tensor) -> Tensor:
        return self.__forward__(x0, T)

    def __forward__(self, x0, T):
        r"""
        Propagate x0

        Parameters
        ----------
        x0: :class:`torch.Tensor`
        T: :class:`torch.Tensor`

        Returns
        -------
        Xhat: :class:`torch.Tensor`
        """
        ΔT = torch.diff(T)
        x = []
        x += [x0]

        for Δt in ΔT:
            x += [self.cell(Δt, x[-1])]

        return torch.stack(x)

In [None]:
class LinODE(torch.jit.ScriptModule):
    def __init__(self, *cell_args, **cell_kwargs):
        super(LinODE, self).__init__()
        self.cell = LinODECell(*cell_args, **cell_kwargs)
        self.kernel = self.cell.kernel

    @torch.jit.script_method
    def forward(self, x0: Tensor, T: Tensor) -> Tensor:

        ΔT = torch.diff(T)
        x = torch.empty((len(T), len(x0)), dtype=x0.dtype, device=x0.device)
        x[0] = x0
        z = x0

        for i, Δt in enumerate(ΔT):
            z = self.cell(Δt, z)
            x[i + 1] = z

        return x

# Testing the Forward
We compare against `scipy.integrate.odeint`

In [None]:
def test_LinODE(dim=None, num=None, precision="single", device="cpu"):
    numpy_dtype = {"single": np.float32, "double": np.float64}[precision]
    torch_dtype = {"single": torch.float32, "double": torch.float64}[precision]
    num = np.random.randint(low=20, high=1000) or num
    dim = np.random.randint(low=2, high=100) or dim
    t0, t1 = np.random.uniform(low=-10, high=10, size=(2,)).astype(numpy_dtype)
    A = np.random.randn(dim, dim).astype(numpy_dtype)
    x0 = np.random.randn(dim).astype(numpy_dtype)
    T = np.random.uniform(low=t0, high=t1, size=num - 2).astype(numpy_dtype)
    T = np.sort([t0, *T, t1]).astype(numpy_dtype)
    cond = np.linalg.cond(A)
    spec = np.linalg.norm(A, ord=2)

    A = torch.tensor(A, requires_grad=True, device=device, dtype=torch_dtype)
    x0 = torch.Tensor(x0).to(device=device, dtype=torch_dtype)
    T = torch.Tensor(T).to(device=device, dtype=torch_dtype)
    model = LinODEv2(input_size=dim, kernel_initialization=A)
    model = model.to(dtype=torch_dtype, device=device)

    @timefun
    def matexp_loss(model, x0, T):
        X = model(x0, T)
        r = torch.linalg.norm(X)
        r.backward()
        return X, r

    _, matexp_time = matexp_loss(model, x0, T)
    return matexp_time

## Checking LinODE error

We compare results from our LinODE against scipy's odeint, averaged across different number of dimensions.

In [None]:
err_single = np.array([test_LinODE() for _ in trange(1_000)]).T

In [None]:
err_single_cuda = np.array([test_LinODE(device="cuda") for _ in trange(1_000)]).T

In [None]:
np.median(err_single_cuda), err_single_cuda.mean()

In [None]:
np.median(err_single), err_single.mean(), np.median(
    err_single_cuda
), err_single_cuda.mean()

In [None]:
err_double = np.array([test_LinODE(precision="double") for _ in trange(1_000)]).T

In [None]:
with plt.style.context("bmh"):
    fig, ax = plt.subplots(
        ncols=3, nrows=2, figsize=(10, 5), tight_layout=True, sharey=True, sharex=True
    )

for i, err in enumerate((err_single, err_double)):
    for j, p in enumerate((1, 2, np.inf)):
        visualize_distribution(err[j], log=True, ax=ax[i, j])
        if j == 0:
            ax[i, 0].annotate(
                f"FP{32*(i+1)}",
                xy=(0, 0.5),
                xytext=(-ax[i, 0].yaxis.labelpad - 5, 0),
                xycoords=ax[i, 0].yaxis.label,
                textcoords="offset points",
                size="xx-large",
                ha="right",
                va="center",
            )
        if i == 1:
            ax[i, j].set_xlabel(f"scaled, relative L{p} error")

fig.savefig("linode_error_plot_torchscript.svg")

# Testing the Backward

We compare against `torchdiffeq.odeint`

In [None]:
from torchdiffeq import odeint_adjoint as odeint
import gc
from time import time
from collections import namedtuple
from pandas import DataFrame, MultiIndex
import pandas

In [None]:
def test_LinODE(dim=None, num=None, precision="double", device="cpu"):
    numpy_dtype = {"single": np.float32, "double": np.float64}[precision]
    torch_dtype = {"single": torch.float32, "double": torch.float64}[precision]
    num = np.random.randint(low=20, high=1000) or num
    dim = np.random.randint(low=2, high=100) or dim
    t0, t1 = np.random.uniform(low=-10, high=10, size=(2,)).astype(numpy_dtype)
    A = np.random.randn(dim, dim).astype(numpy_dtype)
    x0 = np.random.randn(dim).astype(numpy_dtype)
    T = np.random.uniform(low=t0, high=t1, size=num - 2).astype(numpy_dtype)
    T = np.sort([t0, *T, t1]).astype(numpy_dtype)
    cond = np.linalg.cond(A)
    spec = np.linalg.norm(A, ord=2)

    A = torch.tensor(A, requires_grad=True, device=device, dtype=torch_dtype)
    x0 = torch.Tensor(x0).to(device=device, dtype=torch_dtype)
    T = torch.Tensor(T).to(device=device, dtype=torch_dtype)
    model = LinODE(
        input_size=dim, kernel_initialization=A, dtype=torch_dtype, device=device
    )

    def odeint_loss(A, x0, T):
        X = odeint(lambda t, x: A @ x, x0, T, adjoint_params=(A,))
        r = torch.linalg.norm(X)
        return X, r

    def matexp_loss(model, x0, T):
        X = model(x0, T)
        r = torch.linalg.norm(X)
        return X, r

    result_odeint, odeint_ftime = timefun(odeint_loss)(A, x0, T)
    result_matexp, matexp_ftime = timefun(matexp_loss)(model, x0, T)

    if not np.isnan(odeint_ftime):
        X, r = result_odeint
        _, odeint_btime = timefun(r.backward)()
    else:
        odeint_btime = float("nan")

    if not np.isnan(matexp_ftime):
        Xhat, rhat = result_matexp
        _, matexp_btime = timefun(rhat.backward)()
    else:
        matexp_btime = float("nan")

    if not any(np.isnan([odeint_ftime, matexp_ftime])):
        X, r = result_odeint
        Xhat, rhat = result_matexp
        fward_error = relative_error(Xhat, X)
        fward_errors = [float(scaled_norm(fward_error, p=p)) for p in (1, 2, np.inf)]
    else:
        fward_errors = [float("nan")] * 3

    if not any(np.isnan([odeint_btime, matexp_btime])):
        G, Ghat = A.grad, model.kernel.grad
        bward_error = relative_error(Ghat, G)
        bward_errors = [float(scaled_norm(bward_error, p=p)) for p in (1, 2, np.inf)]
    else:
        bward_errors = [float("nan")] * 3

    columns = pandas.MultiIndex.from_product(
        [
            ["forward", "backward"],
            ["1-norm", "2-norm", "inf-norm", "time (odeint)", "time (matexp)"],
        ]
    )
    data = np.array(
        [
            [
                *fward_errors,
                odeint_ftime,
                matexp_ftime,
                *bward_errors,
                odeint_btime,
                matexp_btime,
            ]
        ]
    )
    df = pandas.DataFrame(data, columns=columns)
    df["num"] = num
    df["dim"] = dim
    df["cond"] = cond
    df["spec"] = spec

    return df

In [None]:
df = test_LinODE(precision="single", device="cuda")
df.to_csv("forward_backward_error_fp32.csv", mode="w", index=False)

In [None]:
for k in trange(1000):
    try:
        df = test_LinODE(precision="double", device="cuda")
    #         df.to_csv("forward_backward_error_fp64.csv", mode='a', header=False, index=False)
    except Exception:
        continue