# ODE-RNN with linear ODE instead of general

- try with/without encoder
- first run without missing values
- later with missing values

```
for i in 1,2,..., N:
    h_i' = ODESolve(f, h_{i-1}, (t_{i-1}, t_i))
    h_i = RNNCell(h_i', x_i)
o_i = OutputNN(h_i) for all i...N
```

In [None]:
%config InlineBackend.figure_format = 'svg'

In [None]:
import torch
from torch import nn
from torch.nn import GRUCell
import numpy as np
from opt_einsum import contract
from tqdm.auto import trange

In [None]:
from typing import Union, Callable

In [None]:
from scipy import stats
import matplotlib.pyplot as plt


def visualize_distribution(x, bins=50, log=True, ax=None):
    x = np.array(x)
    x = x[~np.isnan(x)]

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6), tight_layout=True)

    if log:
        z = np.log10(x)
        ax.set_xscale("log")
        ax.set_yscale("log")
        low = np.quantile(z, 0.01)
        high = np.quantile(z, 0.99)
        x = x[(z >= low) & (z <= high)]
        bins = np.logspace(low, high, num=bins, base=10)
    ax.hist(x, bins=bins, density=True)
    ax.set_ylabel("density")
    print(
        f"median: {np.median(x):.2e}   mode:{stats.mode(x)[0][0]:.2e}   mean: {np.mean(x):.2e}  stdev:{np.std(x):.2e}"
    )

In [None]:
class LinODE(nn.Module):
    """
    Linear System module

    x' = Ax + Bu + w
     y = Cx + Du + v

    """

    def __init__(
        self,
        input_size,
        kernel_initialization: Union[torch.Tensor, Callable[int, torch.Tensor]] = None,
        homogeneous: bool = True,
        matrix_type: str = None,
        device=torch.device("cpu"),
        dtype=torch.float32,
    ):
        """
        kernel_initialization: torch.tensor or callable
            either a tensor to assign to the kernel at initialization
            or a callable f: int -> torch.Tensor|L
        """
        super(LinODE, self).__init__()

        if kernel_initialization is None:
            self.kernel_initialization = lambda: torch.randn(
                input_size, input_size
            ) / np.sqrt(input_size)
        elif callable(kernel_initialization):
            self.kernel = lambda: torch.tensor(kernel_initialization(input_size))
        else:
            self.kernel_initialization = lambda: torch.tensor(kernel_initialization)

        self.kernel = nn.Parameter(self.kernel_initialization())

        if not homogeneous:
            self.bias = nn.Parameter(torch.randn(input_size))
            raise NotImplementedError("Inhomogeneous Linear Model not implemented yet.")

        self.to(device=device, dtype=dtype)

    def forward(self, Δt, x):
        """
        Inputs:
        Δt: (...,)
        x:  (..., M)

        Outputs:
        xhat:  (..., M)


        Forward using matrix exponential
        # TODO: optimize if clauses away by changing definition in constructor.
        """
        #         Δt = torch.diff(t)
        #         print(Δt.shape, x.shape)
        AΔt = contract("kl, ... -> ...kl", self.kernel, Δt)
        expAΔt = torch.matrix_exp(AΔt)
        #         print(expAΔt.shape)
        xhat = contract("...kl, ...l -> ...k", expAΔt, x)

        return xhat

In [None]:
def scaled_Lp(x, p=2):
    x = np.abs(x)
    if p == 0:
        # https://math.stackexchange.com/q/282271/99220
        return stats.gmean(x, axis=None)
    elif p == 1:
        return np.mean(x)
    elif p == 2:
        return np.sqrt(np.mean(x**2))
    elif p == np.inf:
        return np.max(x)
    else:
        x = x.astype(np.float128)
        return np.mean(x**p) ** (1 / p)

