## Install

In [0]:
!pip3 install torch torchvision numpy

## Imports

In [0]:
from matplotlib import pyplot as plt
import numpy as np

import torch as th
from torch import nn
import torchvision
from torchvision import transforms

## Config

In [5]:
device = th.device('cuda' if th.cuda.is_available() else 'cpu')
print(f'Using {device}')



sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10

num_epochs = 2
batch_size = 100
learning_rate = 0.003

Using cuda


## MNIST Dataset

In [0]:
# Download and construct MNIST dataset.
train_dataset = torchvision.datasets.MNIST(root='~/code/data/mnist/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='~/code/data/mnist/',
                                          train=False,
                                          transform=transforms.ToTensor(),
                                          download=True)

# Data loader (input pipeline)
train_loader = th.utils.data.DataLoader(dataset=train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
test_loader = th.utils.data.DataLoader(dataset=test_dataset,
                                       batch_size=batch_size,
                                       shuffle=False)

## Model (Bi-directional RNN)

In [0]:
class BiRNN(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super(BiRNN, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True,
                        bidirectional=True)
    self.fc = nn.Linear(hidden_size * 2, num_classes)  # 2 for bidirectional
    
  def forward(self, x):
    # Set initial hidden and cell states. x2 for bidirectional.
    h0 = th.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
    c0 = th.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
    
    # Forward propagate LSTM.
    out, _ = self.lstm(x, (h0, c0))  # out shape (batch_size, seq_len, hid_size * 2)
    
    # Decode the hidden state of the last time step.
    out = self.fc(out[:, -1, :])
    return out

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

## Train

In [10]:
# Loss and optimizer.
loss_fn = nn.CrossEntropyLoss()
optimizer = th.optim.Adam(model.parameters(), lr=learning_rate)


num_steps = len(train_loader)
for epoch in range(num_epochs):
  for step, (images, labels) in enumerate(train_loader):
    images = images.reshape(-1, sequence_length, input_size)
    images = images.to(device)
    labels = labels.to(device)

    # Forward
    outputs = model(images)
    loss = loss_fn(outputs, labels)
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (step + 1) % 100 == 0:
      print(f'Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{num_steps}], '
            f'Loss: {loss.item():.4}')  

Epoch [1/2], Step [100/600], Loss: 0.6131
Epoch [1/2], Step [200/600], Loss: 0.3457
Epoch [1/2], Step [300/600], Loss: 0.176
Epoch [1/2], Step [400/600], Loss: 0.1556
Epoch [1/2], Step [500/600], Loss: 0.1344
Epoch [1/2], Step [600/600], Loss: 0.1192
Epoch [2/2], Step [100/600], Loss: 0.1553
Epoch [2/2], Step [200/600], Loss: 0.0965
Epoch [2/2], Step [300/600], Loss: 0.07829
Epoch [2/2], Step [400/600], Loss: 0.132
Epoch [2/2], Step [500/600], Loss: 0.0866
Epoch [2/2], Step [600/600], Loss: 0.09273


## Test

In [12]:
with th.no_grad():
  correct, total = 0, 0
  for images, labels in test_loader:
    images = images.reshape(-1, sequence_length, input_size)
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    _, predicted = th.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
  accuracy = correct / total
  print(f'Accuracy of model on 10000 test images: {100 * accuracy:0.2f}%')

Accuracy of model on 10000 test images: 97.82%


## Save model

In [0]:
th.save(model.state_dict(), '/tmp/mnist_birnn.ckpt')