In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'
%load_ext autoreload
%autoreload 2

In [None]:
from collections.abc import Sequence
from math import prod, sqrt
from typing import Optional, Protocol, Union, runtime_checkable

import torch
import torch.linalg
from numpy.typing import NDArray
from scipy import stats
from torch import BoolTensor, Tensor, jit, nn
from torch.optim import SGD
import linodenet
from linodenet.constants import TRUE
from linodenet.parametrize import *
from linodenet.projections import functional as projections
from linodenet.types import Device, Dtype, Shape
from linodenet.testing import check_jit_serialization

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as P

In [None]:
U = torch.randn(5, 5)
x = torch.randn(5)

In [None]:
torch.einsum("ij, k -> ik", U, x)

In [None]:
from collections.abc import Mapping, Sized


class Foo(Sized):
    def __iter__(self): ...

    def __len__(self): ...


class Bar(Foo, Mapping):
    def __getitem__(self, key): ...


hash(Foo())  # ✔
hash(Bar())  # ✘ TypeError: unhashable type: 'Bar'

In [None]:
bool((torch.linalg.matrix_rank(torch.randn(7, 5, 5)) <= 6).all())

In [None]:
Bar.__eq__

In [None]:
Bar.mro()

In [None]:
class RankOne(nn.Module):
    def forward(self, x, y):
        # Form a rank 1 matrix multiplying two vectors
        return x.unsqueeze(-1) @ y.unsqueeze(-2)

    def right_inverse(self, Z):
        # Project Z onto the rank 1 matrices
        U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
        # Return rescaled singular vectors
        s0_sqrt = S[0].sqrt().unsqueeze(-1)
        return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt


model = nn.Linear(4, 4)
print(hash(model))
print(dict(model.named_parameters()))
linear_rank_one = P.register_parametrization(model, "weight", RankOne())
print(hash(linear_rank_one))

print(torch.linalg.matrix_rank(linear_rank_one.weight).item())

In [None]:
dict(linear_rank_one.named_parameters())

In [None]:
B, N, M = 7, 3, 5
inputs = torch.randn(B, N)
targets = torch.randn(B, M)
model = nn.Linear(in_features=N, out_features=M, bias=False)

In [None]:
# register_parametrization(model, "weight", UpperTriangular)
param = UpperTriangular(model.weight)
delattr(model, "weight")
model.register_buffer("weight", param.cached_parameter)
model.register_module("weight_parametrization", param)
model.register_parameter("weight_original", param.original_parameter)
dict(model.named_parameters())

In [None]:
scripted = jit.script(model)
dict(scripted.named_parameters())

In [None]:
loaded = check_jit_serialization(scripted)
loaded.weight_parametrization.update_parametrization()
optim = SGD(loaded.parameters(), lr=0.1)
dict(loaded.named_parameters())

In [None]:
with torch.no_grad():
    original_loss = (loaded(inputs) - targets).norm()
    print(original_loss)

loaded.weight, loaded.weight_original, loaded.weight_parametrization

In [None]:
loaded.zero_grad(set_to_none=True)
loss = (loaded(inputs) - targets).norm()
print(loss)
loss.backward()
optim.step()
loaded.weight_parametrization.update_parametrization()

In [None]:
dict(loaded.named_parameters())

In [None]:
loss < original_loss