In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'
%load_ext autoreload
%autoreload 2

In [None]:
from collections.abc import Sequence
from math import prod, sqrt
from typing import Optional, Protocol, Union, runtime_checkable

import torch
import torch.linalg
from numpy.typing import NDArray
from scipy import stats
from torch import BoolTensor, Tensor, jit

# import linodenet
# from linodenet.constants import TRUE
# from linodenet.projections import functional as projections
# from linodenet.types import Device, Dtype, Shape

In [None]:
import jit
from torch import Tensor, jit
from typing import Union


def wrapped_matrix_norm(x: Tensor, p: Union[int, str] = "fro") -> Tensor:
    return torch.linalg.matrix_norm(r, ord=p)


jit.script(wrapped_matrix_norm)

In [None]:
jit.script(matrix_norm)

In [None]:
torch.linalg.matrix_norm(torch.randn(3, 3), ord=None)

In [None]:
func = linodenet.initializations.canonical_skew_symmetric

In [None]:
torch.compile(func)

In [None]:
torch.export.export(linodenet.initializations.canonical_skew_symmetric)

In [None]:
def canonical_skew_symmetric(
    n: Shape, device: Device = None, dtype: Dtype = None
) -> Tensor:
    r"""Return the canonical skew symmetric matrix of size $n=2k$.

    .. math:: 𝕁_n = 𝕀_n ⊗ \begin{bmatrix}0 & +1 \\ -1 & 0\end{bmatrix}

    Normalized such that if $x∼𝓝(0,1)$, then $A⋅x∼𝓝(0,1)$.
    """
    # convert to tuple
    tup = (n,) if isinstance(n, int) else tuple(n)
    dim, size = tup[-1], tup[:-1]
    assert dim % 2 == 0, "The dimension must be divisible by 2!"
    dim //= 2

    J1 = torch.tensor([[0, 1], [-1, 0]], device=device, dtype=dtype)
    eye = torch.eye(dim, device=device, dtype=dtype)
    J = torch.kron(J1, eye)
    # ones = torch.ones(size, device=device, dtype=dtype)
    return J.repeat(size)
    # return torch.einsum("..., de -> ...de", ones, J)

In [None]:
def foo(x: Tensor, *, c: float) -> Tensor:
    return x * c

In [None]:
jit.script(foo)

In [None]:
torch.randn(1, 2, 3, 4) @ torch.randn(4, 4)

In [None]:
%%timeit
linodenet.initializations.canonical_skew_symmetric(1024)

In [None]:
%%timeit
linodenet.initializations.canonical_symplectic(1024)

In [None]:
import torch

In [None]:
m = torch.nn.Linear(3, 4)

In [None]:
from collections.abc import Mapping

import torch
from torch import Tensor, nn


class Foo(nn.Module, Mapping):
    tensors: dict[str, tensor]

    def __len__(self):
        return len(self.tensors)

    def __iter

In [None]:
import logging

import torch
from pkg_resources import load_entry_point
from torch import Tensor, jit, nn

from linodenet.lib import singular_triplet
from linodenet.parametrize import SimpleParametrization, SpectralNormalization
from linodenet.projections import is_symmetric, symmetric
from linodenet.testing import check_model

In [None]:
from copy import deepcopy

from torch.linalg import matrix_norm
from torch.optim import SGD

In [None]:
# create model, parametrization and inputs
inputs = torch.randn(2, 3)
model = nn.Linear(3, 3)
weight = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
with torch.no_grad():
    model.weight.copy_(weight)
    assert matrix_norm(model.weight, ord=2) > 1

print(f"Original weight = {model.weight}")
print(f"Original norm =  {matrix_norm(model.weight, ord=2)}")

spec = SpectralNormalization(model.weight)
spec.weight, spec.original_weight

In [None]:
spec.weight.norm().backward()
# spec.zero_grad(set_to_none=True)

In [None]:
matrix_norm(spec.weight)

In [None]:
# spec.reset_cache()  # <--- never forget
assert spec.weight is spec.cached_tensors["weight"]
# assert spec.parametrized_tensor["weight"]
# cloned_model = deepcopy(model)

# register the parametrization
model.register_module("spec", spec)
# remove the weight attribute (it still exists on the parametrization)
del model.weight

# register the parametrization's weight-buffer as a buffer
model.register_buffer("weight", model.spec.cached_tensors["weight"])

# register the parametrization's weight as a parameter (optional)
model.register_parameter(
    "parametrized_weight", model.spec.parametrized_tensors["weight"]
)
model.weight.norm().backward()
model.zero_grad(set_to_none=True)

In [None]:
optim = SGD(model.spec.parameters(), lr=0.1)
assert model.weight is model.spec.weight
assert model.parametrized_weight is model.spec.parametrized_tensors["weight"]
assert matrix_norm(model.weight, ord=2) <= 1
model.weight, model.parametrized_weight

In [None]:
spec.reset_cache()
model.weight, model.parametrized_weight

In [None]:
spec.original_weight

In [None]:
spec.weight.detach_()
spec.weight.copy_(spec.original_weight)
spec.weight.norm().backward()

In [None]:
raise

In [None]:
model.zero_grad(set_to_none=True)
r = -model(inputs).norm()
r.backward()

In [None]:
model.weight.grad, model.parametrized_weight.grad

In [None]:
dict(model.named_parameters())