In [None]:
num = np.random.randint(low=20, high=1000)
dim = np.random.randint(low=2, high=100)
t0, t1 = np.random.uniform(low=-10, high=10, size=(2,))
A = np.random.randn(dim, dim)
x0 = np.random.randn(dim)
T = np.random.uniform(low=t0, high=t1, size=num - 2)
T = np.sort([t0, *T, t1])

In [None]:
T = torch.tensor(T).to(dtype=torch.float32)
ΔT = torch.diff(T).to(dtype=torch.float32)
Xhat = torch.empty(num, dim).to(dtype=torch.float32)
Xhat[0] = torch.tensor(x0).to(dtype=torch.float32)
model = LinODECell(input_size=dim, kernel_initialization=A).to(dtype=torch.float32)

In [None]:
model(ΔT[0], Xhat[0])

In [None]:
del torch_linodeint

# Optimizing the RNN implementation

We make use of the details provided at https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

In [None]:
from torch import jit

In [None]:
class LinODECell(jit.ScriptModule):
    """
    Linear System module

    x' = Ax + Bu + w
     y = Cx + Du + v

    """

    def __init__(
        self,
        input_size,
        kernel_initialization: Union[torch.Tensor, Callable[int, torch.Tensor]] = None,
        homogeneous: bool = True,
        matrix_type: str = None,
        device=torch.device("cpu"),
        dtype=torch.float32,
    ):
        """
        kernel_initialization: torch.tensor or callable
            either a tensor to assign to the kernel at initialization
            or a callable f: int -> torch.Tensor|L
        """
        super(LinODECell, self).__init__()

        if kernel_initialization is None:
            self.kernel_initialization = lambda: torch.randn(
                input_size, input_size
            ) / np.sqrt(input_size)
        elif callable(kernel_initialization):
            self.kernel = lambda: torch.tensor(kernel_initialization(input_size))
        else:
            self.kernel_initialization = lambda: torch.tensor(kernel_initialization)

        self.kernel = nn.Parameter(self.kernel_initialization())

        if not homogeneous:
            self.bias = nn.Parameter(torch.randn(input_size))
            raise NotImplementedError("Inhomogeneous Linear Model not implemented yet.")

        self.to(device=device, dtype=dtype)

    @jit.script_method
    def forward(self, Δt, x):
        """
        Inputs:
        Δt: (...,)
        x:  (..., M)

        Outputs:
        xhat:  (..., M)


        Forward using matrix exponential
        # TODO: optimize if clauses away by changing definition in constructor.
        """

        AΔt = torch.einsum("kl, ... -> ...kl", self.kernel, Δt)
        expAΔt = torch.matrix_exp(AΔt)
        xhat = torch.einsum("...kl, ...l -> ...k", expAΔt, x)

        return xhat

In [None]:
model = LinODECell(input_size=dim, kernel_initialization=A)

In [None]:
def torch_linodeint(model, x0, T):
    ΔT = torch.diff(T)
    Xhat = torch.empty(len(T), len(x0))

    results = [x0]
    Xhat[0] = torch.tensor(x0)

    for i, Δt in enumerate(ΔT):
        results.append(model(Δt, results[-1]))
    #         Xhat[i+1] = model(Δt, Xhat[i])
    return Xhat

In [None]:
@torch.jit.script
def torch_linodeint(model, x0, T):
    ΔT = torch.diff(T)
    Xhat = torch.empty(len(T), len(x0))

    results = [x0]
    Xhat[0] = torch.tensor(x0)

    for i, Δt in enumerate(ΔT):
        results.append(model(Δt, results[-1]))
    #         Xhat[i+1] = model(Δt, Xhat[i])
    return Xhat

