## Description

**Dataset**

Each example is a sequence of T random MNIST digits. Goal is to predict the sum.

**Model**

A CNN processes each digit. The activations are fed into an LSTM.



## Install

In [0]:
!pip3 install torch torchvision numpy

## Imports

In [0]:
from matplotlib import pyplot as plt
import numpy as np

import torch as th
from torch import nn
import torchvision
from torchvision import transforms

## Config

In [91]:
debug = False
device = th.device('cuda' if th.cuda.is_available() and not debug else 'cpu')
print(f'Using {device}')



sequence_length = 2
conv_out_size = 64
hidden_size = 128
num_layers = 2
num_classes = 1 + 9 * sequence_length

num_epochs = 10
batch_size = 100
learning_rate = 0.001

Using cuda


## Dataset: MNIST Pair Sum 

In [0]:
# Download and construct MNIST dataset.
train_dataset = torchvision.datasets.MNIST(root='~/code/data/mnist/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='~/code/data/mnist/',
                                          train=False,
                                          transform=transforms.ToTensor(),
                                          download=True)

# Data loader (input pipeline)
train_loader = th.utils.data.DataLoader(dataset=train_dataset,
                                        batch_size=sequence_length * batch_size,
                                        shuffle=True)
test_loader = th.utils.data.DataLoader(dataset=test_dataset,
                                       batch_size=sequence_length * batch_size,
                                       shuffle=False)

def encode_sequence_data(images, labels):
  # Create sequence by chunking along batch dimension.
  # [batch, ...] -> [seq_len, batch, ...]
  assert images.size(0) % sequence_length == 0
  assert labels.size(0) % sequence_length == 0
  image_chunks = th.chunk(images, sequence_length)
  images = th.stack(image_chunks)
  label_chunks = th.chunk(labels, sequence_length)
  labels = th.stack(label_chunks)
  labels = th.sum(labels, dim=0)  # [seq_len, batch, ...] -> [batch, ...]
  return images, labels

## Model (Conv RNN)

In [0]:
# Convolutional neural network (with 2 convolutional layers).
class CNN(nn.Module):
  def __init__(self, num_outputs):
    super().__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.layer2 = nn.Sequential(
        nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.fc = nn.Linear(7 * 7 * 32, num_outputs)
    
  def forward(self, x):
    out = self.layer1(x)
    out = self.layer2(out)
    out = out.reshape(out.size(0), -1)
    out = self.fc(out)
    return out

In [0]:
class ConvRNN(nn.Module):
  def __init__(self, conv_net, conv_output_size, hidden_size, num_layers,
               num_classes):
    super(ConvRNN, self).__init__()
    self.conv_net = conv_net
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(conv_output_size, hidden_size, num_layers)
    self.fc = nn.Linear(hidden_size, num_classes)
    
  def forward(self, x):
    """
    Args:
      x: input tensor, shaped [seq_len, batch_size, channel, height, width]
    """
    # Set initial hidden and cell states.
    h0 = th.zeros(self.num_layers, x.size(1), self.hidden_size).to(device)
    c0 = th.zeros(self.num_layers, x.size(1), self.hidden_size).to(device)
    
    # Apply conv_net to get activations.
    orig_shape = x.shape
    combined_batch_shape = th.Size([-1]) + x.shape[2:]
    x = x.view(*combined_batch_shape)
    conv_outs = self.conv_net(x)
    conv_outs = conv_outs.view(orig_shape[:2] + th.Size([-1]))
    
    # Forward propagate LSTM.
    # out shape (seq_len, batch_size, hid_size)
    out, _ = self.lstm(conv_outs, (h0, c0))
    
    # Decode the hidden state of the last time step.
    out = self.fc(out[-1, :, :])
    return out

In [0]:
cnn = CNN(conv_out_size)
model = ConvRNN(cnn, conv_out_size, hidden_size, num_layers, num_classes).to(device)

## Train

In [96]:
# Loss and optimizer.
loss_fn = nn.CrossEntropyLoss()
optimizer = th.optim.Adam(model.parameters(), lr=learning_rate)


num_steps = len(train_loader)
for epoch in range(num_epochs):
  for step, (images, labels) in enumerate(train_loader):
    images, labels = encode_sequence_data(images, labels)
    images = images.to(device)
    labels = labels.to(device)

    # Forward
    outputs = model(images)
    loss = loss_fn(outputs, labels)
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (step + 1) % 100 == 0:
      print(f'Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{num_steps}], '
            f'Loss: {loss.item():.4}')  

Epoch [1/10], Step [100/300], Loss: 2.124
Epoch [1/10], Step [200/300], Loss: 1.54
Epoch [1/10], Step [300/300], Loss: 1.247
Epoch [2/10], Step [100/300], Loss: 0.7427
Epoch [2/10], Step [200/300], Loss: 0.5738
Epoch [2/10], Step [300/300], Loss: 0.4459
Epoch [3/10], Step [100/300], Loss: 0.3697
Epoch [3/10], Step [200/300], Loss: 0.333
Epoch [3/10], Step [300/300], Loss: 0.3481
Epoch [4/10], Step [100/300], Loss: 0.1263
Epoch [4/10], Step [200/300], Loss: 0.3218
Epoch [4/10], Step [300/300], Loss: 0.09702
Epoch [5/10], Step [100/300], Loss: 0.1656
Epoch [5/10], Step [200/300], Loss: 0.07498
Epoch [5/10], Step [300/300], Loss: 0.0914
Epoch [6/10], Step [100/300], Loss: 0.03339
Epoch [6/10], Step [200/300], Loss: 0.167
Epoch [6/10], Step [300/300], Loss: 0.1996
Epoch [7/10], Step [100/300], Loss: 0.09553
Epoch [7/10], Step [200/300], Loss: 0.07876
Epoch [7/10], Step [300/300], Loss: 0.1219
Epoch [8/10], Step [100/300], Loss: 0.05022
Epoch [8/10], Step [200/300], Loss: 0.1675
Epoch [8/10

## Test

In [97]:
with th.no_grad():
  correct, total = 0, 0
  for images, labels in test_loader:
    images, labels = encode_sequence_data(images, labels)
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    _, predicted = th.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
  accuracy = correct / total
  print(f'Accuracy of model on 10000 test images: {100 * accuracy:0.2f}%')

Accuracy of model on 10000 test images: 97.72%


## Save model

In [0]:
th.save(model.state_dict(), '/tmp/mnist_rnn.ckpt')