In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [85]:
from torch.export import export, Dim

In [None]:
from collections.abc import Mapping

import torch
from torch import Tensor, jit, nn
from typing_extensions import Any, Final, Optional, Self

In [98]:
class ReZeroCell(nn.Module):
    r"""ReZero module.

    Simply multiplies the inputs by a scalar initialized to zero.
    """

    HP = {
        "__name__": __qualname__,
        "__module__": __name__,
    }
    r"""The hyperparameter dictionary"""

    # CONSTANTS
    learnable: Final[bool]
    r"""CONST: Whether the scalar is learnable."""

    # PARAMETERS
    scalar: Tensor
    r"""The scalar to multiply the inputs by."""

    def __init__(
        self,
        module: Optional[nn.Module] = None,
        *,
        scalar: Optional[Tensor] = None,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.learnable = bool(learnable)
        self.module = module
        initial_value = torch.as_tensor(0.0 if scalar is None else scalar)
        self.scalar = nn.Parameter(initial_value) if self.learnable else initial_value


    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        """.. Signature:: ``(...,) -> (...,)``."""
        if self.module is None:
            return self.scalar * x
        return self.scalar * self.module(x)


In [107]:
mod = ReZeroCell(nn.Linear(3,3))
m = jit.script(mod)

In [108]:
m.graph

graph(%self : __torch__.___torch_mangle_9.ReZeroCell,
      %x.1 : Tensor):
  %scalar : Tensor = prim::GetAttr[name="scalar"](%self)
  %module : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="module"](%self)
  %9 : Tensor = prim::CallMethod[name="forward"](%module, %x.1) # /tmp/ipykernel_60887/4245271241.py:40:29
  %10 : Tensor = aten::mul(%scalar, %9) # /tmp/ipykernel_60887/4245271241.py:40:15
  return (%10)

In [94]:
aamodel = torch.compile(mod, dynamic=True)

In [64]:
model(torch.randn(1))

tensor([0.], grad_fn=<CompiledFunctionBackward>)

In [96]:
x = torch.randn(7,2)
args = (x,)
shapes = {"x" : {0: Dim("batch"), 1: Dim("feature")}}
exported_mod = export(mod, args, dynamic_shapes=shapes)

In [97]:
exported_mod(torch.randn(4))

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [92]:
import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[32, 64]", arg1_1: "f32[32]", arg2_1: "f32[64, 128]", arg3_1: "f32[64]", arg4_1: "f32[32]", l_x1_: "f32[s0, 64]", l_x2_: "f32[s0, 128]"):
            # File: /tmp/ipykernel_60887/258828328.py:17, code: out1 = self.branch1(x1)
            t: "f32[64, 32]" = torch.ops.aten.t.default(arg0_1);  arg0_1 = None
            addmm: "f32[s0, 32]" = torch.ops.aten.addmm.default(arg1_1, l_x1_, t);  arg1_1 = l_x1_ = t = None
            relu: "f32[s0, 32]" = torch.ops.aten.relu.default(addmm);  addmm = None
            
            # File: /tmp/ipykernel_60887/258828328.py:18, code: out2 = self.branch2(x2)
            t_1: "f32[128, 64]" = torch.ops.aten.t.default(arg2_1);  arg2_1 = None
            addmm_1: "f32[s0, 64]" = torch.ops.aten.addmm.default(arg3_1, l_x2_, t_1);  arg3_1 = l_x2_ = t_1 = None
            relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(addmm_1);  addmm_1 = None
            
  

In [37]:
aclass ReZero(nn.ModuleList):
    r"""A ReZero model."""

    def __init__(self, blocks: nn.Module, weights: Optional[Tensor] = None) -> None:
        super().__init__(blocks)

        self.weights = nn.Parameter(
            torch.zeros(len(blocks)) if weights is None else weights
        )
        # self.blocks = nn.ModuleList(blocks)
        # super().__init__(blocks)

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        for k, block in enumerate(self):
            x = x + self.weights[k] * block(x)
        return x

In [40]:
mod = ReZero([nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3,3)])
m = jit.script(mod)
m

RecursiveScriptModule(
  original_name=ReZero
  (0): RecursiveScriptModule(original_name=Linear)
  (1): RecursiveScriptModule(original_name=Linear)
  (2): RecursiveScriptModule(original_name=Linear)
)

In [41]:
m(torch.randn(7, 3))

tensor([[ 0.4803, -1.1254, -0.8179],
        [-0.8927, -0.9191,  0.3820],
        [-0.3159,  2.1621, -1.3353],
        [-0.3503,  0.8862,  1.0051],
        [-0.4189,  0.7528, -0.8066],
        [ 2.1228,  1.4426,  1.6509],
        [-0.5515, -1.2795,  0.0146]], grad_fn=<DifferentiableGraphBackward>)

In [42]:
m

RecursiveScriptModule(
  original_name=ReZero
  (0): RecursiveScriptModule(original_name=Linear)
  (1): RecursiveScriptModule(original_name=Linear)
  (2): RecursiveScriptModule(original_name=Linear)
)

In [45]:
m[:1].weights

Parameter containing:
tensor([0.], requires_grad=True)

In [1]:
aa
class Constant(nn.Module):
    r"""Constant function."""

    def __init__(self, value: float | Tensor) -> None:
        super().__init__()
        self.register_buffer("value", torch.as_tensor(value))

    def forward(self, _: Tensor) -> Tensor:
        return self.value


