In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from importlib import metadata
from typing import Any

import torch
import torchinfo

from torch import jit
from linodenet.models import LinODE, LinODECell, LinODEnet
from linodenet.models.filters import SequentialFilter
from linodenet.projections.functional import skew_symmetric, symmetric

NUM_DIM = 128
DEVICE = torch.device("cpu")
DTYPE = torch.float32


def join_dicts(d: dict[str, Any]) -> dict[str, Any]:
    """Recursively join dict by composing keys with '/'."""
    result = {}
    for key, val in d.items():
        if isinstance(val, dict):
            result |= join_dicts(
                {f"{key}/{subkey}": item for subkey, item in val.items()}
            )
        else:
            result[key] = val
    return result


def add_prefix(d: dict[str, Any], /, prefix: str) -> dict[str, Any]:
    return {f"{prefix}/{key}": item for key, item in d.items()}


# OPTIMIZER_CONIFG = {
#     "__name__": "SGD",
#     "lr": 0.001,
#     "momentum": 0,
#     "dampening": 0,
#     "weight_decay": 0,
#     "nesterov": False,
# }

# OPTIMIZER_CONFIG = {
#     "__name__": "Adam",
#     "lr": 0.01,
#     "betas": (0.9, 0.999),
#     "eps": 1e-08,
#     "weight_decay": 0,
#     "amsgrad": False,
# }


OPTIMIZER_CONFIG = {
    "__name__": "AdamW",
    "lr": 0.001,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0.001,
    "amsgrad": False,
}


SYSTEM = {
    "__name__": "LinODECell",
    "input_size": int,
    "kernel_initialization": "skew-symmetric",
}

EMBEDDING = {
    "__name__": "ConcatEmbedding",
    "input_size": int,
    "hidden_size": int,
}
FILTER = {
    "__name__": "SequentialFilter",
    "input_size": int,
    "hidden_size": int,
    "autoregressive": True,
}

# FILTER = {
#     "__name__": "RecurrentCellFilter",
#     "concat": True,
#     "input_size": int,
#     "hidden_size": int,
#     "autoregressive": True,
#     "Cell": {
#         "__name__": "GRUCell",
#         "input_size": int,
#         "hidden_size": int,
#         "bias": True,
#         "device": None,
#         "dtype": None,
#     },
# }
from linodenet.models.encoders import ResNet, iResNet

# ENCODER = {"__name__": "ResNet", "__module__": "linodenet.models.encoders","input_size": int, "nblocks": 5, "rezero": True}
# DECODER = {"__name__": "ResNet", "__module__": "linodenet.models.encoders","input_size": int, "nblocks": 5, "rezero": True}


LR_SCHEDULER_CONFIG = {
    "__name__": "ReduceLROnPlateau",
    "mode": "min",
    # (str) – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
    "factor": 0.1,
    # (float) – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
    "patience": 10,
    # (int) – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
    "threshold": 0.0001,
    # (float) – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
    "threshold_mode": "rel",
    # (str) – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
    "cooldown": 0,
    # (int) – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
    "min_lr": 1e-08,
    # (float or list) – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
    "eps": 1e-08,
    # (float) – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
    "verbose": True,
    # (bool) – If True, prints a message to stdout for each update. Default: False.
}

MODEL_CONFIG = {
    "__name__": "LinODEnet",
    "input_size": NUM_DIM,
    "hidden_size": 128,
    "embedding_type": "concat",
    "Filter": SequentialFilter.HP,
    "System": SYSTEM,
    "Encoder": ResNet.HP,
    "Decoder": ResNet.HP,
    "Embedding": EMBEDDING,
}


HPARAMS = join_dicts(
    {
        "Optimizer": OPTIMIZER_CONFIG,
        "LR_Scheduler": LR_SCHEDULER_CONFIG,
        "Model": MODEL_CONFIG,
    }
)

In [None]:
MODEL = LinODEnet
model = MODEL(**MODEL_CONFIG)
model.to(device=DEVICE, dtype=DTYPE)
torchinfo.summary(model)

In [None]:
jit.save(model, "model.pt")
model = jit.load("model.pt")
torchinfo.summary(model)