In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import logging
from collections import OrderedDict
from math import sqrt
from typing import Any, Optional, cast

import torch
from linodenet.models.encoders.ft_transformer import (
    get_activation_fn,
    get_nonglu_activation_fn,
)
from linodenet.util import (
    ReverseDense,
    ReZeroCell,
    autojit,
    deep_dict_update,
    initialize_from_config,
)
from torch import Tensor, jit, nn
from torch.nn import functional as F

logging.basicConfig(level=logging.INFO)
__logger__ = logging.getLogger(__name__)

## Old ResNet

In [None]:
@autojit
class ResNet_(nn.Module):
    """Residual Network."""

    def __init__(
        self,
        *,
        d_numerical: int,
        categories: Optional[list[int]],
        d_embedding: int,
        d: int,
        d_hidden_factor: float,
        n_layers: int,
        activation: str,
        normalization: str,
        hidden_dropout: float,
        residual_dropout: float,
        d_out: int,
    ) -> None:
        super().__init__()

        def make_normalization():
            return {"batchnorm": nn.BatchNorm1d, "layernorm": nn.LayerNorm}[
                normalization
            ](d)

        self.main_activation = get_activation_fn(activation)
        self.last_activation = get_nonglu_activation_fn(activation)
        self.residual_dropout = residual_dropout
        self.hidden_dropout = hidden_dropout

        d_in = d_numerical
        d_hidden = int(d * d_hidden_factor)

        if categories is not None:
            d_in += len(categories) * d_embedding
            category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
            self.register_buffer("category_offsets", category_offsets)
            self.category_embeddings = nn.Embedding(sum(categories), d_embedding)
            nn.init.kaiming_uniform_(self.category_embeddings.weight, a=sqrt(5))
            print(f"{self.category_embeddings.weight.shape=}")

        self.first_layer = nn.Linear(d_in, d)
        self.layers = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "norm": make_normalization(),
                        "linear0": nn.Linear(
                            d, d_hidden * (2 if activation.endswith("glu") else 1)
                        ),
                        "linear1": nn.Linear(d_hidden, d),
                    }
                )
                for _ in range(n_layers)
            ]
        )
        self.last_normalization = make_normalization()
        self.head = nn.Linear(d, d_out)

    def forward(self, x_num: Tensor, x_cat: Optional[Tensor] = None) -> Tensor:
        """Forward pass.

        Parameters
        ----------
        x_num: Tensor
        x_cat: Optional[Tensor]

        Returns
        -------
        Tensor
        """
        tensors = []
        if x_num is not None:
            tensors.append(x_num)
        if x_cat is not None:
            assert self.category_embeddings is not None, "No category embeddings!"
            assert self.category_offsets is not None, "No category offsets!"

            tensors.append(
                self.category_embeddings(
                    x_cat + self.category_offsets[None]  # type: ignore[index]
                ).view(x_cat.size(0), -1)
            )
        x = torch.cat(tensors, dim=-1)

        x = self.first_layer(x)
        for layer in self.layers:
            layer = cast(dict[str, nn.Module], layer)
            z = x
            z = layer["norm"](z)
            z = layer["linear0"](z)
            z = self.main_activation(z)

            if self.hidden_dropout:
                z = F.dropout(z, self.hidden_dropout, self.training)

            z = layer["linear1"](z)

            if self.residual_dropout:
                z = F.dropout(z, self.residual_dropout, self.training)
            x = x + z
        x = self.last_normalization(x)
        x = self.last_activation(x)
        x = self.head(x)
        x = x.squeeze(-1)

        return x


