# Sample Template

Here is some sample text, and a first block of code

In [None]:
from collections import OrderedDict
from math import sqrt
from typing import Any, Final, Iterable, Optional, TypeVar

import torch
from torch import Tensor, jit, nn
from torch._jit_internal import _copy_to_script_wrapper
from torch.linalg import matrix_norm, vector_norm
from torch.nn import functional
from typing_extensions import Self

from linodenet.models.encoders.invertible_layers import LinearContraction, iResNetBlock

In [None]:
m = nn.Linear(4, 3, bias=False)

In [None]:
class Stop(nn.Module):
    def forward(self, *args, **kwargs):
        raise RuntimeError

In [None]:
class iResNet(nn.Module):
    r"""Invertible ResNet consists of a stack of `iResNetBlock` modules.

    References
    ----------
    - | Invertible Residual Networks
      | Jens Behrmann, Will Grathwohl, Ricky T. Q. Chen, David Duvenaud, Jörn-Henrik Jacobsen
      | International Conference on Machine Learning 2019
      | http://proceedings.mlr.press/v97/behrmann19a.html

    Attributes
    ----------
    input_size: int
        The dimensionality of the input space.
    output_size: int
        The dimensionality of the output space.
    blocks:  nn.Sequential
        Sequential model consisting of the iResNetBlocks
    reversed_blocks: nn.Sequential
        The same blocks in reversed order
    HP: dict
        Nested dictionary containing the hyperparameters.
    """

    # Constants
    input_size: Final[int]
    r"""CONST: The dimensionality of the inputs."""
    output_size: Final[int]
    r"""CONST: The dimensionality of the outputs."""

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "maxiter": 10,
        "input_size": None,
        "dropout": None,
        "bias": True,
        "nblocks": 5,
        "rezero": False,
        "iResNetBlock": {
            "input_size": None,
            "activation": "ReLU",
            "activation_config": {"inplace": False},
            "bias": True,
            "hidden_size": None,
            "maxiter": 100,
        },
    }
    r"""The hyperparameter dictionary"""

    def __new__(
        cls, *modules: nn.Module, inverse: Optional[Self] = None, **hparams: Any
    ) -> Self:
        r"""Initialize from hyperparameters."""
        blocks: list[nn.Module] = [] if modules is None else list(modules)
        assert len(blocks) ^ len(hparams), "Provide either blocks, or hyperparameters!"

        if hparams:
            return cls.from_hyperparameters(**hparams)

        return super().__new__(cls)

    def __init__(
        self, *modules: nn.Module, inverse: Optional[Self] = None, **hparams: Any
    ) -> None:
        r"""Initialize from hyperparameters."""
        super().__init__()

        layers: list[nn.Module] = [] if modules is None else list(modules)
        assert len(layers) ^ len(hparams), "Provide either blocks, or hyperparameters!"
        if hparams:
            raise ValueError

        # validate layers
        # for layer in layers:
        #     assert hasattr(layer, "inverse")
        #     assert hasattr(layer, "encode")
        #     assert hasattr(layer, "decode")

        self.blocks = nn.Sequential(*layers)

        # print([layer.is_inverse for layer in self])
        if inverse is None:
            cls = type(self)
            self.inverse = cls(*[layer.inverse for layer in self.blocks], inverse=self)
        else:
            self.inverse = None

    @classmethod
    def from_hyperparameters(cls, cfg) -> Self:
        raise NotImplementedError

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        """Compute the encoding."""
        return self.blocks(x)

    @jit.export
    def encode(self, x: Tensor) -> Tensor:
        """Compute the encoding."""
        return self.blocks(x)

    @jit.export
    def decode(self, y: Tensor) -> Tensor:
        r"""Compute the inverse through fix point iteration in each block in reversed order."""
        for layer in self.blocks[::-1]:  # traverse in reverse
            y = layer.decode(y)
        return y

In [None]:
m, n = 5, 5
layer = LinearContraction(m, n)
model = iResNetBlock(layer)

In [None]:
# model = jit.script(model)

In [None]:
x = torch.randn(5)

In [None]:
model(x)

In [None]:
f = jit.script(iResNet(model, model))

In [None]:
f.decode(x)

## A second heading

and some more text