# Regular MNIST

Cobbled together a regular MNIST training loop, referring to:
+ https://gist.github.com/kdubovikov/eb2a4c3ecadd5295f68c126542e59f0a
+ https://github.com/uber-research/intrinsic-dimension/blob/master/intrinsic_dim/model_builders.py#L81
+ https://arxiv.org/pdf/1804.08838.pdf

In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader

## Data

In [42]:
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Lambda(lambda x: torch.flatten(x))
])

In [43]:
train = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    # natively stored as PIL images
    transform=dataset_transform
)

In [45]:
test = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    train=False,
    transform=dataset_transform
)

In [46]:
train

Dataset MNIST
    Number of datapoints: 60000
    Root location: /home/tnwei/.torchdata/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Lambda()
           )

In [47]:
test

Dataset MNIST
    Number of datapoints: 10000
    Root location: /home/tnwei/.torchdata/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Lambda()
           )

In [48]:
train.data.shape

torch.Size([60000, 28, 28])

In [49]:
train_loader = DataLoader(train, batch_size=100, shuffle=True)
# Returns (torch.Size([100, 784]), torch.Size([100]))

In [50]:
test_loader = DataLoader(test, batch_size=500, shuffle=False)

In [59]:
a[0].shape, a[1].shape

(torch.Size([100, 784]), torch.Size([100]))

## Net definition

In [88]:
class BasicMNIST(nn.Module):
    def __init__(self):
        """
        Paper uses 784-200-200-10
        ref: https://arxiv.org/pdf/1804.08838.pdf
        
        Ref in github:
        https://github.com/uber-research/intrinsic-dimension/blob/9754ebe1954e82973c7afe280d2c59850f281dca/intrinsic_dim/model_builders.py#L81
        """
        super().__init__()
        self.hidden1 = nn.Linear(784, 200)
        self.hidden2 = nn.Linear(200, 10)
        
    def forward(self, x):
        x = self.hidden1(x)
        x = F.relu(x)
        x = self.hidden2(x)
        x = F.relu(x)
        x = F.log_softmax(x, dim=-1)  # (batch_size, dims)
        return x

## Training

In [100]:
net = BasicMNIST()
opt = torch.optim.Adam(net.parameters(), lr=1e-2)
num_epochs = 5

In [101]:
loss_history = []
acc_history = []

In [103]:
# Train
net.train()

for _ in range(num_epochs):
    for batch_id, (features, target) in enumerate(train_loader):
        # forward pass, calculate loss and backprop!
        opt.zero_grad()
        preds = net(features)
        loss = F.nll_loss(preds, target)
        loss.backward()
        loss_history.append(loss.item())
        opt.step()

        if batch_id % 100 == 0:
            print(loss.item())

2.3020222187042236
0.6678874492645264
0.46086710691452026
0.6339600086212158
0.7151810526847839
0.7028355598449707
0.4760339856147766
0.5044549107551575
0.6090096235275269
0.5877191424369812
0.7133115530014038
0.48733580112457275
0.6384894251823425
0.35630112886428833
0.5135530233383179
0.49185654520988464
0.5114858746528625
0.5866591930389404
0.6023568511009216
0.5403339266777039
0.5768228769302368
0.6577960252761841
0.5027185678482056
0.4349723160266876
0.5165771245956421
0.6568780541419983
0.6702275276184082
0.6917394399642944
0.5791258811950684
0.4683903753757477


In [104]:
# Test
net.eval()

test_loss = 0
correct = 0

for features, target in test_loader:
    output = net(features)
    test_loss += F.nll_loss(output, target).item()
    pred = torch.argmax(output, dim=-1) # get the index of the max log-probability
    correct += pred.eq(target).cpu().sum()

test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
accuracy = 100. * correct / len(test_loader.dataset)
acc_history.append(accuracy)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    accuracy))


Test set: Average loss: 0.6160, Accuracy: 8555/10000 (86%)

