In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'
%load_ext autoreload
%autoreload 2

import logging

logging.basicConfig(level=logging.INFO)

from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from typing import Protocol, runtime_checkable

import torch
from torch import Tensor, jit, nn, tensor
from torch.linalg import matrix_norm

# from linodenet.lib import singular_triplet
from torchinfo import summary

# from linodenet.parametrize import Parametrization, SpectralNormalization

In [None]:
def deprecated(func=None, msg=None, /, *, category=DeprecationWarning, stacklevel=1):
    """Indicate that a class, function or overload is deprecated."""
    if isinstance(func, str):
        # used as deprecated("message") -> shift arguments
        assert msg is None
        msg = func
        func = None

    if func is None:
        # used with brackets -> decorator factory
        def decorator(decorated):
            msg = make_default_message(decorated) if msg is None else msg

            def wrapped(*args, **kwargs):
                ...

            return wrapped

        return decorator

    # used without brackets -> wrap func
    msg = make_default_message(func)

    def wrapped(*args, **kwargs):
        ...

    return wrapped

In [None]:
from collections.abc import Callable

isinstance(1, Callable)

In [None]:
from typing_extensions import deprecated

In [None]:
def deprecated(func_or_msg=None, /, **kwargs):
    if isinstance(func_or_message, str):
        # used with brackets
        def decorator():
            def wrapped():
                ...

            return wrapped

        return decorator

    # used without brackets
    default_message = ...

    def wrapped():
        ...

    return wrapped

In [None]:
foo()

In [None]:
import torch
import torch.linalg
from torch import BoolTensor, Tensor, jit

In [None]:
from abc import abstractmethod
from collections.abc import Callable
from contextlib import AbstractContextManager
from typing import Protocol, runtime_checkable

import torch
from torch import Tensor, jit, nn


@runtime_checkable
class ParametrizationProto(Protocol):
    """Protocol for parametrizations.

    Note:
        To work with JIT, the listed methods must be annotated with @jit.export.
    """

    @abstractmethod
    def reset_cache(self) -> None:
        """Reset the cached weight matrix."""
        ...

    @abstractmethod
    def recompute_cache(self) -> None:
        """Recompute the cached weight matrix."""
        ...

    @abstractmethod
    def projection(self) -> None:
        """Project the cached weight matrix."""
        ...

    @jit.export
    def right_inverse(self) -> None:
        """Compute the right inverse of the parametrization."""
        raise NotImplementedError

    @jit.export
    def reset_parameters(self) -> None:
        """Reapply the initialization."""
        raise NotImplementedError


class Parametrize(nn.Module, ParametrizationProto):
    """Parametrization of a single tensor."""

    # Parameters:
    parametrized_tensor: Tensor
    # Buffers:
    cached_tensor: Tensor

    def __init__(
        self,
        tensor: Tensor,
        parametrization: Callable[[Tensor], Tensor],
    ) -> None:
        super().__init__()

        # get the tensor to parametrize
        self.register_parameter("parametrized_tensor", tensor)
        self.register_buffer("cached_tensor", torch.empty_like(tensor))

        # get the parametrization
        self._parametrization = parametrization

    def forward(self) -> Tensor:
        """Apply the parametrization to the weight matrix."""
        return self.parametrization(self.parametrized_tensor)

    @jit.export
    def parametrization(self, x: Tensor) -> Tensor:
        """Apply the parametrization."""
        return self._parametrization(x)

    @jit.export
    def recompute_cache(self) -> None:
        # Compute the cached weight matrix
        new_tensor = self.forward()
        self.cached_tensor.copy_(new_tensor)

    @jit.export
    def projection(self) -> None:
        with torch.no_grad():
            # update the cached weight matrix
            self.recompute_cache()
            self.parametrized_tensor.copy_(self.cached_tensor)

    @jit.export
    def reset_cache(self) -> None:
        # apply projection step.
        self.projection()

        # reengage the autograd engine
        # detach() is necessary to avoid "Trying to backward through the graph a second time" error
        self.cached_tensor.detach_()

        # recompute the cache
        # Note: we need the second run to set up the gradients
        self.recompute_cache()

    @jit.export
    def reset_cache_expanded(self) -> None:
        with torch.no_grad():
            new_tensor = self.forward()
            self.cached_tensor.copy_(new_tensor)
            self.parametrized_tensor.copy_(self.cached_tensor)

        # reengage the autograd engine
        # detach() is necessary to avoid "Trying to backward through the graph a second time" error
        self.cached_tensor.detach_()

        # recompute the cache
        # Note: we need the second run to set up the gradients
        new_tensor = self.forward()
        self.cached_tensor.copy_(new_tensor)

In [None]:
from linodenet.lib import spectral_norm, spectral_norm_native
from linodenet.parametrize import SpectralNormalization
from linodenet.projections import is_symmetric, symmetric
from linodenet.testing import check_model

In [None]:
B, M, N = 4, 3, 3
x = torch.randn(B, M)

# setup reference model
reference_model = nn.Linear(M, N, bias=False)
symmetrized_weight = symmetric(reference_model.weight)
reference_model.weight = nn.Parameter(symmetrized_weight)
assert is_symmetric(reference_model.weight)

# setup vanilla model
model = nn.Linear(M, N, bias=False)
with torch.no_grad():
    model.weight.copy_(reference_model.weight)

# check compatibility
check_model(model, input_args=(x,), reference_model=reference_model, test_jit=True)

# now, parametrize
weight = model.weight
param = Parametrize(weight, symmetric)
param.zero_grad(set_to_none=True)
model.weight = param.parametrized_tensor
model.param = param

# check compatibility
check_model(model, input_args=(x,), reference_model=reference_model, test_jit=True)

In [None]:
m, n = 5, 5
tensor = torch.randn(m, n)
weight = nn.Parameter(tensor)
param = Parametrize(weight, symmetric)
param.zero_grad(set_to_none=True)

In [None]:
model = nn.Linear(m, n)
weight = model.weight
param = Parametrize(weight, symmetric)
param.zero_grad(set_to_none=True)
model.weight = param.parametrized_tensor
model.param = param
summary(model)

In [None]:
model = nn.Linear(m, n, bias=False)
with torch.no_grad():
    model.weight.copy_(reference_model.weight)

In [None]:
x = torch.randn(7, m)
check_model(model, input_args=(x,), reference_model=reference_model, test_jit=True)

## now, parametrize

In [None]:
weight = model.weight
param = Parametrize(weight, symmetric)
param.zero_grad(set_to_none=True)
model.weight = param.parametrized_tensor
model.param = param
check_model(model, input_args=(x,), reference_model=reference_model, test_jit=True)

In [None]:
symmetric(reference_model.weight)

In [None]:
scripted = jit.script(model)

In [None]:
scripted.weight

In [None]:
scripted.

In [None]:
param.reset_cache()

In [None]:
param.parametrized_tensor