In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt

rng = np.random.default_rng()
np.set_printoptions()

## Automatic Backward pass

In [3]:
import torch
from torch import Tensor, jit
from torch.linalg import vector_norm

torch.set_default_tensor_type(torch.cuda.FloatTensor)

In [4]:
@jit.script
def spectral_norm(
    A: Tensor, atol: float = 1e-6, rtol: float = 1e-6, maxiter: int = 1000
) -> Tensor:
    r"""Compute the spectral norm `‚ÄñA‚Äñ_2` by power iteration.

    Stopping critertion:
    - maxiter reached
    - `‚Äñ (A^TA -ŒªI)x ‚Äñ_2 ‚â§ ùóãùóçùóàùóÖ‚ãÖ‚Äñ Œªx ‚Äñ_2 + ùñ∫ùóçùóàùóÖ

    Parameters
    ----------
    A: tensor
    atol: float = 1e-4
    rtol: float =  1e-3,
    maxiter: int = 10

    Returns
    -------
    Tensor
    """
    m, n = A.shape

    #     with torch.no_grad():
    x = torch.randn(n, device=A.device, dtype=A.dtype)
    x = x / vector_norm(x)

    z = A.T @ (A @ x)
    c, d = vector_norm(z, dim=0), vector_norm(x, dim=0)
    Œª = c / d
    r = z - Œª * x

    for _ in range(maxiter):
        x = z / c
        z = A.T @ (A @ x)
        c, d = vector_norm(z, dim=0), vector_norm(x, dim=0)
        Œª = c / d
        r = z - Œª * x
        if vector_norm(r) <= rtol * vector_norm(Œª * x) + atol:
            break

    œÉ_max = torch.sqrt(Œª)

    v = x / vector_norm(x)
    u = A @ v / œÉ_max
    u /= vector_norm(u)
    #     print(u, v, torch.outer(u,v), A@v-Œª*u)

    return œÉ_max

## Custom Backward pass

In [5]:
class SpectralNorm(torch.autograd.Function):
    r"""`‚ÄñA‚Äñ_2=Œª_{ùóÜùñ∫ùóë}(A^ùñ≥A)`.

    The spectral norm `‚à•A‚à•_2 ‚âî ùóåùóéùóâ_x ‚à•Ax‚à•_2 / ‚à•x‚à•_2` can be shown to be equal to
    `œÉ_\max(A) = ‚àö{Œª_{ùóÜùñ∫ùóë} (A·µÄA)}`, the largest singular value of `A`.

    It can be computed efficiently via Power iteration.

    One can show that the derivative is equal to:

    .. math::
        \frac{‚àÇ¬Ω‚à•A‚à•_2}/{‚àÇA} = uv·µÄ

    where `u,v` are the left/right-singular vector corresponding to `œÉ_\max`
    """

    @staticmethod
    def forward(
        ctx, A: Tensor, atol: float = 1e-6, rtol: float = 1e-6, maxiter: int = 1000
    ) -> Tensor:
        """"""
        m, n = A.shape

        #     with torch.no_grad():
        x = torch.randn(n, device=A.device, dtype=A.dtype)
        x = x / vector_norm(x)

        z = A.T @ (A @ x)
        c, d = vector_norm(z, dim=0), vector_norm(x, dim=0)
        Œª = c / d
        r = z - Œª * x

        for _ in range(maxiter):
            x = z / c
            z = A.T @ (A @ x)
            c, d = vector_norm(z, dim=0), vector_norm(x, dim=0)
            Œª = c / d
            r = z - Œª * x
            if vector_norm(r) <= rtol * vector_norm(Œª * x) + atol:
                break

        œÉ_max = torch.sqrt(Œª)

        #         ctx.u = x/vector_norm(x)
        #         ctx.v = z/vector_norm(z)
        v = x / vector_norm(x)
        u = A @ v / œÉ_max
        u /= vector_norm(u)
        #         print(u, v, torch.outer(u,v), A@v-Œª*u)

        ctx.save_for_backward(u, v)
        return œÉ_max

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tensor:
        #         u, v = ctx.u, ctx.v
        u, v = ctx.saved_tensors
        #         print(grad_output, u, v)
        return grad_output * torch.outer(u, v)

## Test against ground truth

**Theorem:** $\frac{‚àÇ‚ÄñA‚Äñ_2}{‚àÇA} = u_1v_1^ùñ≥$, if $A = ‚àë_i œÉ_i u_iv_i^ùñ≥$ is the SVD of $A$

In [6]:
from scipy.stats import ortho_group, dirichlet
import numpy as np

M, N = 64, 128
K = min(M, N)
U = ortho_group.rvs(M)
V = ortho_group.rvs(N)
œÉ = dirichlet.rvs(np.ones(min(M, N))).squeeze()
œÉ = np.flip(np.sort(œÉ))
œÉt = œÉ[0]
X = np.einsum("i, mi, ni -> mn", œÉ, U[:, :K], V[:, :K])
X = torch.tensor(X).double()
H = torch.randn(M, N).double()
u = torch.tensor(U[:, 0])
v = torch.tensor(V[:, 0])
gt = torch.outer(u, v)

### linalg.norm

In [7]:
methods = {
    "norm": lambda X: torch.linalg.norm(X, ord=2),
    "matrix_norm": lambda X: torch.linalg.matrix_norm(X, ord=2),
    "svdvals": lambda X: torch.linalg.svdvals(X)[0],
    "spectral_norm": spectral_norm,
    "SpectralNorm": SpectralNorm.apply,
}

for name, method in methods.items():
    A = torch.nn.Parameter(X.clone(), requires_grad=True)
    œÉ_max = method(A)
    œÉ_max.backward()
    g = A.grad
    fward_error = torch.abs(œÉt - œÉ_max).item()
    bward_error = torch.sqrt(torch.mean((gt - g) ** 2)).item()
    print(f"{fward_error:.4e}  {bward_error:.4e}", name)

## Speet Tests

### with norm

In [8]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
œÉ_max = torch.linalg.norm(X, ord=2)
œÉ_max.backward()

### with matrix_norm

In [9]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
œÉ_max = torch.linalg.matrix_norm(X, ord=2)
œÉ_max.backward()

### with svdvals

In [10]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
œÉ_max = torch.linalg.svdvals(X)[0]
œÉ_max.backward()

### with spectral_norm

In [11]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
œÉ_max = spectral_norm(X)
œÉ_max.backward()

### with SpectralNorm

In [12]:
%%timeit -r 10 -n 10
X = torch.nn.Parameter(torch.randn(M, N), requires_grad=True)
œÉ_max = SpectralNorm.apply(X)
œÉ_max.backward()