# Invertible layers

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

from typing import Any

import torch
import torch.utils.cpp_extension
from torch import Tensor, jit, nn
from torch.linalg import matrix_norm, vector_norm

In [None]:
class SpectralNorm(torch.autograd.Function):
    r"""$‖A‖_2=λ_\max(A^⊤A)$.

    The spectral norm $∥A∥_2 ≔ \sup_x ∥Ax∥_2 / ∥x∥_2$ can be shown to be equal to
    $σ_{\max}(A) = \sqrt{λ_{\max} (A^⊤A)}$, the largest singular value of $A$.

    It can be computed efficiently via Power iteration.

    One can show that the derivative is equal to:

    .. math::  \pdv{½∥A∥_2}{A} = uv^⊤

    where $u,v$ are the left/right-singular vector corresponding to $σ_\max$

    References
    ----------
    - | `Spectral Normalization for Generative Adversarial Networks
        <https://openreview.net/forum?id=B1QRgziT->`_
      | Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
      | `International Conference on Learning Representations 2018
        <https://iclr.cc/Conferences/2018>`_
    """

    @staticmethod
    def forward(ctx: Any, *tensors: Tensor, **kwargs: Any) -> Tensor:
        r""".. Signature:: ``(m, n) -> 1``."""
        A = tensors[0]
        atol: float = kwargs["atol"] if "atol" in kwargs else 1e-6
        rtol: float = kwargs["rtol"] if "rtol" in kwargs else 1e-6
        maxiter: int = kwargs["maxiter"] if "maxiter" in kwargs else 1000
        m, n, *other = A.shape
        assert not other, "Expected 2D input."
        # initialize u and v, median should be useful guess.
        u = u_next = A.median(dim=1).values
        v = v_next = A.median(dim=0).values
        σ: Tensor = torch.einsum("ij, i, j ->", A, u, v)

        for _ in range(maxiter):
            u = u_next / torch.norm(u_next)
            v = v_next / torch.norm(v_next)
            # choose optimal σ given u and v: σ = argmin ‖A - σuvᵀ‖²
            σ = torch.einsum("ij, i, j ->", A, u, v)  # u.T @ A @ v
            # Residual: if Av = σu and Aᵀu = σv
            u_next = A @ v
            v_next = A.T @ u
            σu = σ * u
            σv = σ * v
            ru = u_next - σ * u
            rv = v_next - σ * v
            if (
                vector_norm(ru) <= rtol * vector_norm(σu) + atol
                and vector_norm(rv) <= rtol * vector_norm(σv) + atol
            ):
                break

        ctx.save_for_backward(u, v)
        return σ

    @staticmethod
    def backward(ctx: Any, *grad_outputs: Tensor) -> Tensor:
        u, v = ctx.saved_tensors
        return torch.einsum("..., i, j -> ...ij", grad_outputs[0], u, v)

    vjp = backward

    @staticmethod
    def jvp(ctx: Any, *grad_inputs: Any) -> Any:
        r"""Jacobian-vector product forward mode."""
        u, v = ctx.saved_tensors
        return torch.einsum("...ij, i, j -> ...", grad_inputs[0], u, v)

In [None]:
# Write single layer iResNet in C++


forward
backward
inverse
inverse_backward

## iResNet layer

Any module with forward pass $y = f(x) = x + g(x)$, where $g = g_L ∘ g_{L-1} ∘ … ∘ g_1$, each layer is a contraction ($\text{Lip}(g) < 1$).
Then $x = y - g(x)$ is a fixed point equation that can be solved by fixed point iteration.

$$ x' = y - g(x)$$


Alternatively, we can solve with gradient descent. Question: is the fixed point iteration equivalent to some GD scheme? We have:


$$ x' = y - g(x) = x - x + y - g(x) = x - (x+g(x) - y) = x - ∇_x ∫ (x + g(x) - y) dx $$

## backward for iResNet inverse

So, we calculated $x(y) = f^{-1}(y)$ via fixed point iteration in the inverse pass. What is the gradient? By the inverse function theorem: 

$$𝐃[f^{-1}](y) = \big(𝐃[f](x)\big)^{-1}$$

making use of this fact, we can compute Vector-Jacobian-Products (VJP) as 

