# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
from typing import Any, Dict, Final, List

import torch
from linodenet.models.filters import FilterABC, KalmanCell
from linodenet.util import (
    ReverseDense,
    ReZero,
    autojit,
    deep_dict_update,
    initialize_from_config,
)
from torch import Tensor, jit, nn
from torchinfo import summary

In [None]:
@autojit
class SequentialFilterBlock(FilterABC, nn.ModuleList):
    DEFAULT_HP: dict = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": None,
        "filter": KalmanCell.HP | {"autoregressive": True},
        "layers": [ReverseDense.HP | {"bias": False}, ReZero.HP],
    }

    input_size: Final[int]

    def __init__(self, *args, **HP: Any) -> None:
        super().__init__()

        self.HP = self.DEFAULT_HP | HP
        HP = self.HP
        self.input_size = input_size = HP["input_size"]
        HP["filter"]["input_size"] = input_size

        layers: list[nn.Module] = []

        for layer in HP["layers"]:
            if "input_size" in layer:
                layer["input_size"] = input_size
            if "output_size" in layer:
                layer["output_size"] = input_size

            module = initialize_from_config(layer)
            layers.append(module)

        layers = list(args) + layers
        self.filter = initialize_from_config(HP["filter"])
        self.layers = nn.Sequential(*layers)

    @jit.export
    def forward(self, y: Tensor, x) -> Tensor:
        z = self.filter(y, x)
        for module in self.layers:
            z = module(z)
        return x + z

In [None]:
x = torch.randn(3)
model = SequentialFilterBlock(input_size=3)
print(summary(model))
model(x, x)
scripted = jit.script(model)
x == scripted(x, x)

In [None]:
from linodenet.util.layers import Repeat

In [None]:
class SequentialFilter(FilterABC, nn.Sequential):
    DEFAULT_HP: dict = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "independent": True,
        "copies": 2,
        "input_size": int,
        "module": SequentialFilterBlock.DEFAULT_HP,
    }

    HP: Dict[str, Any]
    """The HP"""

    def __init__(self, **HP: Any) -> None:
        self.HP = self.DEFAULT_HP | HP
        HP = self.HP
        HP["module"]["input_size"] = HP["input_size"]

        copies: list[nn.Module] = []

        for k in range(HP["copies"]):
            if isinstance(HP["module"], nn.Module):
                module = HP["module"]
            else:
                module = initialize_from_config(HP["module"])

            if HP["independent"]:
                copies.append(module)
            else:
                copies = [module] * HP["copies"]
                break

        HP["module"] = str(HP["module"])
        super().__init__(*copies)

    @jit.export
    def forward(self, y: Tensor, x: Tensor) -> Tensor:
        for module in self:
            x = module(y, x)
        return x

In [None]:
x = torch.randn(4)
model = SequentialFilter(input_size=4)
model(x, x)
scripted = jit.script(model)
scripted(x, x)
summary(model)