In [None]:
def test_LinODE(dim=None, num=None, tol=1e-3, precision="single", relative_error=True):
    from scipy.integrate import odeint

    if precision == "single":
        eps = 2**-24
        numpy_dtype = np.float32
        torch_dtype = torch.float32
    elif precision == "double":
        eps = 2**-53
        numpy_dtype = np.float64
        torch_dtype = torch.float64
    else:
        raise ValueError

    num = np.random.randint(low=20, high=1000) or num
    dim = np.random.randint(low=2, high=100) or dim
    t0, t1 = np.random.uniform(low=-10, high=10, size=(2,)).astype(numpy_dtype)
    A = np.random.randn(dim, dim).astype(numpy_dtype)
    x0 = np.random.randn(dim).astype(numpy_dtype)
    T = np.random.uniform(low=t0, high=t1, size=num - 2).astype(numpy_dtype)
    T = np.sort([t0, *T, t1]).astype(numpy_dtype)
    func = lambda t, x: A @ x

    X = odeint(func, x0, T, tfirst=True)

    model = LinODE(input_size=dim, kernel_initialization=A, dtype=torch_dtype)
    ΔT = torch.diff(torch.tensor(T))
    Xhat = torch.empty(num, dim, dtype=torch_dtype)
    Xhat[0] = torch.tensor(x0)

    for i, Δt in enumerate(ΔT):
        Xhat[i + 1] = model(Δt, Xhat[i])

    Xhat = Xhat.detach().cpu().numpy()

    err = np.abs(X - Xhat)

    if relative_error:
        err /= np.abs(X) + eps

    return np.array([scaled_Lp(err, p=p) for p in (1, 2, np.inf)])

## Checking LinODE error

We compare results from our LinODE against scipy's odeint, averaged across different number of dimensions.

In [None]:
errs = np.array([test_LinODE() for _ in trange(1_000)]).T

In [None]:
fig, ax = plt.subplots(
    ncols=3, figsize=(12, 3), tight_layout=True, sharey=True, sharex=True
)

for i, p in enumerate((1, 2, np.inf)):
    visualize_distribution(errs[i], log=True, ax=ax[i])
    ax[i].set_title(f"scaled, relative L{p} error")

In [None]:
errs = np.array([test_LinODE(precision="double") for _ in trange(1_000)]).T

In [None]:
fig, ax = plt.subplots(
    ncols=3, figsize=(12, 3), tight_layout=True, sharey=True, sharex=True
)

for i, p in enumerate((1, 2, np.inf)):
    visualize_distribution(errs[i], log=True, ax=ax[i])
    ax[i].set_title(f"scaled, relative L{p} error")

In [None]:
def test_LinODEA(dim=None, num=None, tol=1e-3, precision="single", relative_error=True):
    from scipy.integrate import odeint

    if precision == "single":
        eps = 2**-24
        numpy_dtype = np.float32
        torch_dtype = torch.float32
    elif precision == "double":
        eps = 2**-53
        numpy_dtype = np.float64
        torch_dtype = torch.float64
    else:
        raise ValueError

    num = np.random.randint(low=20, high=1000) or num
    dim = np.random.randint(low=2, high=100) or dim
    t0, t1 = np.random.uniform(low=-10, high=10, size=(2,)).astype(numpy_dtype)
    A = np.random.randn(dim, dim).astype(numpy_dtype)
    x0 = np.random.randn(dim).astype(numpy_dtype)
    T = np.random.uniform(low=t0, high=t1, size=num - 2).astype(numpy_dtype)
    T = np.sort([t0, *T, t1]).astype(numpy_dtype)
    func = lambda t, x: A @ x

    X = odeint(func, x0, T, tfirst=True)

In [None]:
_ = [test_LinODEA() for k in trange(1000)]

