In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct, RunBuilder
from rnns import RNN, GRU, LSTM, BLSTM

In [2]:
# declare hyperparameters
lr = 0.001
batch_size = 64
input_size = 28
hidden_size = 256
num_layers = 2
num_epochs = 8

In [3]:
# extract and transform the data
train_set = torchvision.datasets.MNIST(
    root='./data/',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
test_set = torchvision.datasets.MNIST(
    root='./data/',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
# prepare the data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=1)

In [4]:
# make an OrderedDict of RNNs
runs = OrderedDict(
    models = [
        RNN(input_size, hidden_size, num_layers),
        GRU(input_size, hidden_size, num_layers),
        LSTM(input_size, hidden_size, num_layers),
        BLSTM(input_size, hidden_size, num_layers)
    ]
)

In [5]:
# loss function (categorical cross-entropy)
criterion = nn.CrossEntropyLoss()

# iterate models in runs and train
for run in RunBuilder.get_runs(runs):
    model = run.models.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    net = type(model).__name__
    comment = f'-{net}'
    tb = SummaryWriter(comment=comment)

    for epoch in range(num_epochs):
        train_loss, train_correct = 0, 0

        model.train()
        train_loop = tqdm(train_loader)
        for batch in train_loop:
            images, labels = batch[0].squeeze(1).to(device), batch[1].to(device)
            preds = model(images)
            loss = criterion(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * labels.size(0)
            train_correct += get_num_correct(preds, labels)

            train_loop.set_description(f'{net:6s}Epoch [{epoch+1:2d}/{num_epochs}]')
            train_loop.set_postfix(loss=train_loss, acc=train_correct/len(train_set))


        tb.add_scalar('Train Loss', train_loss, epoch)
        tb.add_scalar('Train Accuracy', train_correct/len(train_set), epoch)
        

        model.eval()
        with torch.no_grad():
            test_loss, test_correct = 0, 0

            for batch in test_loader:
                images, labels = batch[0].squeeze(1).to(device), batch[1].to(device)
                preds = model(images)
                loss = criterion(preds, labels)

                test_loss += loss.item() * labels.size(0)
                test_correct += get_num_correct(preds, labels)

            tb.add_scalar('Test Loss', test_loss, epoch)
            tb.add_scalar('Test Accuracy', test_correct / len(test_set), epoch)


        for name, weight in model.named_parameters():
            tb.add_histogram(name, weight, epoch)
            tb.add_histogram(f'{name}.grad', weight.grad, epoch)

    torch.save(model.state_dict(), f'./models/with_rnns/model{comment}.ckpt')

RNN   Epoch [ 1/8]: 100%|██████████| 938/938 [00:13<00:00, 67.41it/s, acc=0.821, loss=3.32e+4]
RNN   Epoch [ 2/8]: 100%|██████████| 938/938 [00:13<00:00, 71.10it/s, acc=0.926, loss=1.55e+4]
RNN   Epoch [ 3/8]: 100%|██████████| 938/938 [00:13<00:00, 69.91it/s, acc=0.938, loss=1.32e+4]
RNN   Epoch [ 4/8]: 100%|██████████| 938/938 [00:13<00:00, 68.04it/s, acc=0.947, loss=1.12e+4]
RNN   Epoch [ 5/8]: 100%|██████████| 938/938 [00:13<00:00, 67.02it/s, acc=0.949, loss=1.07e+4]
RNN   Epoch [ 6/8]: 100%|██████████| 938/938 [00:13<00:00, 67.76it/s, acc=0.96, loss=8.58e+3]
RNN   Epoch [ 7/8]: 100%|██████████| 938/938 [00:13<00:00, 68.67it/s, acc=0.96, loss=8.48e+3]
RNN   Epoch [ 8/8]: 100%|██████████| 938/938 [00:13<00:00, 70.39it/s, acc=0.952, loss=1e+4]
GRU   Epoch [ 1/8]: 100%|██████████| 938/938 [00:14<00:00, 66.80it/s, acc=0.895, loss=1.89e+4]
GRU   Epoch [ 2/8]: 100%|██████████| 938/938 [00:14<00:00, 66.33it/s, acc=0.977, loss=4.38e+3]
GRU   Epoch [ 3/8]: 100%|██████████| 938/938 [00:14<00: