## Install

In [0]:
!pip3 install torch torchvision numpy



## Imports

In [0]:
from matplotlib import pyplot as plt
import numpy as np
import torch as th
from torch import nn
import torchvision
from torchvision import transforms

## MNIST Dataset

In [0]:
# Download and construct MNIST dataset.
train_dataset = torchvision.datasets.MNIST(root='~/code/data/mnist/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='~/code/data/mnist/',
                                          train=False,
                                          transform=transforms.ToTensor(),
                                          download=True)

## Model

In [0]:
# Hyper-params
input_size = 28 * 28
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001


# Data loader (input pipeline)
train_loader = th.utils.data.DataLoader(dataset=train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
test_loader = th.utils.data.DataLoader(dataset=test_dataset,
                                       batch_size=batch_size,
                                       shuffle=False)

# Model
model = nn.Linear(input_size, num_classes)

# Loss and optimizer.
loss_fn = nn.CrossEntropyLoss()
optimizer = th.optim.SGD(model.parameters(), lr=learning_rate)

## Train

In [41]:
num_steps = len(train_loader)
for epoch in range(num_epochs):
  for step, (images, labels) in enumerate(train_loader):
    # Forward
    images = images.reshape(-1, input_size)
    outputs = model(images)
    loss = loss_fn(outputs, labels)
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (step + 1) % 100 == 0:
      print(f'Epoch [{epoch}/{num_epochs}], Step [{step}/{num_steps}], '
            f'Loss: {loss.item():.4}')

Epoch [0/5], Step [99/600], Loss: 2.205
Epoch [0/5], Step [199/600], Loss: 2.058
Epoch [0/5], Step [299/600], Loss: 2.062
Epoch [0/5], Step [399/600], Loss: 1.972
Epoch [0/5], Step [499/600], Loss: 1.886
Epoch [0/5], Step [599/600], Loss: 1.814
Epoch [1/5], Step [99/600], Loss: 1.777
Epoch [1/5], Step [199/600], Loss: 1.676
Epoch [1/5], Step [299/600], Loss: 1.575
Epoch [1/5], Step [399/600], Loss: 1.601
Epoch [1/5], Step [499/600], Loss: 1.596
Epoch [1/5], Step [599/600], Loss: 1.436
Epoch [2/5], Step [99/600], Loss: 1.402
Epoch [2/5], Step [199/600], Loss: 1.336
Epoch [2/5], Step [299/600], Loss: 1.298
Epoch [2/5], Step [399/600], Loss: 1.228
Epoch [2/5], Step [499/600], Loss: 1.296
Epoch [2/5], Step [599/600], Loss: 1.271
Epoch [3/5], Step [99/600], Loss: 1.165
Epoch [3/5], Step [199/600], Loss: 1.131
Epoch [3/5], Step [299/600], Loss: 1.308
Epoch [3/5], Step [399/600], Loss: 1.137
Epoch [3/5], Step [499/600], Loss: 1.141
Epoch [3/5], Step [599/600], Loss: 1.048
Epoch [4/5], Step [9

## Test

In [39]:
with th.no_grad():
  correct, total = 0, 0
  for images, labels in test_loader:
    images = images.reshape(-1, input_size)
    outputs = model(images)
    _, predicted = th.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
  accuracy = correct / total
  print(f'Accuracy of model on 10000 test images: {100 * accuracy:0.2f}%')

Accuracy of model on 10000 test images: 85.43%


## Save model

In [0]:
th.save(model.state_dict(), '/tmp/model.ckpt')