# Implicit Layers
Normally, layers are expressed as an application of some function $f: \mathcal{X} \to \mathcal{Z}$. That is
\begin{equation*}
    z = f(x)
\end{equation*}
Instead, we want to consider an implicit function $g: \mathcal{X} \times \mathcal{Z} \to \mathbb{R}^n$, where the output $z$ of the layer is constrained to some root of $g$ such that 
\begin{equation*}
    g(x, z) = 0
\end{equation*}

Advantages of this method are that we do not need any way to actually compute $g$, which we would if this were an explicit layer. More fundamentally, this separates the computation of the layer from the definition of the layer. They are also far less memory intensive, as there is no hidden state which to store to perform backpropogation. Indeed, gradients can be computed directly from the implicit function theorem.

## Fixed point iteration layer
Suppose we have inputs, and outputs $x, z \in \mathbb{R}^n$ and model weights $W \in \mathbb{R}^{n\times n}$. Then consider a fixed point iteration layer such that we initialize $z := 0$ and repeat the following until convergence:
\begin{equation*}
z := \tanh{(Wz + x)}
\end{equation*}

Eventually, we may reach some fixed output $z^*$ such that $z^* = \tanh{(Wz^* + x)}$. In this case, the implicit representation is as follows.
\begin{equation*}
g(x, z) = z - \tanh{(Wz + x)}
\end{equation*}
For now, we will assume that for typical values of $W$ that convergence is possible.

In [1]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TanhFixedPointLayer(nn.Module):
    def __init__(
        self, out_features: int, tol: float = 1e-4, max_iter: int = 50
    ) -> None:
        super().__init__()
        self.linear = nn.Linear(out_features, out_features, bias=False).to(device)
        self.tol = tol
        self.max_iter = max_iter

    def forward(self, x) -> torch.Tensor:
        z = torch.zeros_like(x).to(device)
        self.iterations = 0

        while self.iterations < self.max_iter:
            z_next = torch.tanh(self.linear(z) + x)
            self.err = torch.norm(z - z_next)
            z = z_next
            self.iterations += 1
            if self.err < self.tol:
                break

        return z


In [2]:
layer = TanhFixedPointLayer(50)
X = torch.randn(10, 50).to(device)
Z = layer(X)
print(f"terminated after {layer.iterations} iterations with error {layer.err}")


terminated after 14 iterations with error 4.8777183110360056e-05


Let's use this layer on MNIST data instead of random data.

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

mnist_train = datasets.MNIST(
    "data/", train=True, download=True, transform=transforms.ToTensor()
)
mnist_test = datasets.MNIST(
    "data/", train=False, download=True, transform=transforms.ToTensor()
)
train_loader = DataLoader(mnist_train, batch_size=100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=100, shuffle=False)


In [4]:
import torch.optim as optim

torch.manual_seed(0)
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 100),
    TanhFixedPointLayer(100, max_iter=200),
    nn.Linear(100, 10),
).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)


In [5]:
!jupyter nbextension enable --py widgetsnbextension
from typing import Tuple
from tqdm.notebook import tqdm


def epoch(
    loader: DataLoader, model: nn.Module, opt: optim.Optimizer = None, monitor=None
) -> Tuple[float, float, float]:
    total_loss, total_err, total_monitor = 0.0, 0.0, 0.0
    model.eval() if opt is None else model.train()

    for X, y in tqdm(loader, leave=False):
        X, y = X.to(device), y.to(device)
        yp = model(X)
        loss = nn.CrossEntropyLoss()(yp, y)
        if opt:
            opt.zero_grad()
            loss.backward()
            if sum(torch.sum(torch.isnan(p.grad)) for p in model.parameters()) == 0:
                opt.step()

        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
        if monitor is not None:
            total_monitor += monitor(model)

    return (
        total_err / len(loader.dataset),
        total_loss / len(loader.dataset),
        total_monitor / len(loader),
    )


Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [6]:
for i in range(10):
    if i == 5:
        opt.param_groups[0]["lr"] = 1e-2

    train_err, train_loss, train_fpiter = epoch(
        train_loader, model, opt, lambda x: x[2].iterations
    )
    test_err, test_loss, test_fpiter = epoch(
        test_loader, model, monitor=lambda x: x[2].iterations
    )
    print(
        f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, FP Iters: {train_fpiter:.2f} | "
        + f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, FP Iters: {test_fpiter:.2f}"
    )


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.1113, Loss: 0.4034, FP Iters: 53.40 | Test Error: 0.0717, Loss: 0.2418, FP Iters: 56.55


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0574, Loss: 0.1941, FP Iters: 53.09 | Test Error: 0.0498, Loss: 0.1639, FP Iters: 53.86


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0437, Loss: 0.1479, FP Iters: 56.71 | Test Error: 0.0459, Loss: 0.1469, FP Iters: 57.44


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0364, Loss: 0.1223, FP Iters: 64.10 | Test Error: 0.0384, Loss: 0.1335, FP Iters: 60.41


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0312, Loss: 0.1044, FP Iters: 77.47 | Test Error: 0.0379, Loss: 0.1218, FP Iters: 83.96


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0214, Loss: 0.0743, FP Iters: 80.36 | Test Error: 0.0318, Loss: 0.1057, FP Iters: 78.79


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0196, Loss: 0.0685, FP Iters: 76.08 | Test Error: 0.0313, Loss: 0.1033, FP Iters: 74.38


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0187, Loss: 0.0655, FP Iters: 77.63 | Test Error: 0.0306, Loss: 0.1051, FP Iters: 75.75


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0180, Loss: 0.0633, FP Iters: 79.81 | Test Error: 0.0300, Loss: 0.1045, FP Iters: 78.45


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0174, Loss: 0.0612, FP Iters: 80.81 | Test Error: 0.0304, Loss: 0.1035, FP Iters: 79.67


