### Training with data augmentation

To combat potential overfitting, we can add some basic data augmentation. When we load a training sample, we will randomly shift it by up to one pixel in the vertical and/or horizontal directions.

We also increase the number of passes again to 60, since data augmentation (1) increases the effective "size" of the training set, and (2) also adds some "noise" to the training process.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
# The following code is adapted from "modern.py" in the repository.

class ModernDataAugNet(nn.Module):

    def __init__(self):
        super().__init__()

        # initialization as described in the paper to my best ability, but it doesn't look right...
        winit = lambda fan_in, *shape: (torch.rand(*shape) - 0.5) * 2 * 2.4 / fan_in**0.5
        macs = 0 # keep track of MACs (multiply accumulates)
        acts = 0 # keep track of number of activations

        # H1 layer parameters and their initialization
        self.H1w = nn.Parameter(winit(5*5*1, 12, 1, 5, 5))
        self.H1b = nn.Parameter(torch.zeros(12, 8, 8)) # presumably init to zero for biases
        macs += (5*5*1) * (8*8) * 12
        acts += (8*8) * 12

        # H2 layer parameters and their initialization
        """
        H2 neurons all connect to only 8 of the 12 input planes, with an unspecified pattern
        I am going to assume the most sensible block pattern where 4 planes at a time connect
        to differently overlapping groups of 8/12 input planes. We will implement this with 3
        separate convolutions that we concatenate the results of.
        """
        self.H2w = nn.Parameter(winit(5*5*8, 12, 8, 5, 5))
        self.H2b = nn.Parameter(torch.zeros(12, 4, 4)) # presumably init to zero for biases
        macs += (5*5*8) * (4*4) * 12
        acts += (4*4) * 12

        # H3 is a fully connected layer
        self.H3w = nn.Parameter(winit(4*4*12, 4*4*12, 30))
        self.H3b = nn.Parameter(torch.zeros(30))
        macs += (4*4*12) * 30
        acts += 30

        # output layer is also fully connected layer
        self.outw = nn.Parameter(winit(30, 30, 10))
        self.outb = nn.Parameter(torch.zeros(10))
        macs += 30 * 10
        acts += 10

        self.macs = macs
        self.acts = acts

    def forward(self, x):

        # Note: basic data augmentation by 1 pixel along x/y directions
        if self.training:
            shift_x, shift_y = np.random.randint(-1, 2, size=2)
            x = torch.roll(x, (shift_x, shift_y), (2, 3))

        # x has shape (1, 1, 16, 16)
        x = F.pad(x, (2, 2, 2, 2), 'constant', -1.0) # pad by two using constant -1 for background
        x = F.conv2d(x, self.H1w, stride=2) + self.H1b
        x = torch.tanh(x)

        # x is now shape (1, 12, 8, 8)
        x = F.pad(x, (2, 2, 2, 2), 'constant', -1.0) # pad by two using constant -1 for background
        slice1 = F.conv2d(x[:, 0:8], self.H2w[0:4], stride=2) # first 4 planes look at first 8 input planes
        slice2 = F.conv2d(x[:, 4:12], self.H2w[4:8], stride=2) # next 4 planes look at last 8 input planes
        slice3 = F.conv2d(torch.cat((x[:, 0:4], x[:, 8:12]), dim=1), self.H2w[8:12], stride=2) # last 4 planes are cross
        x = torch.cat((slice1, slice2, slice3), dim=1) + self.H2b
        x = torch.tanh(x)

        # x is now shape (1, 12, 4, 4)
        x = x.flatten(start_dim=1) # (1, 12*4*4)
        x = x @ self.H3w + self.H3b
        x = torch.tanh(x)

        # x is now shape (1, 30)
        x = x @ self.outw + self.outb

         # x is finally shape (1, 10)
        return x

In [None]:
# The following code is adapted from "modern.py" in the repository.

learning_rate = 3e-4

# init rng
torch.manual_seed(1337)
np.random.seed(1337)
torch.use_deterministic_algorithms(True)

# init a model
model = ModernDataAugNet()
print("model stats:")
print("# params:      ", sum(p.numel() for p in model.parameters())) # in paper total is 9,760
print("# MACs:        ", model.macs)
print("# activations: ", model.acts)

# init data
Xtr, Ytr = torch.load('train1989.pt')
Xte, Yte = torch.load('test1989.pt')

# init optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

def eval_split(split):
    # eval the full train/test set, batched implementation for efficiency
    model.eval()
    X, Y = (Xtr, Ytr) if split == 'train' else (Xte, Yte)
    Yhat = model(X)
    loss = F.cross_entropy(Yhat, Y.argmax(dim=1))
    err = torch.mean((Y.argmax(dim=1) != Yhat.argmax(dim=1)).float())
    print(f"eval: split {split:5s}. loss {loss.item():e}. error {err.item()*100:.2f}%. misses: {int(err.item()*Y.size(0))}")
  
# train
for pass_num in range(60):

    # learning rate decay
    alpha = pass_num / 79
    for g in optimizer.param_groups:
        g['lr'] = (1 - alpha) * learning_rate + alpha * (learning_rate / 3)

    # perform one epoch of training
    model.train()
    for step_num in range(Xtr.size(0)):

        # fetch a single example into a batch of 1
        x, y = Xtr[[step_num]], Ytr[[step_num]]

        # forward the model and the loss
        yhat = model(x)
        loss = F.cross_entropy(yhat, y.argmax(dim=1))

        # calculate the gradient and update the parameters
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    # after epoch epoch evaluate the train and test error / metrics
    print(pass_num + 1)
    eval_split('train')
    eval_split('test')

# save final model to file
torch.save(model.state_dict(), 'dataug_model.pt')

model stats:
# params:       9760
# MACs:         63660
# activations:  1000
1
eval: split train. loss 5.288423e-01. error 13.48%. misses: 982
eval: split test . loss 5.376812e-01. error 13.65%. misses: 274
2
eval: split train. loss 3.618094e-01. error 9.74%. misses: 710
eval: split test . loss 3.622440e-01. error 10.21%. misses: 204
3
eval: split train. loss 2.770010e-01. error 7.78%. misses: 567
eval: split test . loss 2.691664e-01. error 7.77%. misses: 155
4
eval: split train. loss 2.191166e-01. error 6.32%. misses: 460
eval: split test . loss 2.131962e-01. error 5.88%. misses: 117
5
eval: split train. loss 1.949559e-01. error 5.71%. misses: 416
eval: split test . loss 2.028036e-01. error 6.13%. misses: 122
6
eval: split train. loss 1.760966e-01. error 5.09%. misses: 371
eval: split test . loss 1.767326e-01. error 5.28%. misses: 105
7
eval: split train. loss 1.579356e-01. error 4.69%. misses: 341
eval: split test . loss 1.632146e-01. error 5.08%. misses: 102
8
eval: split train. los

Data augmentation - used in combination with our other improvements - seems to help a lot! Our test error is now comfortably under 3%.