# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import logging
from functools import wraps
from typing import Any

import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import logging
from typing import Any, TypeVar

import torch
from torch import Tensor, jit, nn

from tsdm.utils.decorators import trace

In [None]:
nnModuleType = TypeVar("nnModuleType", bound=nn.Module)

In [None]:
?wraps

In [None]:
type(f"AutoJIT@{cls.__name__}")

$${\displaystyle \|A\|_{p,q}=\left(\frac{1}{n}\sum _{j=1}^{n}\left(\frac{1}{m}\sum _{i=1}^{m}|a_{ij}|^{p}\right)^{\frac {q}{p}}\right)^{\frac {1}{q}}.}$$

In [None]:
xx

In [None]:
def autojit(
    base_class: type[nnModuleType], /, *, inherit: bool = False
) -> type[nnModuleType]:
    assert issubclass(base_class, nn.Module)

    @wraps(base_class, updated=())
    class WrappedClass(base_class):  # type: ignore  # pylint: disable=too-few-public-methods
        r"""A simple Wrapper."""

        @trace
        def __new__(cls, *args: Any, **kwargs: Any) -> nnModuleType:  # type: ignore[misc]
            print(f"{cls=}, {args=}, {kwargs=}")
            instance: nnModuleType = super().__new__(cls)
            instance.__init__(*args, **kwargs)
            scripted: nnModuleType = jit.script(instance)
            # If __new__() does not return an instance of cls, then the new instance’s __init__() method will not be invoked!
            return scripted

    assert issubclass(WrappedClass, base_class)
    return WrappedClass

In [None]:
@autojit
class Series(nn.ModuleList):
    """A ResNet model."""

    @trace
    def __init__(self, *modules: nn.Module) -> None:
        print("__INIT__ CALLED")
        super().__init__(modules)

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        r"""Forward pass.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        for block in self:
            x = block(x)
        return x


# @autojit
class ResNet(Series):
    @trace
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        r"""Forward pass.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        for block in self:
            x = x + block(x)
        return x

In [None]:
blocks = [
    nn.Linear(4, 4),
    nn.ReLU(),
    nn.Linear(4, 4),
]
x = torch.randn(7, 4)
model = Series(*blocks)

In [None]:
y = model(x)
torch.linalg.norm(y).backward()

In [None]:
model = ResNet(*blocks)

In [None]:
y = model(x)
torch.linalg.norm(y).backward()

In [None]:
class A:
    @trace
    def __new__(cls, *args, **kwargs):
        return super().__new__(cls, *args, **kwargs)

    @trace
    def __init__(self):
        super().__init__()


class B:
    @trace
    def __new__(cls, *args, **kwargs):
        return super().__new__(cls, *args, **kwargs)

    @trace
    def __init__(self):
        super().__init__()


obj = A()

In [None]:
obj = B()

In [None]:
class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.layer = nn.Identity()

In [None]:
MyModule()