# Optimizing Performance by using torchscript to jit-compile ODE model

We make use of the details provided at https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

In [1]:
%config InlineBackend.figure_format = 'svg'

In [2]:
import torch
import torchdiffeq
from torch import nn
from torch.nn import GRUCell
import numpy as np
from opt_einsum import contract
from tqdm.auto import trange
from typing import Union, Callable
from scipy import stats
import matplotlib.pyplot as plt
from scipy.integrate import odeint

In [3]:
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{{amsmath}}')

In [4]:
def scaled_Lp(x, p=2):
    x = np.abs(x)
    if p==0:
        # https://math.stackexchange.com/q/282271/99220
        return stats.gmean(x, axis=None)
    elif p==1:
        return np.mean(x)
    elif p==2:
        return np.sqrt(np.mean(x**2))
    elif p==np.inf:
        return np.max(x)
    else:
        x = x.astype(np.float128)
        return np.mean(x**p)**(1/p)

In [5]:
def visualize_distribution(x, bins=50, log=True, ax=None):
    x = np.array(x)
    nans = np.isnan(x)
    x = x[~nans]

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6), tight_layout=True)
    
    ax.grid(axis='x')
    ax.set_axisbelow(True)

    if log:
        z = np.log10(x)
        ax.set_xscale('log')
        ax.set_yscale('log')
        low = np.floor(np.quantile(z, 0.01))
        high = np.quantile(z, 0.99)
        x = x[(z>=low) & (z<=high)]
        bins = np.logspace(low, high, num=bins, base=10)
    ax.hist(x, bins=bins, density=True)
    ax.text(.975, .975, 
       r"\begin{tabular}{ll}" \
            + F"NaNs   & {100*np.mean(nans):.2f}\%"   + r" \\ " \
            + F"Mean   & {np.mean(x):.2e}"          + r" \\ " \
            + F"Median & {np.median(x):.2e}"        + r" \\ " \
            + F"Mode   & {stats.mode(x)[0][0]:.2e}" + r" \\ " \
            + F"stdev  & {np.std(x):.2e}"           + r" \\ " \
        + r"\end{tabular}",
        transform=ax.transAxes, va='top', ha='right', snap=True)

In [6]:
class LinODECell(torch.jit.ScriptModule):
    """
    Linear System module
    
    x' = Ax + Bu + w
     y = Cx + Du + v
    
    """

    def __init__(self, input_size, 
                 kernel_initialization: Union[torch.Tensor, Callable[int, torch.Tensor]] = None, 
                 homogeneous: bool =True, 
                 matrix_type: str =None,
                 device=torch.device('cpu'),
                 dtype=torch.float32,
                ):
        """
        kernel_initialization: torch.tensor or callable
            either a tensor to assign to the kernel at initialization
            or a callable f: int -> torch.Tensor|L
        """
        super(LinODECell, self).__init__()
        
        
        if kernel_initialization is None:
            self.kernel_initialization = lambda: torch.randn(input_size, input_size)/np.sqrt(input_size)
        elif callable(kernel_initialization):
            self.kernel = lambda: torch.tensor(kernel_initialization(input_size))  
        else:
            self.kernel_initialization = lambda: torch.tensor(kernel_initialization)

        self.kernel = nn.Parameter(self.kernel_initialization())    
        
        if not homogeneous:
            self.bias = nn.Parameter(torch.randn(input_size))
            raise NotImplementedError("Inhomogeneous Linear Model not implemented yet.")
            
        self.to(device=device, dtype=dtype)
       
    @torch.jit.script_method
    def forward(self, Δt, x):nput_size, hidden_size, 
        """
        Inputs:
        Δt: (...,)
        x:  (..., M)
        
        Outputs:
        xhat:  (..., M)
        
        
        Forward using matrix exponential
        # TODO: optimize if clauses away by changing definition in constructor.
        """
        
        AΔt    = torch.einsum('kl, ... -> ...kl', self.kernel, Δt)
        expAΔt = torch.matrix_exp(AΔt)
        xhat   = torch.einsum('...kl, ...l -> ...k', expAΔt, x)

        return xhat

In [7]:
class LinODE(torch.jit.ScriptModule):
    def __init__(self, *cell_args, **cell_kwargs):
        super(LinODE, self).__init__()
        self.cell = LinODECell(*cell_args, **cell_kwargs)

    @torch.jit.script_method
    def forward(self, x0, T):
        # type: (Tensor, Tensor) -> Tensor

        ΔT = torch.diff(T)     
        x = torch.jit.annotate(List[Tensor], [])    
        x += [x0]
    
        for i, Δt in enumerate(ΔT):
            x += [self.cell(Δt, x[-1])]
            
        return torch.stack(x)

