# Title

In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [3]:
from pathlib import Path

path = Path.cwd().joinpath("models")
path.mkdir(exist_ok=True)

In [4]:
import torch
from linodenet.models import LinODECell as Model

model = Model(32)
filepath = path.joinpath(f"{Model.__name__}.pt")
# torch.jit.save(model, filepath)

In [11]:
model2 = torch.jit.load(filepath)

In [None]:
?torch.jit.load

## Test without Initialization / Regularization

In [26]:
import logging
from typing import Any, Final, Optional, Union, Callable

import torch
from torch import Tensor, jit, nn

from linodenet.initializations import INITIALIZATIONS, Initialization, gaussian
from linodenet.models.iResNet import iResNet
from linodenet.projections import PROJECTIONS, Projection
from linodenet.util import deep_dict_update

__logger__ = logging.getLogger(__name__)

__all__: Final[list[str]] = [
    "ConcatEmbedding",
    "ConcatProjection",
    "LinODE",
    "LinODECell",
    "LinODEnet",
]


class LinODECell(nn.Module):
    r"""Linear System module, solves `ẋ = Ax`, i.e. `x̂ = e^{A\Delta t}x`.

    Parameters
    ----------
    input_size: int
    kernel_initialization: Union[Tensor, Callable[int, Tensor]]

    Attributes
    ----------
    input_size:  int
        The dimensionality of the input space.
    output_size: int
        The dimensionality of the output space.
    kernel: Tensor
        The system matrix
    kernel_initialization: Callable[[], Tensor]
        Parameter-less function that draws a initial system matrix
    kernel_projection: Callable[[Tensor], Tensor]
        Regularization function for the kernel
    """

    input_size: Final[int]
    output_size: Final[int]

    kernel: Tensor
    # kernel_initialization: Callable[[], Tensor]
    kernel_projection: Projection

    def __init__(
        self,
        input_size: int,
        kernel_initialization: Optional[Union[str, Tensor, Initialization]] = None,
        kernel_projection: Optional[Union[str, Projection]] = None,
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = input_size

        def kernel_initialization_dispatch():
            if kernel_initialization is None:
                return lambda: gaussian(input_size)
            if kernel_initialization in INITIALIZATIONS:
                _init = INITIALIZATIONS[kernel_initialization]
                return lambda: _init(input_size)
            if callable(kernel_initialization):
                assert Tensor(kernel_initialization(input_size)).shape == (
                    input_size,
                    input_size,
                )
                return lambda: Tensor(kernel_initialization(input_size))
            if isinstance(kernel_initialization, Tensor):
                assert kernel_initialization.shape == (input_size, input_size)
                return lambda: kernel_initialization
            assert Tensor(kernel_initialization).shape == (input_size, input_size)
            return lambda: Tensor(kernel_initialization)

        # this looks funny, but it needs to be written that way to be compatible with torchscript
        def kernel_regularization_dispatch():
            if kernel_projection is None:
                _kernel_regularization = PROJECTIONS["identity"]
            elif kernel_projection in PROJECTIONS:
                _kernel_regularization = PROJECTIONS[kernel_projection]
            elif callable(kernel_projection):
                _kernel_regularization = kernel_projection
            else:
                raise NotImplementedError(f"{kernel_projection=} unknown")
            return _kernel_regularization

        self._kernel_initialization = kernel_initialization_dispatch()
        self._kernel_regularization = kernel_regularization_dispatch()
        self.kernel = nn.Parameter(self._kernel_initialization())

    def kernel_initialization(self) -> Tensor:
        r"""Draw an initial kernel matrix (random or static)."""
        return self._kernel_initialization()

    @jit.export
    def kernel_regularization(self, w: Tensor) -> Tensor:
        r"""Regularize the Kernel, e.g. by projecting onto skew-symmetric matrices."""
        return self._kernel_regularization(w)

    @jit.export
    def forward(self, t: Tensor, x0: Tensor) -> Tensor:
        # TODO: optimize if clauses away by changing definition in constructor.
        r"""Signature: `[...,]×[...,d] ⟶ [...,d]`.

        Parameters
        ----------
        Δt: Tensor, shape=(...,)
            The time difference `t_1 - t_0` between `x_0` and `x̂`.
        x0:  Tensor, shape=(...,DIM)
            Time observed value at `t_0`

        Returns
        -------
        xhat:  Tensor, shape=(...,DIM)
            The predicted value at `t_1`
        """
        A = self.kernel_regularization(self.kernel)
        At = torch.einsum("kl, ... -> ...kl", A, t)
        expAt = torch.matrix_exp(At)
        xhat = torch.einsum("...kl, ...l -> ...k", expAt, x0)
        return xhat

In [27]:
from pathlib import Path

path = Path.cwd().joinpath("models")
path.mkdir(exist_ok=True)

In [28]:
model = jit.script(LinODECell(32))
filepath = path.joinpath(f"{LinODECell.__name__}.pt")
torch.jit.save(model, filepath)
model2 = torch.jit.load(filepath)

In [48]:
class MyModule(torch.jit.ScriptModule):

    my_constant: Final[int]
    kernel: Tensor

    def __init__(self, input_size):
        super(MyModule, self).__init__()
        self.my_constant = 2
        self.kernel = nn.Parameter(gaussian(input_size))

    def forward(self, x):
        return torch.matrix_exp(self.kernel @ x)

In [49]:
model = torch.jit.script(MyModule(10))
filepath = path.joinpath(f"test_model.pt")
torch.jit.save(model, filepath)
model2 = torch.jit.load(filepath)


# torch.jit.save(model, path)

In [89]:
from typing import List

In [90]:
class LSTMLayer(torch.nn.Module):
    def __init__(self, cell, *cell_args):
        super(LSTMLayer, self).__init__()
        self.cell = cell(*cell_args)

    # @jit.export
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        inputs = input.unbind(0)
        outputs = torch.jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs), state

In [92]:
model = LSTMLayer(torch.nn.RNNCell, 10, 12)
model(torch.randn(7, 10), torch.randn(7, 12))

In [69]:
model = torch.jit.script(LSTMLayer(torch.nn.RNNCell, 10, 12))
filepath = path.joinpath(f"test_model.pt")
torch.jit.save(model, filepath)
model2 = torch.jit.load(filepath)

In [None]:
class ExampleNoUnicode(nn.Module)