# Title

In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
from torch import jit, Tensor, nn
from torch.nn import Module

In [3]:
from linodenet.models import LinearContraction

model = LinearContraction(3, 4)

In [9]:
torch.jit.RecursiveScriptModule

In [4]:
issubclass(torch.jit.RecursiveScriptModule, nn.Module)

In [2]:
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [8]:
import torch
from torch import jit, Tensor, nn
from torch.nn import Module

In [15]:
r"""Module Docstring."""


from functools import wraps


def wrapfunc(other):
    def autodeco(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            return other(func(*args, **kwargs))

        return wrapper

    return autodeco


def autojit_a(basecls: type[Module]) -> type[Module]:

    assert issubclass(basecls, Module)

    @wraps(basecls, updated=())
    class WrappedClass(basecls):
        def __new__(cls, *args, **kwargs):
            instance = basecls()
            return jit.script(instance)

    return WrappedClass


def autojit_b(basecls: type[Module]) -> type[Module]:

    assert issubclass(basecls, Module)

    @wraps(Module, updated=())
    class WrappedClass(Module):
        def __new__(cls, *args, **kwargs):
            instance = basecls()
            return jit.script(instance)

    return WrappedClass


def autojit_c(basecls: type[Module]) -> type[Module]:

    assert issubclass(basecls, Module)

    basecls.__new__ = wrapfunc(jit.script)(basecls.__new__)

    return basecls

In [12]:
class MyModule(Module):
    a: Tensor

    def __init__(self, a: float = 1.0):
        super().__init__()
        self.a = torch.tensor(a)

    def forward(self, x: Tensor) -> Tensor:
        return self.a * x


MyModule()

In [17]:
@autojit_a
class MyModule(Module):
    a: Tensor

    def __init__(self, a: float = 1.0):
        super().__init__()
        self.a = torch.tensor(a)

    def forward(self, x: Tensor) -> Tensor:
        return self.a * x


MyModule()

In [18]:
@autojit_b
class MyModule(Module):
    a: Tensor

    def __init__(self, a: float = 1.0):
        super().__init__()
        self.a = torch.tensor(a)

    def forward(self, x: Tensor) -> Tensor:
        return self.a * x


MyModule()

In [19]:
@autojit_c
class MyModule(Module):
    a: Tensor

    def __init__(self, a: float = 1.0):
        super().__init__()
        self.a = torch.tensor(a)

    def forward(self, x: Tensor) -> Tensor:
        return self.a * x


MyModule()