In [None]:
def test_LinODEB(dim=None, num=None, tol=1e-3, precision="single", relative_error=True):
    from scipy.integrate import odeint

    if precision == "single":
        eps = 2**-24
        numpy_dtype = np.float32
        torch_dtype = torch.float32
    elif precision == "double":
        eps = 2**-53
        numpy_dtype = np.float64
        torch_dtype = torch.float64
    else:
        raise ValueError

    num = np.random.randint(low=20, high=1000) or num
    dim = np.random.randint(low=2, high=100) or dim
    t0, t1 = np.random.uniform(low=-10, high=10, size=(2,)).astype(numpy_dtype)
    A = np.random.randn(dim, dim).astype(numpy_dtype)
    x0 = np.random.randn(dim).astype(numpy_dtype)
    T = np.random.uniform(low=t0, high=t1, size=num - 2).astype(numpy_dtype)
    T = np.sort([t0, *T, t1]).astype(numpy_dtype)

    model = LinODE(input_size=dim, kernel_initialization=A, dtype=torch_dtype).to(
        device=torch.device("cuda")
    )
    ΔT = torch.diff(torch.tensor(T)).to(device=torch.device("cuda"))
    Xhat = torch.empty(num, dim, dtype=torch_dtype).to(device=torch.device("cuda"))
    Xhat[0] = torch.tensor(x0).to(device=torch.device("cuda"))

    for i, Δt in enumerate(ΔT):
        Xhat[i + 1] = model(Δt, Xhat[i])

In [None]:
_ = [test_LinODEB() for k in trange(1000)]

In [None]:
import matplotlib.pyplot as plt

In [None]:
def test_linode:
    from scipy.integrate import odeint
    n = np.random.randint(low=1, high=100)
    t0, t1 = np.random.uniform(low=-10, high=10, size=(2,))
    A = np.random.randn(n,n)
    x0 = np.random.randn(n)
    T = np.linspace(t0, t1)
    func = lambda t, x: A@x
    x = odeint(func, x0, T, tfirst=True)
    
    model = LinODE(input_size=n)
    model.kernel =torch.from_numpy(A)
    
    
    
    
    
    
    

In [None]:
class AttrDict(dict):
    # https://stackoverflow.com/a/14620633/9318372
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


d = AttrDict({"k": 1, "l": 2})

In [None]:
import collections


def deep_update(source: dict, overrides: dict) -> dict:
    """
    Update a nested dictionary or similar mapping.
    Modify ``source`` in place.
    Reference: https://stackoverflow.com/a/30655448/9318372
    """
    for key, value in overrides.iteritems():
        if isinstance(value, collections.Mapping) and value:
            returned = deep_update(source.get(key, {}), value)
            source[key] = returned
        else:
            source[key] = overrides[key]
    return source

In [None]:
class LinODE_RNN(nn.Module):
    # default hyperparameters
    HP = {
        # what model to use for the reccurent cell. Options: LSTMCell, RNNCell, GRUCell
        'Cell': nn.GRUCell,
        # Reccurent Cell Options. See
        'CellOptions' : {'input_size': None, 'hidden_size' : None, 'bias' : True},
        # Linear ODE parameters.
        'LinODE' :  {'input_size': None, 'initialization': None, 'matrix_type': None, 'homogeneous'=True},
    }
    
    def __set_HP():
        self.HP['LinODE']['hidden_size']
    
    def __init__(self, input_size, HP: dict):
        self.__set_HP(input_size, HP: dict)
        self.init_HP()
        self.dynamics = LinODE(**HP['LinODE'])
        self.encoder = 
        self.decoder = 
        self.filter = 
        
        
    def forward(self):
        """c
        input: t: tensor shape (..., N,)
            Observation timepoints corresponding to the observed values
        input: x: tensor shape (..., N, M) dtype: float. 
            Observed data, NaN indicates Missing values
        input:
        output: xhat: tensor shape (..., N, M)
            Predicted values. The values may differ from x for non-NaN entries, since the model assumes that observational data is noisy.
            Q: Does this make any sense for categorical data? Not really..., but one can use sigmoid for example.
        """
        
        xhat = None
        
        return xhat
    
    def predict(self, t, x):
        xhat = self(t, x)
        
        # TODO: treat categorical features.
        
        return xhat