@autojit
class ResNetBlock(nn.Sequential):
    """Pre-activation ResNet block.

    References
    ----------
    - | Identity Mappings in Deep Residual Networks
      | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
      | European Conference on Computer Vision 2016
      | https://link.springer.com/chapter/10.1007/978-3-319-46493-0_38
    """

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": None,
        "num_subblocks": 2,
        "subblocks": [
            # {
            #     "__name__": "BatchNorm1d",
            #     "__module__": "torch.nn",
            #     "num_features": int,
            #     "eps": 1e-05,
            #     "momentum": 0.1,
            #     "affine": True,
            #     "track_running_stats": True,
            # },
            ReverseDense.HP,
        ],
    }

    def __init__(self, **HP: Any) -> None:
        super().__init__()

        self.CFG = HP = deep_dict_update(self.HP, HP)

        assert HP["input_size"] is not None, "input_size is required!"

        for layer in HP["subblocks"]:
            if layer["__name__"] == "Linear":
                layer["in_features"] = HP["input_size"]
                layer["out_features"] = HP["input_size"]
            if layer["__name__"] == "BatchNorm1d":
                layer["num_features"] = HP["input_size"]
            else:
                layer["input_size"] = HP["input_size"]
                layer["output_size"] = HP["input_size"]

        subblocks: OrderedDict[str, nn.Module] = OrderedDict()

        for k in range(HP["num_subblocks"]):
            key = f"subblock{k}"
            module = nn.Sequential(
                *[initialize_from_config(layer) for layer in HP["subblocks"]]
            )
            self.add_module(key, module)
            subblocks[key] = module

        # self.subblocks = nn.Sequential(subblocks)
        super().__init__(subblocks)


@autojit
class ResNet(nn.ModuleList):
    """A ResNet model."""

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": None,
        "num_blocks": 5,
        "blocks": [
            ResNetBlock.HP,
            ReZeroCell.HP,
        ],
    }

    def __init__(self, **HP: Any) -> None:
        super().__init__()
        self.CFG = HP = deep_dict_update(self.HP, HP)

        assert HP["input_size"] is not None, "input_size is required!"

        # pass the input_size to the subblocks
        for block_cfg in HP["blocks"]:
            if "input_size" in block_cfg:
                block_cfg["input_size"] = HP["input_size"]

        blocks: list[nn.Module] = []

        for k in range(HP["num_blocks"]):
            key = f"block{k}"
            module = nn.Sequential(
                *[initialize_from_config(layer) for layer in HP["blocks"]]
            )
            self.add_module(key, module)
            blocks.append(module)

        super().__init__(blocks)

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        r"""Forward pass.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        for block in self:
            x = x + block(x)
        return x

In [None]:
from torchinfo import summary

In [None]:
model = ResNet(input_size=2)
summary(model)

## ResNet V2

In [None]:
from typing import TypeVar, overload

from tsdm.util.decorators import trace

Self = TypeVar("Self")  # noqa: Y001


@autojit
class ResNet(nn.Sequential):
    """A ResNet model."""

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": None,
        "num_blocks": 5,
        "block": ResNetBlock.HP,
    }

    def __new__(cls: type[Self], *blocks, **HP) -> Self:
        print(f"__new__   {len(blocks)=} {HP=}")
        assert len(blocks) ^ len(HP), "Provide either blocks, or hyperparameters!"

        if HP:
            return cls.from_hyperparameters(**HP)

        return super().__new__(cls)

    def __init__(self, *args, **kwargs) -> None:
        print(f"__init__, {len(args)=}  {kwargs=}")
        if kwargs:
            return
        super().__init__(*args, **kwargs)

    @classmethod
    def from_hyperparameters(
        cls: type[ResNetType],
        *,
        input_size: int,
        num_blocks: int = 5,
        block_cfg: dict = ResNetBlock.HP,
    ) -> ResNetType:
        """Create a ResNet model from hyperparameters."""

        if "input_size" in block_cfg:
            block_cfg["input_size"] = input_size

        blocks: list[nn.Module] = []
        for k in range(num_blocks):
            module: nn.Module = initialize_from_config(block_cfg)
            blocks.append(module)
        return cls(*blocks)

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        r"""Forward pass.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        for block in self:
            x = x + block(x)
        return x

In [None]:
model = ResNet.from_hyperparameters(input_size=2)
jit.save(jit.script(model), "model.pt")
model = jit.load("model.pt")
summary(model)

In [None]:
model = ResNet(input_size=2)
jit.save(jit.script(model), "model.pt")
model = jit.load("model.pt")
summary(model)

In [None]:
model = ResNet(nn.Linear(3, 3), nn.Linear(3, 3))
summary(model)

## ReZero v2

In [None]:
from typing import Union

from torch._jit_internal import _copy_to_script_wrapper


