# Title

In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [2]:
import matplotlib.pyplot as plt
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [3]:
import torch
from torch import Tensor, jit, nn
from torchinfo import summary
from linodenet.util import autojit

In [4]:
from typing import Any, Dict, Final, List

In [5]:
class Series(nn.Sequential):
    """An augmentation of nn.Sequential."""

    DEFAULT_HP = {"modules": [None]}

    HP: Dict[str, Any]
    """The HP"""

    def __init__(self, *args: Any, **HP: Any) -> None:
        self.HP = self.DEFAULT_HP | HP
        HP = self.HP

        modules: list[nn.Module] = []

        if HP["modules"] != [None]:
            del HP["modules"][0]
            for k, layer in enumerate(HP["modules"]):
                module = initialize_from_config(layer)
                modules.append(module)

        modules = list(args) + modules

        super().__init__(*modules)

In [6]:
x = torch.randn(3)

In [7]:
model = nn.Sequential(nn.ReLU(), nn.Linear(3, 4))
scripted = jit.script(model)
scripted(x)

In [8]:
model = Series(nn.ReLU(), nn.Linear(3, 4))
scripted = jit.script(model)
scripted(x)

In [9]:
model = nn.Sequential()
scripted = jit.script(model)

## Parallel

In [10]:
class Parallel(nn.ModuleList):
    """An augmentation of nn.Sequential."""

    DEFAULT_HP = {"modules": [None]}

    HP: Dict[str, Any]
    """The HP"""

    def __init__(self, *args, **HP: Any) -> None:
        self.HP = self.DEFAULT_HP | HP
        HP = self.HP

        modules: list[nn.Module] = []

        if HP["modules"] != [None]:
            del HP["modules"][0]
            for k, layer in enumerate(HP["modules"]):
                module = initialize_from_config(layer)
                modules.append(module)

        modules = list(args) + modules

        super().__init__(*modules)

    @jit.export
    def forward(self, x) -> list[Any]:
        r"""Forward pass.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        result: List[Any] = []

        for module in self:
            result.append(module(x))

        return result

In [11]:
model = Parallel([nn.ReLU(), nn.Linear(3, 4)])
scripted = jit.script(model)
scripted(x)

## Repeat

In [12]:
@autojit
class Repeat(nn.Sequential):
    """An copies of a module multiple times."""

    DEFAULT_HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "module": None,
        "copies": 1,
        "independent": True,
    }

    HP: Dict[str, Any]
    """The HP"""

    def __init__(self, **HP: Any) -> None:
        self.HP = self.DEFAULT_HP | HP
        HP = self.HP

        copies: list[nn.Module] = []

        for k in range(HP["copies"]):
            if isinstance(HP["module"], nn.Module):
                module = HP["module"]
            else:
                module = initialize_from_config(HP["module"])

            if HP["independent"]:
                copies.append(module)
            else:
                copies = [module] * HP["copies"]
                break

        HP["module"] = str(HP["module"])
        super().__init__(*copies)

In [13]:
model = Repeat(module=nn.ReLU(), copies=3)
scripted = jit.script(model)
scripted(x)

In [14]:
summary(model)