$$ [∆y ↦ ⟨ v ∣ \big(𝐃[f](x)\big)^{-1} ∆y⟩ =  [∆y ↦⟨  \big(𝐃[f](x)\big)^{-T}v ∣ ∆y  ⟩] $$

Hence, the VJP is given by $⟨\text{solve}( 𝐃[f](f^{-1}(y))^T v ∣ outer  )$

The big question: How do we get the transpose?!

⟹ All sublayers must make the transpose available! 

$$ 𝐃[f](f^{-1}(y))^T v = 𝐃[f_1∘f_2 ∘…∘f_n]^T v = \Big(𝐃[f_1]∘𝐃[f_2] ∘…∘𝐃[f_n]\Big)^T v = 𝐃[f_n]^⊤ ∘ … ∘ 𝐃[f_1]^T v$$

BUT THIS IS JUST THE VJP of $f$ !!!


Thus, the goal becomes:

$$ \text{solve}( VJP(f, x, v), w) $$

And we can make use of any iterative solver!!


However, we need a library that works with general tensorial data.

In [None]:
class iResNet(nn.Sequential):
    def __init__(self):
        pass

    def forward(self, x: Tensor):
        for layer in self.layers:
            x = x + layer(x)
        return x

    def inverse(self, y):
        """via fixed point iteration."""
        x = y
        for k in range(10):
            x = y - self(x)
        return x

    def vjp_inverse(self, outer_grad, ctx):
        x = ctx["x"]
        return solve(lambda v: vjp(f, x, v), outer_grad)

In [None]:
def anderson(f, x0, m=5, lam=1e-4, max_iter=50, tol=1e-2, beta=1.0):
    """Anderson acceleration for fixed point iteration."""
    bsz, d, H, W = x0.shape
    X = torch.zeros(bsz, m, d * H * W, dtype=x0.dtype, device=x0.device)
    F = torch.zeros(bsz, m, d * H * W, dtype=x0.dtype, device=x0.device)
    X[:, 0], F[:, 0] = x0.view(bsz, -1), f(x0).view(bsz, -1)
    X[:, 1], F[:, 1] = F[:, 0], f(F[:, 0].view_as(x0)).view(bsz, -1)

    H = torch.zeros(bsz, m + 1, m + 1, dtype=x0.dtype, device=x0.device)
    H[:, 0, 1:] = H[:, 1:, 0] = 1
    y = torch.zeros(bsz, m + 1, 1, dtype=x0.dtype, device=x0.device)
    y[:, 0] = 1

    res = []
    for k in range(2, max_iter):
        n = min(k, m)
        G = F[:, :n] - X[:, :n]
        H[:, 1 : n + 1, 1 : n + 1] = (
            torch.bmm(G, G.transpose(1, 2))
            + lam * torch.eye(n, dtype=x0.dtype, device=x0.device)[None]
        )
        alpha = torch.solve(y[:, : n + 1], H[:, : n + 1, : n + 1])[0][
            :, 1 : n + 1, 0
        ]  # (bsz x n)

        X[:, k % m] = (
            beta * (alpha[:, None] @ F[:, :n])[:, 0]
            + (1 - beta) * (alpha[:, None] @ X[:, :n])[:, 0]
        )
        F[:, k % m] = f(X[:, k % m].view_as(x0)).view(bsz, -1)
        res.append(
            (F[:, k % m] - X[:, k % m]).norm().item()
            / (1e-5 + F[:, k % m].norm().item())
        )
        if res[-1] < tol:
            break
    return X[:, k % m].view_as(x0), res

In [None]:
from typing import Final

import torch.autograd as autograd


class DEQFixedPoint(nn.Module):
    maxiter: Final[int]

    def __init__(self, maxiter: int = 5):
        super().__init__()
        self.f = nn.Linear(5, 5)
        self.maxiter = maxiter

    def forward(self, x: Tensor) -> Tensor:
        # compute forward pass and re-engage autograd tape

        with torch.no_grad():
            z = x.clone()
            for k in range(self.maxiter):
                z = x - self.f(z)

        # re-engage tape
        z = x - self.f(z)

        # set up Jacobian vector product (without additional forward calls)
        z0 = z.clone().detach().requires_grad_()
        f0 = x - self.f(z0)

        z.register_hook(self.custom_backward)
        return z

    def vjp_f(self, z0, y):
        return autograd.vjp(self.f, z0, y)

    def custom_backward(grad: Tensor) -> Tensor:
        return torch.linalg.solve(self.vjp, grad)

In [None]:
model = DEQFixedPoint()

In [None]:
model(torch.randn(5))

In [None]:
jit.script(model)

In [None]:
torch.randn(5).register_hook

In [None]:
nn.Linear(4, 5)(torch.randn(4))

## DEMO

## Example: spectral normalization layer


consider $y = x + \frac{A}{‖A‖₂}x$

# Example: Linear Solver Layer


Consider: $f:(A, b) ↦ solve(A, b)$

Then $\frac{∂f}{∂A} = solve(A^⊤, -𝕀⊗x)$ and $\frac{∂f}{∂b} = solve(A, 𝕀)$

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

np.set_printoptions(4)

In [None]:
A = np.random.randn(5, 5)
b = np.random.randn(5)
x = np.linalg.solve(A, b)

In [None]:
g = jax.jacfwd(jnp.linalg.solve)
g(A, b)

In [None]:
np.linalg.solve(A.T, -np.einsum("ij, k -> ijk", np.eye(5), x))