# Regular MNIST

The paper cites using a LeNet for CNN MNIST. Referring to `build_cnn_model_direct_mnist` in `model_builders.py`:

```python
with tf.name_scope('net') as scope:
    xx = Convolution2D(16, kernel_size=3, strides=1, init='he_normal', padding='valid', activation='relu', kernel_regularizer=l2(weight_decay))(preproc_images)
    xx = Convolution2D(16, 3, 3, init='he_normal', padding='valid', activation='relu', kernel_regularizer=l2(weight_decay))(xx)
    xx = MaxPooling2D((2, 2))(xx)
    xx = Convolution2D(16, 3, 3, init='he_normal', padding='valid', activation='relu', kernel_regularizer=l2(weight_decay))(xx)
    xx = BatchNormalization(momentum=0.5)(xx)
    xx = Convolution2D(16, 3, 3, init='he_normal', padding='valid', activation='relu', kernel_regularizer=l2(weight_decay))(xx)  # (8, 8)
    xx = MaxPooling2D((2, 2))(xx)  # (4, 4)
    xx = Flatten()(xx)
    xx = Dense(800, kernel_initializer='he_normal', activation='relu', kernel_regularizer=l2(weight_decay))(xx)
    xx = Dense(800, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(xx)
    xx = BatchNormalization(momentum=0.5)(xx)
    xx = Activation('relu')(xx)
    xx = Dense(500, kernel_initializer='he_normal', activation='relu', kernel_regularizer=l2(weight_decay))(xx)
    logits = Dense(10, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(xx)
    model = ExtendedModel(input=input_images, output=logits)
```

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader

## Data

In [52]:
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    # torchvision.transforms.Lambda(lambda x: torch.flatten(x))
])

In [53]:
train = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    # natively stored as PIL images
    transform=dataset_transform
)

In [54]:
test = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    train=False,
    transform=dataset_transform
)

In [55]:
train

Dataset MNIST
    Number of datapoints: 60000
    Root location: /home/tnwei/.torchdata/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [56]:
test

Dataset MNIST
    Number of datapoints: 10000
    Root location: /home/tnwei/.torchdata/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [57]:
train.data.shape

torch.Size([60000, 28, 28])

In [58]:
train_loader = DataLoader(train, batch_size=100, shuffle=True)
# Returns (torch.Size([100, 784]), torch.Size([100]))

In [59]:
test_loader = DataLoader(test, batch_size=500, shuffle=False)

## Net definition

In [21]:
class PaperLeNet(nn.Module):
    def __init__(self):
        """
        Paper cites using LeNet arch
        ref: https://arxiv.org/pdf/1804.08838.pdf
        
        But from the implementation in github, this isn't exactly LeNet:
        https://github.com/uber-research/intrinsic-dimension/blob/9754ebe1954e82973c7afe280d2c59850f281dca/intrinsic_dim/model_builders.py#L347
        """
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=1, padding="valid")
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=3, padding="valid")
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=3, padding="valid")
        
        # num_features is out_channels of prev conv?
        # at this point, the network has a 32x32 image reduced to 1x1!
        self.bn1 = nn.BatchNorm2d(num_features=16, momentum=0.5) 
        self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=3, padding="valid")
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.flatten1 = nn.Flatten()
        
        self.dense1 = ()
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.maxpool1(x)
        
        x = self.conv3(x)
        x = self.bn1(x)
        x = self.conv4(x)
        x = self.maxpool2(x)
        
        x = self.flatten1(x)
        return x

The current net above is a portion of the Tensorflow net replicated. The conv layer after the batchnorm simply doesn't have enough strides to work on! I double checked to ensure that I was reading keras docs correctly. Strange. 