In [None]:
mask = np.random.choice([True, False], size=(5, 6))
np.where(mask, np.random.randn(5, 6))

In [None]:
d = AttrDict()
d.update({"items": ["jacket", "necktie", "trousers"]})
d.items

How to handle input? We have multiple Options:

1. Input $t_\text{obs}$, $x_\text{obs}$, and $t_\text{predict}$, return $x_\text{predict}$
    - similar to regular ODESELVE input, but with many time observations instead of single initial condition.
2. Input $t_\text{obs+predict}$, $x_\text{obs}$, fill $x$ with nan values at prediction points (reduce problem to imputation task)
3. Input $t$, $x$, $u$. The controls $u$ can occur at future time points (pre-scheduled) controls


### Question? How to handle initial hidden state & initial state estimation in RNN?

1. Initialize with zero or randomly (kinda dumb, but has to do for now)
2. Initialize through initializer network, 
    - small deepset / Time series set function network
    - ODE-RNN encoder like in Latent-ODE encoder


In [None]:
class LinODERNN(nn.Module):
    # default hyperparameters
    HP = {
        'GRUCell' : {'bias' : True, 'hidden_size' : None},
        'LinODE' : {'hidden_size': None, initialization: 'None'}
    }
    
    def __set_HP()
    
    def __init__(self, input_size, **hyperparameters):
        self.__set_HP(**hyperparameters)
        self.init_HP()
        self.GRUCell = nn.GRUCell()
        self.LinODE = LinODE()
        
    def forward(self):
        """c
        input: t: tensor shape (..., N,)
            Observation timepoints corresponding to the observed values
        input: x: tensor shape (..., N, M) dtype: float. 
            Observed data, NaN indicates Missing values
        input:
        output: xhat: tensor shape (..., N, M)
            Predicted values. The values may differ from x for non-NaN entries, since the model assumes that observational data is noisy.
            Q: Does this make any sense for categorical data? Not really..., but one can use sigmoid for example.
        """
        
        xhat = None
        
        return xhat
    
    def predict(self, t, x):
        xhat = self(t, x)
        
        # TODO: treat categorical features.
        
        return xhat

In [None]:
import numpy as np
import matplotlib.pyplot as plt

N = 100_000
n = 20

A = np.random.randn(N, n, n)
symA = (A + np.einsum("ijk-> ikj", A)) / 2
skewA = (A - np.einsum("ijk-> ikj", A)) / 2

In [None]:
conds = np.linalg.cond(A)
symconds = np.linalg.cond(symA)
skewconds = np.linalg.cond(skewA)

In [None]:
from scipy import stats


def visualize_distribution(x, bins=100, log=True, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6), tight_layout=True)

    if log:
        x = np.log10(x)
        ax.set_xscale("log")
        ax.set_yscale("log")
        bins = np.logspace(np.floor(np.min(x)), np.ceil(np.max(x)), num=bins, base=10)
    ax.hist(conds, bins=bins, density=True)
    print(
        f"median: {np.median(x):.2}   mode:{stats.mode(x)[0][0]:.2}   mean: {np.mean(x):.2}  stdev:{np.std(x):.2}"
    )

In [None]:
fig, ax = plt.subplots(
    ncols=3, figsize=(12, 4), tight_layout=True, sharex=True, sharey=True
)
visualize_distribution(conds, ax=ax[0])
visualize_distribution(symconds, ax=ax[1])
visualize_distribution(skewconds, ax=ax[2])

In [None]:
nn.init.kaiming_normal_(torch.empty(10, 10))

In [None]:
def random_matrix(input_size, kind=None):
    """
    kind options:
    symmetric,
    skew symmetric,
    orthogonal,
    normal,
    """

    A = nn.init.kaiming_normal_(torch.empty(input_size, input_size))

    if kind == "symmetric":
        return (A + A.T) / 2
    if kind == "skew-symmetric":
        return (A - A.T) / 2

In [None]:
?GRUCell