In [None]:
HP = {
    # Size of the latent state
    'n_ode_gru_dims' : 6,
    # Number of layers in ODE func in recognition ODE
    'n_layers' : 1,
    # Number of units per layer in ODE func
    'n_units' : 100,
    # nonlinearity used
    'nonlinear' : nn.Tanh,
    #
    'concat_mask' : True,
    # dimensionality of input
    'input_dim' : n_dim,
    # device: 'cpu' or 'cuda'
    'device' : torch.device('cpu'),
    # Number of units per layer in each of GRU update networks
    'n_gru_units' : 100,
    # measurement error
    'obsrv_std' : 0.01,
    #
    'use_binary_classif' : False,
    #
    'train_classif_w_reconstr' : False,
    #
    'classif_per_tp' :  False,
    # number of outputs
    'n_labels' : 1,
    # relative tolerance of ODE solver
    'odeint_rtol': 1e-3,
    # absolute tolereance of ODE solver
    'odeint_atol': 1e-4,
    # batch_size
    'batch-size' : 50,
    # learn-rate
    'lr': 1e-2,
}

In [15]:
import collections

def deep_update(d: dict, new: dict) -> dict:
    """
    https://stackoverflow.com/a/30655448/9318372
    Update a nested dictionary or similar mapping.
    Modify ``source`` in place.
    """
    for key, value in new.items():
        if isinstance(value, collections.Mapping) and value:
            d[key] = deep_update(d.get(key, {}), value)
        else:
            d[key] = new[key]
    return d


def deep_update_keys(d:dict, **new_kv) -> dict:
    """
    Overrides values in nested dictionary.
    For each key in a leaf dictionary (in the dictionary tree),
    if the key appears in the new dict, its value is plugged in instead.
    """
    for key, value in d.items():
        if isinstance(value, collections.Mapping) and value:
            d[key] = deep_update_keys(d.get(key, {}), **new_kv)
        elif key in new_kv:
            d[key] = new_kv[key]
    return d


### Example

In [24]:
HP = {
    'GRUCell': {'input_size':None, 'bias' : True, 'hidden_size' : None},
    'LinODE' : {'input_size':None, 'hidden_size': None, 'initialization': None},
}
deep_update_keys(HP['LinODE'], input_size=10)
HP

In [21]:
HP = {
    'GRUCell': {'input_size':None, 'bias' : True, 'hidden_size' : None},
    'LinODE' : {'input_size':None, 'hidden_size': None, 'initialization': None},
}
new_HP = {
    'GRUCell': {'input_size':10, 'bias' : False, 'hidden_size' : 20},
    'LinODE' : {'input_size':10},
}
deep_update(HP, new_HP)
HP

In [None]:
class LinODE_RNN(torch.jit.ScriptModule):
    HP = {
        'input_size'  : None,
        'hidden_size' : None,
        'GRUCell'     : {'input_size' : None, 'bias' : True, 'hidden_size' : None},
        'LinODE'      : {
            'input_size' : None, 
            'initialization': None},
        'Decoder'     : {
            'input_size': None, 
            'hidden_layers': 1,  
            'activation': 'tanh', 
            'units': None, 
            'output_size': None
        }
    }
    
    def __init__(self, input_size, HP = {}):
        super(LinODE_RNN, self).__init__()
        
        # Setup default hyperparameters
        hidden_size = 2*input_size
        
        deep_update_keys(self.HP, input_size=input_size, hidden_size=2*input_size)
        
        self.HP['LinODE']['input_size'] = HP['hidden_size']
        self.HP['Decoder']['input_size'] = HP['hidden_size']
        
        deep_update_keys(self.HP['decoder'], {
            'input_size': hidden_size, 
            'hidden_layers': 1,  
            'activation': 'tanh', 
            'units': 2*input_size, 
            'output_size': input_size 
        })
        
        assert self.HP['GRUCell']['hidden_size'] == self.HP['input_size']
        assert self.HP['Decoder']['input_size']  == self.HP['hidden_size']
        assert self.HP['Decoder']['output_size'] == self.HP['input_size']

        # Initialize the components
        self.dynamic = LinODECell(**self.HP['LinODE'])    
        self.filter  = GRUCell(**self.HP['GRUCell'])
        self.encoder = None

        self.decoder = nn.Sequential(
            nn.Linear(self.HP['Decoder']['input_size'], self.HP['Decoder']['units']),
            nn.Tanh(),
            nn.Linear(self.HP['Decoder']['units'], self.HP['Decoder']['units']),
            nn.Tanh(),
            nn.Linear(self.HP['Decoder']['units'], self.HP['Decoder']['output_size']),
        )
    
    @torch.jit.script_method
    def forward(self, X, T):
        # type: (Tensor, Tensor) -> Tensor
        # shapes: X:  BATCH x SEQ_LEN x DIM
        shape = X.shape
        hidden_shape = (*X.shape, self.HP['hidden_size'])
        
        ΔT = torch.diff(T)
        h = torch.zeros(hidden_shape)
        Xhat = torch.empty_like(X)
    
        for i, Δt in enumerate(zip(h, ΔT)):
            h_dash = self.dynamic(Δt, h[i])
            h[i+1] = self.filter(h_dash, x[i+1])
            
        Xhat = self.decoder(h)
            
        return Xhat
    
    
    