@autojit
class ReZero(ResNet):
    r"""A ReZero model."""

    def __init__(self, *blocks: nn.Module, **kwargs):
        if kwargs:
            return
        super().__init__(*blocks)
        weights = torch.zeros(len(blocks))  # if weights is None else weights
        self.register_parameter("weights", nn.Parameter(weights.to(torch.float)))

    @_copy_to_script_wrapper
    def __getitem__(
        self: nn.Sequential, item: Union[int, slice]
    ) -> Union[nn.Module, nn.Sequential]:
        r"""Get a sub-model."""
        modules: list[nn.Module] = list(self._modules.values())
        if isinstance(item, slice):
            return ReZero(*modules[item], weights=self.weights[item])  # type: ignore[index]
        return modules[item]

    @jit.export
    def __len__(self) -> int:
        return len(self.weights)

    @jit.export
    def length(self) -> int:
        return len(self.weights)

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        r"""Forward pass.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        for k, block in enumerate(self):
            x = x + self.weights[k] * block(x)
        return x

In [None]:
model = ReZero.from_hyperparameters(input_size=2)
# print(summary(model))
jit.save(jit.script(model), "model.pt")
model = jit.load("model.pt")
summary(model)

In [None]:
model = ReZero(input_size=2)
jit.save(jit.script(model), "model.pt")
model = jit.load("model.pt")
summary(model)

## Series v2

In [None]:
@autojit
class Series(nn.Sequential):
    r"""Pre-activation ResNet block.

    References
    ----------
    - | Identity Mappings in Deep Residual Networks
      | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
      | European Conference on Computer Vision 2016
      | https://link.springer.com/chapter/10.1007/978-3-319-46493-0_38
    """

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": None,
        "layer_cfg": [
            ReverseDense.HP,
            ReverseDense.HP,
            # {
            #     "__name__": "BatchNorm1d",
            #     "__module__": "torch.nn",
            #     "num_features": int,
            #     "eps": 1e-05,
            #     "momentum": 0.1,
            #     "affine": True,
            #     "track_running_stats": True,
            # },
        ],
    }

    def __new__(cls: type[Self], *modules, **HP) -> Self:
        assert not (
            len(modules) & len(HP)
        ), "Provide either modules, or hyperparameters!"

        if HP:
            return cls.from_hyperparameters(**HP)

        return super().__new__(cls)

    def __init__(self, *args, **kwargs) -> None:
        if kwargs:
            return
        super().__init__(*args, **kwargs)

    @classmethod
    def from_hyperparameters(
        cls,
        *,
        input_size: int,
        layers_cfg: list[dict] = [ReverseDense.HP, ReverseDense.HP],
    ):
        """Create a Series model from hyperparameters."""

        for layer in layers_cfg:
            if layer["__name__"] == "Linear":
                layer["in_features"] = input_size
                layer["out_features"] = input_size
            if layer["__name__"] == "BatchNorm1d":
                layer["num_features"] = input_size
            else:
                layer["input_size"] = input_size
                layer["output_size"] = input_size

        layers: list[nn.Module] = []
        for layer_cfg in layers_cfg:
            module: nn.Module = initialize_from_config(layer_cfg)
            layers.append(module)

        return cls(*layers)

In [None]:
model = Series(input_size=5)
summary(model)

## Class Repeat

In [None]:
from copy import deepcopy

## Repeat v2

In [None]:
class Repeat(nn.Sequential):
    """Repeat a module multiple times."""

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__module__": __module__,  # type: ignore[name-defined]
        "num_copies": 2,
        "clone_modules": True,
    }

    def __init__(
        self, *args, num_copies: int = 2, clone_layers: bool = True, **kwargs
    ) -> None:
        if kwargs:
            return
        super().__init__(*args)

    @classmethod
    def from_hyperparameters(
        cls,
        *,
        input_size: int,
        layers_cfg: list[dict] = [ReverseDense.HP, ReverseDense.HP],
        num_copies: int = 2,
        clone_layers: bool = True,
    ):
        """Create a Repeat model from hyperparameters."""

        for layer in layers_cfg:
            if layer["__name__"] == "Linear":
                layer["in_features"] = input_size
                layer["out_features"] = input_size
            if layer["__name__"] == "BatchNorm1d":
                layer["num_features"] = input_size
            else:
                layer["input_size"] = input_size
                layer["output_size"] = input_size

        layers: list[nn.Module] = []
        for layer_cfg in layers_cfg:
            module: nn.Module = initialize_from_config(layer_cfg)
            layers.append(module)

        blocks = []
        for k in range(num_copies):
            if clone_layers:
                block = Series(*[deepcopy(module) for module in layers])
            else:
                block = Series(*layers)
            blocks.append(block)

        return cls(*blocks)

In [None]:
model = Repeat(input_size=5, clone_layers=False)
summary(model)