# Problem

We want to inspect parts of the model of interest and possibly use them for things like additional losses etc.

In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
from torch import nn, jit, Tensor
from typing import NamedTuple, Union, TypeVar, Final

from linodenet.models import LinearContraction, LinODEnet

In [3]:
model = LinearContraction(3, 4)

In [4]:
model.c
_ = model(torch.randn(3))

In [5]:
f = lambda _, x: x

In [6]:
model.c

In [7]:
model = LinODEnet(7, 8)

init = (torch.randn(10), torch.randn(10, 7))

_ = model(*init)

In [8]:
model.encoder.blocks[0].bottleneck[0].c

In [10]:
dict(model.named_buffers(recurse=True))

In [5]:
import logging
from math import sqrt
from typing import Any, Final, Optional

import torch
from torch import Tensor, jit, nn
from torch.linalg import matrix_norm, vector_norm
from torch.nn import functional


class LinearContraction(nn.Module):
    r"""A linear layer `f(x) = A⋅x` satisfying the contraction property `‖f(x)-f(y)‖_2 ≤ ‖x-y‖_2`.

    This is achieved by normalizing the weight matrix by
    `A' = A⋅\min(\tfrac{c}{‖A‖_2}, 1)`, where `c<1` is a hyperparameter.

    Attributes
    ----------
    input_size:  int
        The dimensionality of the input space.
    output_size: int
        The dimensionality of the output space.
    c: Tensor
        The regularization hyperparameter
    weight: Tensor
        The weight matrix
    bias: Tensor or None
        The bias Tensor if present, else None.
    """

    input_size: Final[int]
    output_size: Final[int]

    J: nn.Module
    # C: Tensor
    # ONE: Tensor
    # spectral_norm: Tensor
    #
    # weight: Tensor
    # bias: Optional[Tensor]

    def __init__(
        self, input_size: int, output_size: int, c: float = 0.97, bias: bool = True
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size

        self.weight = nn.Parameter(Tensor(output_size, input_size))
        if bias:
            self.bias = nn.Parameter(Tensor(output_size))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()
        J = nn.Linear(3, 4)
        # self.spectral_norm = matrix_norm(self.weight, ord=2)
        self.register_buffer("ONE", torch.tensor(1.0))
        self.register_buffer("C", torch.tensor(float(c)))
        self.register_buffer("spectral_norm", matrix_norm(self.weight, ord=2))

    def reset_parameters(self) -> None:
        r"""Reset both weight matrix and bias vector."""
        nn.init.kaiming_uniform_(self.weight, a=sqrt(5))
        if self.bias is not None:
            bound = 1 / sqrt(self.input_size)
            nn.init.uniform_(self.bias, -bound, bound)

    # def extra_repr(self) -> str:
    #     return "input_size={}, output_size={}, bias={}".format(
    #         self.input_size, self.output_size, self.bias is not None
    #     )

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        r"""Signature: `[...,n] ⟶ [...,n]`.

        Parameters
        ----------
        x: Tensor

        Returns
        -------
        Tensor
        """
        # σ_max, _ = torch.lobpcg(self.weight.T @ self.weight, largest=True)
        # σ_max = torch.linalg.norm(self.weight, ord=2)
        # σ_max = spectral_norm(self.weight)
        # σ_max = torch.linalg.svdvals(self.weight)[0]
        self.spectral_norm = matrix_norm(self.weight, ord=2)
        fac = torch.minimum(self.C / self.spectral_norm, self.ONE)
        return functional.linear(x, fac * self.weight, self.bias)

In [6]:
jit.script(LinearContraction(13, 17))

In [14]:
import torch
from torch import nn, jit, Tensor
from typing import NamedTuple, Union, TypeVar, Final

dtypes = TypeVar("dtypes")

sigtype = Union[tuple[dtypes, ...], list[tuple[dtypes, ...]]]


class Signature(NamedTuple):
    inputs: sigtype[Union[type[...], str, int]]
    outputs: sigtype[Union[type[...], str, int]]


class M(nn.Module):
    # a: Final[type[Ellipsis]] = ...
    ZERO: Tensor

    """DemO"""

    def __init__(self, input_size, hidden_size, output_size):

        super().__init__()
        self.A = nn.Linear(input_size, hidden_size)
        self.B = nn.Linear(hidden_size, output_size)
        # z = torch.tensor(float('nan'))
        self.register_buffer("ZERO", torch.tensor(()))
        self.register_buffer("weightx", self.A.weight)
        self.register_buffer("z", torch.tensor(()))

    def forward(self, x):
        self.z = self.A(x)
        y = self.B(self.z)
        return y


model = jit.script(M(3, 4, 5))

model.z

In [16]:
model.state_dict()

In [10]:
Signature(
    inputs=[(..., "S"), (..., "S", 5)],
    outputs=(..., "S", 5),
)

In [170]:
help(Ellipsis)

In [145]:
model.a

In [108]:
y = model(torch.randn(4, 3))

In [109]:
model.z

In [110]:
y

In [118]:
from torch.optim import SGD

optim = SGD(model.parameters(), 0.1)

In [131]:
model.zero_grad()
loss = torch.sum(model(torch.randn(4, 3))) + torch.sum(model.z)
loss.backward()
optim.step()
model.A.weight

In [94]:
model(torch.randn(2, 3))
model.z

In [84]:
dict(model.named_buffers())

In [79]:
dict(model._buffers)

In [72]:
dir(model)

In [31]:
M()

In [2]:
import torchinfo
from linodenet.models import LinODEnet

In [3]:
model = LinODEnet(10, 10, Encoder_cfg={"nblocks": 2}, Decoder_cfg={"nblocks": 2})
torchinfo.summary(model)

In [6]:
dir(model)

In [4]:
[key for key in model.state_dict().keys() if "spectral_norm" in key]

In [10]:
?model.buffers

In [5]:
import torch
from torch import nn

In [6]:
blocks = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5), nn.Linear(5, 6))
blocks

In [8]:
blocks[::-1]

In [24]:
blocks[::-1]

In [9]:
 =2 

In [12]:
def f(x):
    return 2 * x + 1

In [1]:
import numpy as np
import numba
import torch

In [20]:
@torch.jit.script
# @numba.njit
def g() -> float:
    x: float = 0.0
    y: float = 1.0
    for k in range(10000):
        y, x = y + 1.0, y
    return y

In [21]:
%%timeit
g()

In [22]:
@torch.jit.script
# @numba.njit
def h() -> float:
    x: float = 0.0
    y: float = 1.0
    for k in range(10000):
        x = y
        y = y + 1.0
    return y

In [23]:
%%timeit
h()