# Question: What is the best way of implementing a LinearContraction leayer in python?

I.e. a linear layer with $‖A‖_2 = σ_{\max}(A)≤ 1$.

**TODOs:**

- test torch.nn.utils.paramtrizations.spectal_norm

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt

rng = np.random.default_rng()
np.set_printoptions()

## Automatic Backward pass

In [None]:
import torch
from torch import Tensor, jit
from torch.linalg import vector_norm

torch.set_default_tensor_type(
    torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
)

In [None]:
from torch.nn.utils.parametrizations import spectral_norm
from torch import nn, jit

In [None]:
snm = spectral_norm(nn.Linear(20, 30))

### torch shipped spectral norm not jitable

In [None]:
jit.script(snm)

## Improved Custom Implemention

We will more or less duplicate the torch implementation with some minor improvements

- Run init until convergence
- Provide option to run forward until convergence



## AutoGradTemplate

In [None]:
from torch.autograd import Function

# Inherit from Function
class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

In [None]:
linear = LinearFunction.apply

In [None]:
from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (
    torch.randn(20, 20, dtype=torch.double, requires_grad=True),
    torch.randn(30, 20, dtype=torch.double, requires_grad=True),
)
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)

## SpectralNorm Implementation

In [None]:
from tqdm.auto import tqdm, trange

In [None]:
from typing import Any


class SpectralNorm(torch.autograd.Function):
    r"""`‖A‖_2=λ_{𝗆𝖺𝗑}(A^𝖳A)`.

    The spectral norm `∥A∥_2 ≔ 𝗌𝗎𝗉_x ∥Ax∥_2 / ∥x∥_2` can be shown to be equal to
    `σ_\max(A) = √{λ_{𝗆𝖺𝗑} (AᵀA)}`, the largest singular value of `A`.

    It can be computed efficiently via Power iteration.

    One can show that the derivative is equal to:

    .. math::
        \frac{∂½∥A∥_2}/{∂A} = uvᵀ

    where `u,v` are the left/right-singular vector corresponding to `σ_\max`

    References
    ----------
    - | `Spectral Normalization for Generative Adversarial Networks
        <https://openreview.net/forum?id=B1QRgziT->`_
      | Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
      | `International Conference on Learning Representations 2018
        <https://iclr.cc/Conferences/2018>`_
    """

    @staticmethod
    def jvp(ctx: Any, *grad_inputs: Any) -> Any:
        u, v = ctx.saved_tensors
        return torch.outer(u, v) @ grad_inputs[0]

    @staticmethod
    def forward(ctx: Any, *tensors: Tensor, **kwargs: Any) -> Tensor:
        r"""Forward pass.

        Parameters
        ----------
        ctx
        tensors
        kwargs

        Returns
        -------
        Tensor
        """
        A = tensors[0]
        atol: float = kwargs["atol"] if "atol" in kwargs else 1e-6
        rtol: float = kwargs["rtol"] if "rtol" in kwargs else 1e-6
        maxiter: int = kwargs["maxiter"] if "maxiter" in kwargs else 1000
        m, n, *other = A.shape
        assert not other, "Expected 2D input."
        # initialize u and v, median should be useful guess.
        u = u_next = A.median(dim=1).values
        v = v_next = A.median(dim=0).values

        for _ in range(maxiter):
            u = u_next / torch.norm(u_next)
            v = v_next / torch.norm(v_next)
            # choose optimal σ given u and v: σ = argmin ‖A - σuvᵀ‖²
            σ: Tensor = torch.einsum("ij, i, j ->", A, u, v)  # u.T @ A @ v
            # Residual: if Av = σu and Aᵀu = σv

            u_next = A @ v
            v_next = A.T @ u

            # u_next = torch.einsum('ij, ...j->...i', A, v)
            # v_next = torch.einsum('ij, ...i->...j', A, u)

            σu = σ * u
            σv = σ * v

            ru = u_next - σ * u
            rv = v_next - σ * v
            if (
                vector_norm(ru) <= rtol * vector_norm(σu) + atol
                and vector_norm(rv) <= rtol * vector_norm(σv) + atol
            ):
                break

        ctx.save_for_backward(u, v)
        return σ

    @staticmethod
    def backward(ctx: Any, grad_outputs: Tensor) -> Tensor:
        r"""Backward pass.

        Parameters
        ----------
        ctx
        grad_outputs
        """
        u, v = ctx.saved_tensors
        return torch.outer(grad_outputs * u, v)

In [None]:
spectral_norm = SpectralNorm.apply

In [None]:
%%timeit
inputs = torch.randn(20, 30, dtype=torch.double, requires_grad=True)
test = gradcheck(spectral_norm, inputs, eps=1e-6, atol=1e-4)
print(test)

In [None]:
15.1 ± 4.94

