# `tanh` vs `ReLU` in a small CNN

Section 3.1 in the [AlexNet](https://proceedings.neurips.cc/paper_files/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)
paper claims that `tanh` takes much longer to reach 25% accuracy on the training
data than `ReLU`. [PDLT](https://arxiv.org/abs/2106.10165) shows that `tanh` can
reach criticality if it's initialized correctly. 

Let's test this! I want to use a really small network, so how about a two layer
CNN with 1 linear layer on MNIST?

Here's Figure 1 in the AlexNet paper: 

![](../assets/alexnet-fig1.png)

In [5]:
from typing import Literal

import torch
import torch.nn as nn
import torchvision

In [7]:
dataset = torchvision.datasets.MNIST(
    root="../datasets",
    # ToTensor scales [0,255] to [0.0, 1.0]
    transform=torchvision.transforms.ToTensor(),
    download=True,
)
dataloader = torch.utils.data.DataLoader(dataset)

In [8]:
dataset.class_to_idx

{'0 - zero': 0,
 '1 - one': 1,
 '2 - two': 2,
 '3 - three': 3,
 '4 - four': 4,
 '5 - five': 5,
 '6 - six': 6,
 '7 - seven': 7,
 '8 - eight': 8,
 '9 - nine': 9}

In [9]:
class FourLayerCNN(nn.Module):
    """4 Conv2Ds in a row with max pooling. Then"""

    def __init__(self, activation_name: Literal["relu", "tanh"], n_classes=100):
        activation = nn.ReLU if activation_name == "relu" else nn.Tanh
        super().__init__()
        # MNIST is (28,28,1)
        self.layers = nn.Sequential(
            # (b, 1, 28, 28) -> (b, 32, 28, 28)
            nn.Conv2d(1, 32, (3, 2), stride=1, padding="same"),
            activation(),
            # (b, 32, 28, 28) -> (b, 32, 14, 14)
            nn.MaxPool2d((2, 2), stride=2),
            # (b, 32, 14, 14) -> (b, 64, 14, 14)
            nn.Conv2d(32, 64, (3, 2), stride=1, padding="same"),
            activation(),
            # (b, 64, 14, 14) -> (b, 64, 7, 7)
            nn.MaxPool2d((2, 2), stride=2),
            # (b, 64, 7, 7) -> (b, 3136)
            nn.Flatten(),
            # (b, 3136) -> (b, 128)
            nn.Linear(3136, 128),
            activation(),
            nn.Linear(128, n_classes),
            nn.Softmax(dim=1),
        )

        # Initialization based on PDLT conditions for criticality
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                if activation_name == "relu":
                    nn.init.normal(
                        layer.weight,
                        mean=0,
                        std=(2 / (layer.kernel_size[0] * layer.kernel_size[1])) ** (1 / 2),
                    )
                    nn.init.constant(layer.bias, 0)
                else:
                    nn.init.normal(
                        layer.weight,
                        mean=0,
                        std=(1 / (layer.kernel_size[0] * layer.kernel_size[1])) ** (1 / 2),
                    )
                    nn.init.constant(layer.bias, 0)
            elif isinstance(layer, nn.Linear):
                if activation_name == "relu":
                    nn.init.normal(layer.weight, mean=0, std=(2 / layer.out_features) ** (1 / 2))
                    nn.init.constant(layer.bias, 0)
                else:
                    nn.init.normal(layer.weight, mean=0, std=(1 / layer.out_features) ** (1 / 2))
                    nn.init.constant(layer.bias, 0)

    def forward(self, x):
        return self.layers(x)