In [None]:
dict(model.named_buffers())

In [None]:
spec.reset_cache()  # <--- never forget

In [None]:
# model.zero_grad(set_to_none=True)
# model.spec.reset_cache()
r = -model(inputs).norm()
r.backward()
# print(model.weight.grad, model.parametrized_weight.grad)

cached_weigth_before = model.weight.clone()
params_weight_before = model.parametrized_weight.clone()
assert (cached_weigth_before == params_weight_before).all()

# perform a step
optim.step()
cached_weigth_step = model.weight.clone()
params_weight_step = model.parametrized_weight.clone()
assert (cached_weigth_before == cached_weigth_step).all()
assert not (params_weight_before == params_weight_step).all()
assert not (params_weight_step == cached_weigth_step).all()

# update the chaches
model.spec.reset_cache()
cached_weigth_update = model.weight.clone()
params_weight_update = model.parametrized_weight.clone()
assert model.weight is model.spec.weight
assert model.parametrized_weight is model.spec.parametrized_tensors["weight"]

In [None]:
cached_weigth_step == cached_weigth_update

In [None]:
cached_weigth_update == params_weight_update

In [None]:
assert not (cached_weigth_step == cached_weigth_update).all()
assert (cached_weigth_update == params_weight_update).all()


# after = torch.cat([model.weight.clone(), model.parametrized_weight.clone()], dim=-1)
# assert not torch.allclose(before, after)
# print(before - after)
# print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
#
# print(torch.cat([model.weight, model.parametrized_weight], dim=-1))

In [None]:
cached_weigth_step, cached_weigth_update

In [None]:
params_weight_step, params_weight_update

In [None]:
model.spec.parametrized_tensor["weight"]

In [None]:
print("Step 1 --------------")
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(matrix_norm(model.weight, ord=2))
model.zero_grad(set_to_none=True)
r = model(inputs).norm()
r.backward()
# assert model.parametrized_weight.grad is not None
optim.step()
model.spec.reset_cache()
print("Recompute... --------------")
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(matrix_norm(model.weight, ord=2))
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(matrix_norm(model.weight, ord=2))
print("Step 2 --------------")
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(matrix_norm(model.weight, ord=2))
model.zero_grad(set_to_none=True)
r = model(inputs).norm()
r.backward()
# assert model.parametrized_weight.grad is not None
optim.step()
model.spec.reset_cache()
print("Recompute... --------------")
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(torch.cat([model.spec.weight, model.spec.parametrized_tensor["weight"]], dim=-1))
print(matrix_norm(model.weight, ord=2))
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(matrix_norm(model.weight, ord=2))
print("Step 3 --------------")
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(matrix_norm(model.weight, ord=2))
model.zero_grad(set_to_none=True)
r = model(inputs).norm()
r.backward()
optim.step()
model.spec.reset_cache()
print("Recompute... --------------")
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(matrix_norm(model.weight, ord=2))
print(torch.cat([model.weight, model.parametrized_weight], dim=-1))
print(matrix_norm(model.weight, ord=2))


assert model.parametrized_weight is model.spec.parametrized_tensor["weight"]

In [None]:
model.zero_grad(set_to_none=True)
r = -model(inputs).norm()
r.backward()

In [None]:
model.parametrized_weight.grad

In [None]:
model.weight.grad is None

In [None]:
with torch.no_grad():
    model = nn.Linear(3, 3)
    display(
        id(model.weight),
    )
    spec = SpectralNormalization(model.weight)
    model.spec = spec

    del model.weight
    model.register_parameter(
        "parametrized_weight", model.spec.parametrized_tensor["weight"]
    )

    model.register_buffer("weight", model.spec.weight)
    # model.weight.copy_(model.parametrized_weight)

    print("-" * 80)
    display(
        f"{id(spec.weight)=}",
        spec.weight,
        f"{id(spec.parametrized_tensor['weight'])=}",
        spec.parametrized_tensor["weight"],
        spec.parametrized_tensor["weight"].grad,
    )
    print("-" * 80)
    display(
        id(model.weight),
        model.weight,
        id(model.parametrized_weight),
        model.parametrized_weight,
        model.parametrized_weight.grad,
    )

    spec.recompute_cache()


inputs = torch.randn(2, 3)
r = model(inputs)
r.norm().backward()

In [None]:
check_model(model, input_args=inputs, test_jit=True)

In [None]:
model.parametrized_weight

In [None]:
import sys

import torch
from torch import nn

from linodenet.models.encoders.invertible_layers import LinearContraction
from linodenet.parametrize import SpectralNormalization

model = LinearContraction(4, 4)

print(model.weight, model.cached_weight)
model.recompute_cache()
print(model.weight, model.cached_weight)

print("==============================================================")

torch.manual_seed(0)
model = nn.Linear(4, 4)
param = nn.Parameter(model.weight.clone().detach() * 2)
spec = SpectralNormalization(param)

print(spec.parametrized_tensor["weight"], spec.weight, sep="\n")
assert spec.parametrized_tensor["weight"] is param
assert spec.weight is spec.cached_tensors["weight"]
print("==============================================================")

spec.cached_tensors["weight"].copy_(spec.parametrized_tensor["weight"])

print(spec.parametrized_tensor["weight"], spec.weight, sep="\n")
assert spec.parametrized_tensor["weight"] is param
assert spec.weight is spec.cached_tensors["weight"]