# iResNet-Implementation

In [None]:
%config InlineBackend.figure_format = 'svg'

In [None]:
import tsdm
import warnings
import torch
import math
import torchdiffeq
from torch import nn, Tensor
from torch.nn import GRUCell
import numpy as np
from opt_einsum import contract
from tqdm.auto import trange
from typing import Union, Callable
import scipy
from scipy import stats
import matplotlib.pyplot as plt
from scipy.integrate import odeint

from typing import Union

In [None]:
from tsdm.util import ACTIVATIONS, deep_dict_update, deep_kval_update, scaled_norm

ACTIVATIONS

In [None]:
class LinearContraction(torch.jit.ScriptModule):
    __constants__ = ["input_size", "output_size"]
    input_size: int
    output_size: int
    weight: Tensor
    bias: Union[Tensor, None]

    def __init__(self, input_size: int, output_size: int, bias: bool = True) -> None:
        super(LinearContraction, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.weight = nn.Parameter(torch.Tensor(output_size, input_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_size))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return "input_size={}, output_size={}, bias={}".format(
            self.input_size, self.output_size, self.bias is not None
        )

    @torch.jit.script_method
    def forward(self, input: Tensor, c: float = 0.97) -> Tensor:
        σ_max = torch.linalg.norm(self.weight, ord=2)
        fac = torch.minimum(c / σ_max, torch.ones(1))
        return nn.functional.linear(input, fac * self.weight, self.bias)

In [None]:
def test_LinearContraction(
    n_samples: int = 10_000, dim_in: int = None, dim_out: int = None
) -> None:
    """
    Tests empirically whether the LinearContraction module is a contraction.
    """
    n_samples = n_samples or np.random.randint(low=1000, high=10_000)
    dim_in = dim_in or np.random.randint(low=2, high=100)
    dim_out = dim_out or np.random.randint(low=2, high=100)
    x = torch.randn(n_samples, dim_in)
    y = torch.randn(n_samples, dim_in)
    distances = torch.cdist(x, y)

    model = LinearContraction(dim_in, dim_out)
    xhat = model(x)
    yhat = model(y)
    latent_distances = torch.cdist(xhat, yhat)

    assert torch.all(latent_distances <= distances)

    scaling_factor = (latent_distances / distances).flatten()
    fig, ax = plt.subplots(figsize=(8, 4), tight_layout=True)
    tsdm.util.visualize_distribution(scaling_factor, ax=ax)
    ax.set_title(
        f"LinearContraction -- Scaling Factor Distribution (samples:{n_samples}, dim-in:{dim_in}, dim-out:{dim_out})"
    )
    ax.set_xlabel(r"$s(x, y) = \frac{\|\phi(x)-\phi(y)\|}{\|x-y\|}$")
    ax.set_ylabel(r"density $p(s\mid x, y)$ where $x_i,y_i\sim \mathcal N(0,1)$")


test_LinearContraction()

In [None]:
class iResNetBlock(torch.jit.ScriptModule):
    __constants__ = ["input_size", "output_size", "maxiter"]
    input_size: int
    hidden_size: int
    output_size: int
    maxiter: int
    bias: bool

    HP = {
        "activation": "ReLU",
        "activation_config": {"inplace": False},
        "bias": True,
        "hidden_size": None,
        "input_size": None,
        "maxiter": 100,
    }

    def __init__(self, input_size: int, **HP):
        super(iResNetBlock, self).__init__()

        self.HP["input_size"] = input_size
        tsdm.utils.deep_dict_update(self.HP, HP)

        self.input_size = input_size
        self.output_size = input_size
        self.hidden_size = self.HP["hidden_size"] or math.ceil(math.sqrt(input_size))

        self.maxiter = self.HP["maxiter"]
        self.bias = self.HP["bias"]

        activation = ACTIVATIONS[self.HP["activation"]]

        self.bottleneck = nn.Sequential(
            LinearContraction(self.input_size, self.hidden_size, self.bias),
            LinearContraction(self.hidden_size, self.input_size, self.bias),
            activation(**self.HP["activation_config"]),
        )

    @torch.jit.script_method
    def forward(self, x):
        """n-dim to n-dim"""

        xhat = x + self.bottleneck(x)

        return xhat

    @torch.jit.script_method
    def inverse(self, y, maxiter: int = 1000, rtol: float = 1e-05, atol: float = 1e-08):
        #         with torch.no_grad():
        xhat = y.clone()
        xhat_dash = y.clone()
        residual = torch.zeros_like(y)

        for k in range(self.maxiter):
            xhat_dash = y - self.bottleneck(xhat)
            residual = torch.abs(xhat_dash - xhat) - rtol * torch.absolute(xhat)

            if torch.all(residual <= atol):
                return xhat_dash
            else:
                xhat = xhat_dash

        warnings.warn(
            f"No convergence in {maxiter} iterations. Max residual:{torch.max(residual)} > {atol}."
        )
        return xhat_dash

In [None]:
def test_iResNetBlock(
    n_samples: int = 1_000, input_size: int = None, hidden_size: int = None
) -> None:
    """
    Tests empirically whether the iResNetBlock is indeed invertible.
    """
    n_samples = 10_000 or np.random.randint(low=1000, high=10_000)
    input_size = np.random.randint(low=2, high=100)
    hidden_size = np.random.randint(low=2, high=100)
    HP = {}

    model = iResNetBlock(input_size, **HP)

    x = torch.randn(n_samples, input_size)
    y = torch.randn(n_samples, input_size)

    fx = model(x)
    xhat = model.inverse(fx)

    ify = model.inverse(y)
    yhat = model(ify)

    dist_lmap = tsdm.utils.scaled_norm(x - fx, axis=-1)
    dist_rmap = tsdm.utils.scaled_norm(y - ify, axis=-1)
    err_linverse = tsdm.utils.scaled_norm(x - xhat, axis=-1)
    err_rinverse = tsdm.utils.scaled_norm(y - yhat, axis=-1)

    fig, ax = plt.subplots(
        ncols=2, nrows=2, figsize=(10, 5), tight_layout=True, sharex="row", sharey="row"
    )
    tsdm.utils.visualize_distribution(err_linverse, ax=ax[0, 0])
    tsdm.utils.visualize_distribution(err_rinverse, ax=ax[0, 1])
    tsdm.utils.visualize_distribution(dist_lmap, ax=ax[1, 0])
    tsdm.utils.visualize_distribution(dist_rmap, ax=ax[1, 1])

    assert torch.quantile(err_linverse, 0.99) <= 10**-6
    assert torch.quantile(err_rinverse, 0.99) <= 10**-6
    #     assert torch.mean()

    #     ax.set_title(F"Scaling Factor Distribution (samples:{n_samples}, dim-in:{dim_in}, dim-out:{dim_out}))
    ax[0, 0].set_xlabel(r"$r_\text{left}(x) = \|x - \phi^{-1}(\phi(x))\|$")
    ax[0, 0].set_ylabel(r"$p(r_\text{left} \mid x)$ where $x_i \sim \mathcal N(0,1)$")
    ax[0, 1].set_xlabel(r"$r_\text{right}(y) = \|y - \phi(\phi^{-1}(y))\|$")
    ax[0, 1].set_ylabel(r"$p(r_\text{right}\mid y)$ where $y_j \sim \mathcal N(0,1)$")

    ax[1, 0].set_xlabel(r"$d_\text{left}(x) = \|x - \phi(x)\|$")
    ax[1, 0].set_ylabel(r"$p(d_\text{left} \mid x)$ where $x_i \sim \mathcal N(0,1)$")
    ax[1, 1].set_xlabel(r"$d_\text{right}(y) = \|y - \phi^{-1}(y)\|$")
    ax[1, 1].set_ylabel(r"$p(d_\text{right} \mid y)$ where $y_j \sim \mathcal N(0,1)$")
    fig.suptitle(
        f"iResNetBlock -- Inversion property (samples:{n_samples}, dim-in:{input_size}, dim-hidden:{hidden_size})",
        fontsize=16,
    )


test_iResNetBlock()

In [None]:
def printgradnorm(self, grad_input, grad_output):
    print("Inside " + self.__class__.__name__ + " backward")
    print("Inside class:" + self.__class__.__name__)
    print("")
    print("grad_input: ", type(grad_input))
    print("grad_input[0]: ", type(grad_input[0]))
    print("grad_output: ", type(grad_output))
    print("grad_output[0]: ", type(grad_output[0]))
    print("")
    print("grad_input size:", grad_input[0].size())
    print("grad_output size:", grad_output[0].size())
    print("grad_input norm:", grad_input[0].norm())

Consider:
- loss $\ell(x, \hat x)$
- $\hat x =  F^{-1}(z, \theta)$ where $F(z) = z + g(z, \theta)$.
    - The inverse solves the fixed point equation $\hat x(z,\theta) = z - g(\hat x(z, \theta), \theta)$
- Then 

$$
\frac{\partial \ell(x, \hat x)}{\partial \theta} 
= \frac{\partial \ell(x, \hat x)}{\partial \hat x}\frac{\partial \hat x}{\partial \theta} 
$$

Where , since $\hat x(z,\theta) = z - g(\hat x(z, \theta), \theta)$ we have

$$
\frac{\partial \hat x}{\partial \theta} 
= - \frac{\partial g}{\partial \hat x}\frac{\partial \hat x}{\partial \theta} - \frac{\partial g}{\partial \theta} 
\implies
\Big(I - \frac{\partial g}{\partial \hat x}\Big)\frac{\partial \hat x}{\partial \theta}  = - \frac{\partial g}{\partial \theta}
$$

Plugging this into the loss we have

$$\begin{aligned}
\frac{\partial \ell(x, \hat x)}{\partial \theta} 
&= \frac{\partial \ell(x, \hat x)}{\partial \hat x}\bigg(-\Big(I - \frac{\partial g}{\partial \hat x}\Big)^{-1}\frac{\partial g}{\partial \theta}\bigg)
\\
&= a^T \frac{\partial g}{\partial \theta} \qquad\text{where }  \Big(I - \frac{\partial g}{\partial \hat x}\Big)a  = - \frac{\partial \ell}{\partial \hat x}
\end{aligned}$$


In [None]:
class inverse_iteration(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, bottleneck):
        x = input.clone()
        for k in range(1000):
            # fixed point iteration
            x = input - bottleneck(x)

        ctx.save_for_backward(x)

        return x

    @staticmethod
    def backward(ctx, grad_output):
        pass

In [None]:
n_samples = 10_000 or np.random.randint(low=1000, high=10_000)
input_size = np.random.randint(low=2, high=100)
hidden_size = np.random.randint(low=2, high=100)
HP = {}
x = torch.randn(n_samples, input_size)
y = torch.randn(n_samples, input_size)
model = iResNetBlock(input_size, **HP)
modelB = LinearContraction(input_size, input_size)

In [None]:
model.inverse(x)

In [None]:
inverse_iteration.apply(x, model.bottleneck)

In [None]:
torch.linalg.norm(model.inverse(modelB(model(x)))).backward()

In [None]:
class iResNet(torch.jit.ScriptModule):
    HP = {
        "maxiter": 10,
        "input_size": None,
        "dropout": None,
        "bias": True,
        "nBlocks": 5,
        "iResNetBlock": {
            "input_size": None,
            "activation": "ReLU",
            "activation_config": {"inplace": False},
            "bias": True,
            "hidden_size": None,
            "maxiter": 100,
        },
    }

    input_size: int
    output_size: int
    nblocks: int

    def __init__(self, input_size, **HP):
        super(iResNet, self).__init__()

        self.HP["input_size"] = input_size
        tsdm.utils.deep_dict_update(self.HP, HP)

        self.input_size = input_size
        self.output_size = input_size
        self.HP["iResNetBlock"]["input_size"] = self.input_size

        self.nblocks = self.HP["nBlocks"]
        self.maxiter = self.HP["maxiter"]
        self.bias = self.HP["bias"]

        self.blocks = nn.Sequential(
            *[iResNetBlock(**self.HP["iResNetBlock"]) for k in range(self.nblocks)]
        )

        self.reversed_blocks = nn.Sequential(*reversed(self.blocks))

    @torch.jit.script_method
    def forward(self, x):
        """n-dim to n-dim"""

        for block in self.blocks:
            x = block(x)

        return x

    @torch.jit.script_method
    def inverse(self, y, maxiter: int = 100, rtol: float = 1e-05, atol: float = 1e-08):

        with torch.no_grad():
            for block in self.reversed_blocks:
                # `reversed` does not work in torchscript v1.8.1
                y = block.inverse(y)

        return y

    @torch.jit.script_method
    def alt_inverse(
        self, y, maxiter: int = 1000, rtol: float = 1e-05, atol: float = 1e-08
    ):

        xhat = y.clone()
        xhat_dash = y.clone()
        residual = torch.zeros_like(y)

        for k in range(self.maxiter):
            xhat_dash = y - self(xhat)
            residual = torch.abs(xhat_dash - xhat) - rtol * torch.absolute(xhat)

            if torch.all(residual <= atol):
                return xhat_dash
            else:
                xhat = xhat_dash

        warnings.warn(
            f"No convergence in {maxiter} iterations. Max residual:{torch.max(residual)} > {atol}."
        )
        return xhat_dash

In [None]:
len(nn.Sequential(nn.Linear(10, 11), nn.Linear(11, 12), nn.Linear(12, 13)))

In [None]:
from torchinfo import summary

n_samples = 10_000 or np.random.randint(low=1000, high=10_000)
input_size = np.random.randint(low=2, high=100)
nBlocks = np.random.randint(low=2, high=100)
HP = {"nBlocks": nBlocks}
print(f"{n_samples=}, {input_size=},  {nBlocks=}")
model = iResNet(input_size, **HP)
summary(model)

In [None]:
x = torch.randn(n_samples, input_size)
y = torch.randn(n_samples, input_size)

fx = model(x)

In [None]:
xhat = model.inverse(fx)

ify = model.inverse(y)
yhat = model(ify)

In [None]:
dist_lmap = tsdm.utils.scaled_norm(x - fx, axis=-1)
dist_rmap = tsdm.utils.scaled_norm(y - ify, axis=-1)
err_linverse = tsdm.utils.scaled_norm(x - xhat, axis=-1)
err_rinverse = tsdm.utils.scaled_norm(y - yhat, axis=-1)

fig, ax = plt.subplots(
    ncols=2, nrows=2, figsize=(10, 5), tight_layout=True, sharex="row", sharey="row"
)
tsdm.utils.visualize_distribution(err_linverse, ax=ax[0, 0])
tsdm.utils.visualize_distribution(err_rinverse, ax=ax[0, 1])
tsdm.utils.visualize_distribution(dist_lmap, ax=ax[1, 0])
tsdm.utils.visualize_distribution(dist_rmap, ax=ax[1, 1])

# assert torch.quantile(err_linverse, 0.99) <= 10**-6
# assert torch.quantile(err_rinverse, 0.99) <= 10**-6

#     ax.set_title(F"Scaling Factor Distribution (samples:{n_samples}, dim-in:{dim_in}, dim-out:{dim_out}))
ax[0, 0].set_xlabel(r"$r_\text{left}(x) = \|x - \phi^{-1}(\phi(x))\|$")
ax[0, 0].set_ylabel(r"$p(r_\text{left} \mid x)$ where $x_i \sim \mathcal N(0,1)$")
ax[0, 1].set_xlabel(r"$r_\text{right}(y) = \|y - \phi(\phi^{-1}(y))\|$")
ax[0, 1].set_ylabel(r"$p(r_\text{right}\mid y)$ where $y_j \sim \mathcal N(0,1)$")

ax[1, 0].set_xlabel(r"$d_\text{left}(x) = \|x - \phi(x)\|$")
ax[1, 0].set_ylabel(r"$p(d_\text{left} \mid x)$ where $x_i \sim \mathcal N(0,1)$")
ax[1, 1].set_xlabel(r"$d_\text{right}(y) = \|y - \phi^{-1}(y)\|$")
ax[1, 1].set_ylabel(r"$p(d_\text{right} \mid y)$ where $y_j \sim \mathcal N(0,1)$")
fig.suptitle(
    f"{model.__class__.__name__} -- Inversion property (samples:{n_samples}, dim-in:{input_size})",
    fontsize=16,
)

In [None]:
fx = model(x)
xhat = model.alt_inverse(fx)

ify = model.alt_inverse(y)
yhat = model(ify)

dist_lmap = tsdm.utils.scaled_norm(x - fx, axis=-1)
dist_rmap = tsdm.utils.scaled_norm(y - ify, axis=-1)
err_linverse = tsdm.utils.scaled_norm(x - xhat, axis=-1)
err_rinverse = tsdm.utils.scaled_norm(y - yhat, axis=-1)

fig, ax = plt.subplots(
    ncols=2, nrows=2, figsize=(10, 5), tight_layout=True, sharex="row", sharey="row"
)
tsdm.utils.visualize_distribution(err_linverse, ax=ax[0, 0])
tsdm.utils.visualize_distribution(err_rinverse, ax=ax[0, 1])
tsdm.utils.visualize_distribution(dist_lmap, ax=ax[1, 0])
tsdm.utils.visualize_distribution(dist_rmap, ax=ax[1, 1])

# assert torch.quantile(err_linverse, 0.99) <= 10**-6
# assert torch.quantile(err_rinverse, 0.99) <= 10**-6

#     ax.set_title(F"Scaling Factor Distribution (samples:{n_samples}, dim-in:{dim_in}, dim-out:{dim_out}))
ax[0, 0].set_xlabel(r"$r_\text{left}(x) = \|x - \phi^{-1}(\phi(x))\|$")
ax[0, 0].set_ylabel(r"$p(r_\text{left} \mid x)$ where $x_i \sim \mathcal N(0,1)$")
ax[0, 1].set_xlabel(r"$r_\text{right}(y) = \|y - \phi(\phi^{-1}(y))\|$")
ax[0, 1].set_ylabel(r"$p(r_\text{right}\mid y)$ where $y_j \sim \mathcal N(0,1)$")

ax[1, 0].set_xlabel(r"$d_\text{left}(x) = \|x - \phi(x)\|$")
ax[1, 0].set_ylabel(r"$p(d_\text{left} \mid x)$ where $x_i \sim \mathcal N(0,1)$")
ax[1, 1].set_xlabel(r"$d_\text{right}(y) = \|y - \phi^{-1}(y)\|$")
ax[1, 1].set_ylabel(r"$p(d_\text{right} \mid y)$ where $y_j \sim \mathcal N(0,1)$")
fig.suptitle(
    f"{model.__class__.__name__} -- Inversion property (samples:{n_samples}, dim-in:{input_size}, dim-hidden:{hidden_size})",
    fontsize=16,
)

Anyone know how to register a custom backward function to a pytorch module? In implementing the i-ResNet Architecture, I roughly have

```python
class iResNetBlock(nn.Module):
    def __init__(self, input_size):
        self.bottleneck = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Linear(hidden_size, input_size),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return x + self.bottleneck(x)
    
    def inverse(self, y):
        x = y.clone()
        while not converged:
            # fixed point iteration
            x = y - self.bottleneck(x)
   
        return x
        
    def inverse_backwards():
        pass
    
class iResNet(nn.Module):
    def __init__(self, num_blocks):
        self.blocks = nn.Sequential(*[
            iResNetBlock for k in range(num_blocks)
        ])
        
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x
    
    def inverse(self, y):
        for block in reversed(self.blocks):
            y = block.inverse(y)
        return y
```
    