In [None]:
@jit.script
def spectral_norm(
    A: Tensor, atol: float = 1e-6, rtol: float = 1e-6, maxiter: int = 1000
) -> Tensor:
    r"""Compute the spectral norm `‖A‖_2` by power iteration.

    Stopping critertion:
    - maxiter reached
    - `‖ (A^TA -λI)x ‖_2 ≤ 𝗋𝗍𝗈𝗅⋅‖ λx ‖_2 + 𝖺𝗍𝗈𝗅

    Parameters
    ----------
    A: tensor
    atol: float = 1e-4
    rtol: float =  1e-3,
    maxiter: int = 10

    Returns
    -------
    Tensor
    """
    m, n = A.shape

    #     with torch.no_grad():
    x = torch.randn(n, device=A.device, dtype=A.dtype)
    x = x / vector_norm(x)

    z = A.T @ (A @ x)
    c, d = vector_norm(z, dim=0), vector_norm(x, dim=0)
    λ = c / d
    r = z - λ * x

    for _ in range(maxiter):
        x = z / c
        z = A.T @ (A @ x)
        c, d = vector_norm(z, dim=0), vector_norm(x, dim=0)
        λ = c / d
        r = z - λ * x
        if vector_norm(r) <= rtol * vector_norm(λ * x) + atol:
            break

    σ_max = torch.sqrt(λ)

    v = x / vector_norm(x)
    u = A @ v / σ_max
    u /= vector_norm(u)
    #     print(u, v, torch.outer(u,v), A@v-λ*u)

    return σ_max

## Custom Backward pass

In [None]:
class SpectralNorm(torch.autograd.Function):
    r"""`‖A‖_2=λ_{𝗆𝖺𝗑}(A^𝖳A)`.

    The spectral norm `∥A∥_2 ≔ 𝗌𝗎𝗉_x ∥Ax∥_2 / ∥x∥_2` can be shown to be equal to
    `σ_\max(A) = √{λ_{𝗆𝖺𝗑} (AᵀA)}`, the largest singular value of `A`.

    It can be computed efficiently via Power iteration.

    One can show that the derivative is equal to:

    .. math::
        \frac{∂½∥A∥_2}/{∂A} = uvᵀ

    where `u,v` are the left/right-singular vector corresponding to `σ_\max`
    """

    @staticmethod
    def forward(
        ctx, A: Tensor, atol: float = 1e-6, rtol: float = 1e-6, maxiter: int = 1000
    ) -> Tensor:
        """"""
        m, n = A.shape

        #     with torch.no_grad():
        x = torch.randn(n, device=A.device, dtype=A.dtype)
        x = x / vector_norm(x)

        z = A.T @ (A @ x)
        c, d = vector_norm(z, dim=0), vector_norm(x, dim=0)
        λ = c / d
        r = z - λ * x

        for _ in range(maxiter):
            x = z / c
            z = A.T @ (A @ x)
            c, d = vector_norm(z, dim=0), vector_norm(x, dim=0)
            λ = c / d
            r = z - λ * x
            if vector_norm(r) <= rtol * vector_norm(λ * x) + atol:
                break

        σ_max = torch.sqrt(λ)

        #         ctx.u = x/vector_norm(x)
        #         ctx.v = z/vector_norm(z)
        v = x / vector_norm(x)
        u = A @ v / σ_max
        u /= vector_norm(u)
        #         print(u, v, torch.outer(u,v), A@v-λ*u)

        ctx.save_for_backward(u, v)
        return σ_max

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tensor:
        #         u, v = ctx.u, ctx.v
        u, v = ctx.saved_tensors
        #         print(grad_output, u, v)
        return grad_output * torch.outer(u, v)

## Test against ground truth

**Theorem:** $\frac{∂‖A‖_2}{∂A} = u_1v_1^𝖳$, if $A = ∑_i σ_i u_iv_i^𝖳$ is the SVD of $A$

In [None]:
from scipy.stats import ortho_group, dirichlet
import numpy as np

M, N = 64, 128
K = min(M, N)
U = ortho_group.rvs(M)
V = ortho_group.rvs(N)
σ = dirichlet.rvs(np.ones(min(M, N))).squeeze()
σ = np.flip(np.sort(σ))
σt = σ[0]
X = np.einsum("i, mi, ni -> mn", σ, U[:, :K], V[:, :K])
X = torch.tensor(X).double()
H = torch.randn(M, N).double()
u = torch.tensor(U[:, 0])
v = torch.tensor(V[:, 0])
gt = torch.outer(u, v)

### linalg.norm

In [None]:
methods = {
    "norm": lambda X: torch.linalg.norm(X, ord=2),
    "matrix_norm": lambda X: torch.linalg.matrix_norm(X, ord=2),
    "svdvals": lambda X: torch.linalg.svdvals(X)[0],
    "spectral_norm": spectral_norm,
    "SpectralNorm": SpectralNorm.apply,
}

for name, method in methods.items():
    A = torch.nn.Parameter(X.clone(), requires_grad=True)
    σ_max = method(A)
    σ_max.backward()
    g = A.grad
    fward_error = torch.abs(σt - σ_max).item()
    bward_error = torch.sqrt(torch.mean((gt - g) ** 2)).item()
    print(f"{fward_error:.4e}  {bward_error:.4e}", name)

## Speet Tests

### with norm

In [None]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
σ_max = torch.linalg.norm(X, ord=2)
σ_max.backward()

### with matrix_norm

In [None]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
σ_max = torch.linalg.matrix_norm(X, ord=2)
σ_max.backward()

### with svdvals

In [None]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
σ_max = torch.linalg.svdvals(X)[0]
σ_max.backward()

### with spectral_norm

In [None]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
σ_max = spectral_norm(X)
σ_max.backward()

### with SpectralNorm

In [None]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
σ_max = SpectralNorm.apply(X)
σ_max.backward()