```python
LeNet()(torch.randn(10, 1, 28, 28))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_48053/620529947.py in <module>
----> 1 LeNet()(torch.randn(10, 1, 28, 28))

~/miniconda3/envs/gan/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_48053/700110480.py in forward(self, x)
     32         x = self.conv3(x)
     33         x = self.bn1(x)
---> 34         x = self.conv4(x)
     35         x = self.maxpool2(x)
     36 

~/miniconda3/envs/gan/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/gan/lib/python3.9/site-packages/torch/nn/modules/conv.py in forward(self, input)
    444 
    445     def forward(self, input: Tensor) -> Tensor:
--> 446         return self._conv_forward(input, self.weight, self.bias)
    447 
    448 class Conv3d(_ConvNd):

~/miniconda3/envs/gan/lib/python3.9/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    440                             weight, bias, self.stride,
    441                             _pair(0), self.dilation, self.groups)
--> 442         return F.conv2d(input, weight, bias, self.stride,
    443                         self.padding, self.dilation, self.groups)
    444 

RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (3 x 3). Kernel size can't be greater than actual input size
```

The LeNet used in this paper isn't exactly a LeNet. There's a batchnorm mixed in in a really weird place. 

I prefer building a classic LeNet from other sources. Found the original paper on LeCun's personal website at http://yann.lecun.org/exdb/publis/pdf/lecun-89e.pdf. The individual neurons are described specifically in non-deep learning lingo, which makes it a bit difficult to read.

[PyImageSearch](https://www.pyimagesearch.com/2021/07/19/pytorch-training-your-first-convolutional-neural-network-cnn/) describes LeNet in a tutorial. The overall arch looks similar to the OG implementation, just that can't be sure on the exact parameters.

In [47]:
class LeNet(nn.Module):
    def __init__(self):
        """
        LeNet from PyImageSearch:
        https://www.pyimagesearch.com/2021/07/19/pytorch-training-your-first-convolutional-neural-network-cnn/
        """
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=(5, 5), stride=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
        self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=(5, 5), stride=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
        self.flatten1 = nn.Flatten()
        
        self.fc1 = nn.Linear(in_features=800, out_features=500)
        self.relu3 = nn.ReLU()
        
        self.fc2 = nn.Linear(in_features=500, out_features=10)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        x = self.flatten1(x)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.logsoftmax(x)
        
        return x

## Training

In [60]:
net = LeNet()
opt = torch.optim.Adam(net.parameters(), lr=1e-2)
num_epochs = 5

In [61]:
loss_history = []
acc_history = []

In [62]:
# Train
net.train()

for _ in range(num_epochs):
    for batch_id, (features, target) in enumerate(train_loader):
        # forward pass, calculate loss and backprop!
        opt.zero_grad()
        preds = net(features)
        loss = F.nll_loss(preds, target)
        loss.backward()
        loss_history.append(loss.item())
        opt.step()

        if batch_id % 100 == 0:
            print(loss.item())

2.3076186180114746
0.391210675239563
0.389596551656723
0.375776469707489
0.14125771820545197
0.23916095495224
0.13795307278633118
0.17592500150203705
0.2031337320804596
0.10884179174900055
0.2261829376220703
0.17889922857284546
0.1087673157453537
0.11061125993728638
0.18922947347164154
0.08641938120126724
0.18402229249477386
0.22066231071949005
0.1210862323641777
0.05056007206439972
0.11251470446586609
0.14379653334617615
0.12517106533050537
0.05036019906401634
0.16461332142353058
0.1605762243270874
0.11386830359697342
0.1159994974732399
0.05315348133444786
0.14150004088878632


In [63]:
# Test
net.eval()

test_loss = 0
correct = 0

for features, target in test_loader:
    output = net(features)
    test_loss += F.nll_loss(output, target).item()
    pred = torch.argmax(output, dim=-1) # get the index of the max log-probability
    correct += pred.eq(target).cpu().sum()

test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
accuracy = 100. * correct / len(test_loader.dataset)
acc_history.append(accuracy)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    accuracy))


Test set: Average loss: 0.1356, Accuracy: 9577/10000 (96%)



No surprises here.