This is kinda disappointing. We achieve passable results, but at a considerable speed loss. Since so many iterations are being done, we're essentially running a 50-80 layer MLP to get performance on part with a single layer one. However, there are key differences (one of which is lack of exploding gradient) that will become more clear once we introduce a few more ideas.

## Alternative root-finding techniques
Above, we simply iterated the equation that defined the layer. However, we can use a faster root-finding method, such as Newton's method, to improve upond this dramatically. For some function $g: \mathbb{R}^n \to \mathbb{R}^n$, we want to find the root $g(z) = 0$. Then, we can do the update
\begin{equation*}
z := z - \left( \frac{\partial g}{\partial z}\right)^{-1} g(z)
\end{equation*}
where $\frac{\partial g}{\partial z}$ is the Jacobian of $g$ with respect to $z$. We could use automatic differentiation here, but it is simple to express the Jacobian in closed form here.
\begin{gather*}
g(x, z) = z - \tanh{(Wz + x)}\\
\frac{\partial g}{\partial z} = I - \text{diag}(\text{sech}^2(Wz + x))W
\end{gather*}

In [37]:
class TanhNewtonLayer(nn.Module):
    def __init__(
        self, out_features: int, tol: float = 1e-4, max_iter: int = 50
    ) -> None:
        super().__init__()
        self.linear = nn.Linear(out_features, out_features, bias=False).to(device)
        self.tol = tol
        self.max_iter = max_iter

    def forward(self, x) -> torch.Tensor:
        z = torch.tanh(x)
        self.iterations = 0

        while self.iterations < self.max_iter:
            z_linear = self.linear(z) + x
            g = z - torch.tanh(z_linear)
            self.err = torch.norm(g)
            if self.err < self.tol:
                break

            J = (
                torch.eye(z.shape[1]).to(device)[None, :, :]
                - (1 / torch.cosh(z_linear) ** 2)[:, :, None]
                * self.linear.weight[None, :, :]
            )
            z = z - torch.linalg.solve(J, g)
            self.iterations += 1

        g = z - torch.tanh(self.linear(z) + x)
        z[torch.norm(g, dim=1) > self.tol, :] = 0
        return z


In [38]:
layer = TanhNewtonLayer(50)
X = torch.randn(10, 50).to(device)
Z = layer(X)
print(f"iters: {layer.iterations}, err: {layer.err}")

iters: 3, err: 9.452685389987892e-07


This converges faster, but note that we are solving a linear system at every iteration, which is significantly more expensive. We can still plug it into the same training loop as before and see what happens.

In [39]:
torch.manual_seed(0)
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 100),
    TanhNewtonLayer(100, max_iter=40),
    nn.Linear(100, 10),
).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)

for i in range(8):
    if i == 5:
        opt.param_groups[0]["lr"] = 1e-2

    train_err, train_loss, train_fpiter = epoch(
        train_loader, model, opt, lambda x: x[2].iterations
    )
    test_err, test_loss, test_fpiter = epoch(
        test_loader, model, monitor=lambda x: x[2].iterations
    )
    print(
        f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, Newton Iters: {train_fpiter:.2f} | "
        + f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, Newton Iters: {test_fpiter:.2f}"
    )


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.1153, Loss: 0.4183, Newton Iters: 6.91 | Test Error: 0.0769, Loss: 0.2601, Newton Iters: 6.83


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0606, Loss: 0.2084, Newton Iters: 6.77 | Test Error: 0.0523, Loss: 0.1727, Newton Iters: 6.84


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0474, Loss: 0.1597, Newton Iters: 6.93 | Test Error: 0.0472, Loss: 0.1543, Newton Iters: 6.66


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0389, Loss: 0.1300, Newton Iters: 6.98 | Test Error: 0.0427, Loss: 0.1466, Newton Iters: 7.98


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0406, Loss: 0.1403, Newton Iters: 7.42 | Test Error: 0.0389, Loss: 0.1298, Newton Iters: 6.70


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0276, Loss: 0.0939, Newton Iters: 6.92 | Test Error: 0.0349, Loss: 0.1151, Newton Iters: 6.90


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0251, Loss: 0.0870, Newton Iters: 7.18 | Test Error: 0.0337, Loss: 0.1117, Newton Iters: 7.08


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0238, Loss: 0.0829, Newton Iters: 7.51 | Test Error: 0.0339, Loss: 0.1124, Newton Iters: 7.98


Like the above, this also functions. However, it is even slower despite the dramatically fewer number of iterations required. This is because it needs to invert the entire Jacobian matrix on every iteration for each sample in each minibatch. As the hidden size increases, this computation becomes increasingly expensive to the point of being intractable (taking inverses sucks). There are other quasi-Newton methods to improve on this aspect.

Another issue is that since Newton's method is implemented direction inside the autograd system, the intermediate states of the Jacobian and solver also increase. It is also numerically unstable for matrices that are near-singular (this is why the NaN check exists in the eopoch function).