In [None]:
import itertools
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from IPython.display import display

### Step 1: Load our dataset into memory.

In [None]:
mnist = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

### Step 2: Define our model. These will typically be off-the-shelf things specialized for your problem.

In [None]:
class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(4, 4, kernel_size=3, padding=1),
            nn.Upsample(size=(14, 14)),
            nn.Conv2d(4, 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(4, 4, kernel_size=3, padding=1),
            nn.Upsample(size=(7, 7)),
            nn.Conv2d(4, 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(4, 10, kernel_size=3, padding=1),
            nn.AdaptiveAvgPool2d(output_size=(1, 1))
        )

    def forward(self, x):
        b, c, h, w = x.shape
        return self.layers(x).view(b, 10).sigmoid()

### Step 3. Begin to set up training. Dataloader is a way to load in batches of training examples.

In [None]:
dataloader = torch.utils.data.DataLoader(
    mnist,
    batch_size=128,
    shuffle=True,
)

### Step 4. Create model and an optimizer that will apply gradient updates to the model.

In [None]:
model = Model().to("cuda:0")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

### Step 5. Train the model. The loop computes an error and then adjusts model params to make it go down.

In [None]:
losses = []

In [None]:
%%time 

for epoch in range(10):
    for images, labels in dataloader:
        images = images.to("cuda:0")
        labels = labels.to("cuda:0")

        optimizer.zero_grad()

        output = model(images) 

        loss = F.cross_entropy(output, labels)

        loss.backward()
        optimizer.step()   

        losses.append(loss.item())

### Step 6. Plot the losses. Make sure that the errors go down over time in a nice smooth curve.

In [None]:
plt.plot(losses)

### Step 7. Evaluate the trained model. Normally you would throw novel examples outside the training set at it.

In [None]:
def predict(tensor):
    tensor = tensor.to("cuda:0")
    tensor = tensor.view(1, 1, 28, 28)
    return model(tensor).argmax()

In [None]:
all_examples = list(mnist)

In [None]:
image, label = random.choice(all_examples)

print("input image:")
display(Image.fromarray(np.uint8(255 * image[0].numpy())))

print(f"actual label: {label}")
print(f"predicted label: {predict(image)}")