In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from abc import abstractmethod
from collections.abc import Iterable
from typing import Any, Final, Optional, TypeAlias

import torch
from torch import Tensor, jit, nn

from linodenet.utils import (
    ReverseDense,
    ReZeroCell,
    deep_dict_update,
    deep_keyval_update,
    initialize_from_config,
)

from linodenet.models.filters import FilterABC, KalmanCell
from torchinfo import summary

In [None]:
class SequentialFilterBlock(FilterABC, nn.ModuleList):
    r"""Multiple Filters applied sequentially."""

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": None,
        "filter": KalmanCell.HP | {"autoregressive": True},
        "layers": [ReverseDense.HP | {"bias": False}, ReZeroCell.HP],
    }
    r"""The HyperparameterDict of this class."""

    input_size: Final[int]

    def __init__(self, *args: Any, **HP: Any) -> None:
        super().__init__()
        self.CFG = HP = deep_dict_update(self.HP, HP)

        self.input_size = input_size = HP["input_size"]
        HP["filter"]["input_size"] = input_size

        layers: list[nn.Module] = []

        for layer in HP["layers"]:
            if "input_size" in layer:
                layer["input_size"] = input_size
            if "output_size" in layer:
                layer["output_size"] = input_size
            module = initialize_from_config(layer)
            layers.append(module)

        layers = list(args) + layers
        self.filter: nn.Module = initialize_from_config(HP["filter"])
        self.layers: Iterable[nn.Module] = nn.Sequential(*layers)

    @jit.export
    def forward(self, y: Tensor, x: Tensor) -> Tensor:
        r"""Signature: ``[(..., m), (..., n)] -> (..., n)``."""
        z = self.filter(y, x)
        for module in self.layers:
            z = module(z)
        return x + z


class SequentialFilter(FilterABC, nn.ModuleList):
    r"""Multiple Filters applied sequentially."""

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "independent": True,
        "copies": 2,
        "input_size": int,
        "module": SequentialFilterBlock.HP,
    }
    r"""The HyperparameterDict of this class."""

    def __init__(self, **HP: Any) -> None:
        super().__init__()
        self.CFG = HP = deep_dict_update(self.HP, HP)

        HP["module"]["input_size"] = HP["input_size"]

        copies: list[nn.Module] = []

        for _ in range(HP["copies"]):
            if isinstance(HP["module"], nn.Module):
                module = HP["module"]
            else:
                module = initialize_from_config(HP["module"])

            if HP["independent"]:
                copies.append(module)
            else:
                copies = [module] * HP["copies"]
                break

        HP["module"] = str(HP["module"])
        nn.ModuleList.__init__(self, copies)

    @jit.export
    def forward(self, y: Tensor, x: Tensor) -> Tensor:
        r"""Signature: ``[(..., m), (..., n)] -> (..., n)``."""
        for module in self:
            x = module(y, x)
        return x

In [None]:
model = SequentialFilter(input_size=16)
summary(model)

Old one:
    
x = x - ϕ(Linear(x-y))
x = x - ϕ(Linear(x-y))

New Filter:

x = x - αLinear(x-y)
x = x - ϵϕ(Linear(x-y))
x = x - ϵϕ(Linear(x-y))