In [5]:
def foo(x: Tensor, y: Tensor) -> None:
    
    for x_, y_ in zip(x, y):
        print(x_, y_)
        print(x_ + y_)

In [7]:
jit.script(foo)(torch.rand(3, 3), torch.rand(3, 3))

 0.9633
 0.2476
 0.9208
[ CPUFloatType{3} ]  0.1607
 0.1260
 0.8766
[ CPUFloatType{3} ]
 1.1239
 0.3736
 1.7974
[ CPUFloatType{3} ]
 0.0732
 0.6961
 0.3948
[ CPUFloatType{3} ]  0.1357
 0.2627
 0.4702
[ CPUFloatType{3} ]
 0.2089
 0.9588
 0.8650
[ CPUFloatType{3} ]
 0.4245
 0.0720
 0.6513
[ CPUFloatType{3} ]  0.5397
 0.8934
 0.6739
[ CPUFloatType{3} ]
 0.9642
 0.9655
 1.3251
[ CPUFloatType{3} ]


In [4]:
jit.script(Constant(1.0))()

RuntimeError: forward() is missing value for argument '_'. Declaration: forward(__torch__.Constant self, Tensor _) -> Tensor

In [None]:
from importlib import metadata
from typing import Any

import torch
import torchinfo

from torch import jit
from linodenet.models import LinODE, LinODECell, LinODEnet
from linodenet.models.filters import SequentialFilter
from linodenet.projections.functional import skew_symmetric, symmetric

NUM_DIM = 128
DEVICE = torch.device("cpu")
DTYPE = torch.float32


def join_dicts(d: dict[str, Any]) -> dict[str, Any]:
    """Recursively join dict by composing keys with '/'."""
    result = {}
    for key, val in d.items():
        if isinstance(val, dict):
            result |= join_dicts(
                {f"{key}/{subkey}": item for subkey, item in val.items()}
            )
        else:
            result[key] = val
    return result


def add_prefix(d: dict[str, Any], /, prefix: str) -> dict[str, Any]:
    return {f"{prefix}/{key}": item for key, item in d.items()}


# OPTIMIZER_CONIFG = {
#     "__name__": "SGD",
#     "lr": 0.001,
#     "momentum": 0,
#     "dampening": 0,
#     "weight_decay": 0,
#     "nesterov": False,
# }

# OPTIMIZER_CONFIG = {
#     "__name__": "Adam",
#     "lr": 0.01,
#     "betas": (0.9, 0.999),
#     "eps": 1e-08,
#     "weight_decay": 0,
#     "amsgrad": False,
# }


OPTIMIZER_CONFIG = {
    "__name__": "AdamW",
    "lr": 0.001,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0.001,
    "amsgrad": False,
}


SYSTEM = {
    "__name__": "LinODECell",
    "input_size": int,
    "kernel_initialization": "skew-symmetric",
}

EMBEDDING = {
    "__name__": "ConcatEmbedding",
    "input_size": int,
    "hidden_size": int,
}
FILTER = {
    "__name__": "SequentialFilter",
    "input_size": int,
    "hidden_size": int,
    "autoregressive": True,
}

# FILTER = {
#     "__name__": "RecurrentCellFilter",
#     "concat": True,
#     "input_size": int,
#     "hidden_size": int,
#     "autoregressive": True,
#     "Cell": {
#         "__name__": "GRUCell",
#         "input_size": int,
#         "hidden_size": int,
#         "bias": True,
#         "device": None,
#         "dtype": None,
#     },
# }
from linodenet.models.encoders import ResNet, iResNet

# ENCODER = {"__name__": "ResNet", "__module__": "linodenet.models.encoders","input_size": int, "nblocks": 5, "rezero": True}
# DECODER = {"__name__": "ResNet", "__module__": "linodenet.models.encoders","input_size": int, "nblocks": 5, "rezero": True}


LR_SCHEDULER_CONFIG = {
    "__name__": "ReduceLROnPlateau",
    "mode": "min",
    # (str) – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
    "factor": 0.1,
    # (float) – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
    "patience": 10,
    # (int) – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
    "threshold": 0.0001,
    # (float) – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
    "threshold_mode": "rel",
    # (str) – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
    "cooldown": 0,
    # (int) – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
    "min_lr": 1e-08,
    # (float or list) – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
    "eps": 1e-08,
    # (float) – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
    "verbose": True,
    # (bool) – If True, prints a message to stdout for each update. Default: False.
}

MODEL_CONFIG = {
    "__name__": "LinODEnet",
    "input_size": NUM_DIM,
    "hidden_size": 128,
    "embedding_type": "concat",
    "Filter": SequentialFilter.HP,
    "System": SYSTEM,
    "Encoder": ResNet.HP,
    "Decoder": ResNet.HP,
    "Embedding": EMBEDDING,
}


HPARAMS = join_dicts({
    "Optimizer": OPTIMIZER_CONFIG,
    "LR_Scheduler": LR_SCHEDULER_CONFIG,
    "Model": MODEL_CONFIG,
})

In [None]:
MODEL = LinODEnet
model = MODEL(**MODEL_CONFIG)
model.to(device=DEVICE, dtype=DTYPE)
torchinfo.summary(model)

In [None]:
jit.save(model, "model.pt")
model = jit.load("model.pt")
torchinfo.summary(model)