In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
import torch.nn.utils.parametrize as parametrize
from torch import Tensor, jit, nn


def symmetric(X):
    return X.triu() + X.triu(1).transpose(-1, -2)


X = torch.rand(3, 3)
A = symmetric(X)
assert torch.allclose(A, A.T)  # A is symmetric
print(A)

In [None]:
class LinearSymmetric(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(n_features, n_features))

    def forward(self, x):
        A = symmetric(self.weight)
        return x @ A

In [None]:
layer = LinearSymmetric(3)
out = layer(torch.rand(8, 3))

In [None]:
class Symmetric(nn.Module):
    def forward(self, X):
        return X.triu() + X.triu(1).transpose(-1, -2)

In [None]:
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Symmetric())

In [None]:
class Skew(nn.Module):
    def forward(self, X):
        A = X.triu(1)
        return A - A.transpose(-1, -2)


cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)
parametrize.register_parametrization(cnn, "weight", Skew())
# Print a few kernels
print(cnn.weight[0, 1])
print(cnn.weight[2, 2])

In [None]:
class Rezero(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(0, dtype=float), requires_grad=True)

    def forward(self, x: Tensor) -> Tensor:
        return self.alpha * x

In [None]:
from torchinfo import summary
from torch.optim import SGD

In [None]:
model = nn.Linear(3, 3)


parametrize.register_parametrization(model, "weight", Rezero())
summary(model)

In [None]:
jit.script(model)

In [None]:
optim = SGD(model.parameters(), lr=0.01)

In [None]:
x = torch.randn(10, 3)

In [None]:
model.zero_grad()
loss = torch.mean(model(model(x)) ** 2)
loss.backward()
optim.step()

In [None]:
model.parametrizations.weight[0].alpha

In [None]:
model.weight

In [None]:
model = jit.script(model)

In [None]:
module = nn.ModuleList([nn.Linear(3, 3), nn.Linear(3, 3)])

scripted = jit.script(module)
scripted + scripted

In [None]:
nn.ModuleList([nn.Linear(3, 3), nn.Linear(3, 3)])
module[1:]

In [None]:
module = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3))[
    1:
]  # + nn.Sequential(nn.Linear(3,3), nn.Linear(3,3))

In [None]:
from typing import Final


class foo(nn.Module):
    a: Final[list[str]] = ["a"]

In [None]:
type(jit.script(foo()).a)

In [None]:
type(foo.a)