In [2]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable as V
import torchvision.transforms as transforms
import torchvision.datasets as dsets

In [3]:
train_dataset = dsets.MNIST(root = './data', train = True, transform = transforms.ToTensor(), download = True)

test_dataset = dsets.MNIST(root='./data', train = False, transform = transforms.ToTensor())

In [4]:
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = False)

In [8]:
class RNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(RNN, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        
        self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first = True, nonlinearity='relu')
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        h0 = V(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))
        
        out, hn = self.rnn(x, h0)
        out = self.fc(out[:,-1,:])
        
        return out
        

In [9]:
input_dim = 28
hidden_dim = 100
layer_dim = 1
output_dim = 10

In [10]:
model = RNN(input_dim, hidden_dim, layer_dim, output_dim)

In [11]:
criterion = nn.CrossEntropyLoss()

In [13]:
lr = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = lr)

In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 5
seq_dim = 28
it = 0
for e in range(epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = V(images.view(-1, seq_dim, input_dim).to(device))
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        
        loss = criterion(outputs, labels)
        
        loss.backward()
        
        optimizer.step()
        
        it+=1
        
        if it%500 == 0:
            correct=0
            total=0
            for images, labels in test_loader:
                images = V(images.view(-1, seq_dim, input_dim).to(device))
                outputs = model(images)
                
                _, predicted = torch.max(outputs.data, 1)
                if torch.cuda.is_available():
                    correct+= (predicted.cpu() == labels.cpu()).sum()
                else:
                    correct+= (predicted == labels).sum()
                    
                total+=labels.size(0)
                
            accuracy = 100 * correct/total
            print("iter: {} loss: {} accuracy: {}".format(it,loss.data,accuracy))

iter: 500 loss: 0.9657265543937683 accuracy: 65
iter: 1000 loss: 0.9033140540122986 accuracy: 76
iter: 1500 loss: 0.4439593553543091 accuracy: 81
iter: 2000 loss: 0.5289924144744873 accuracy: 84
iter: 2500 loss: 0.46200478076934814 accuracy: 86
iter: 3000 loss: 0.5108069181442261 accuracy: 89
