In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.nn import Parameter
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (1.0,))
])

In [3]:
from torchvision.datasets import MNIST

download_root = '.data/MNIST_DATASET'

train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
valid_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)
test_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)

In [4]:
batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [5]:
n_iters = 6000
num_epochs = n_iters / (len(train_dataset) / 100)
num_epochs = int(num_epochs)
num_epochs

10

In [6]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
        self.reset_parameters()
        
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)
    
    def forward(self, x , hidden):
        hx, cx = hidden
        x = x.view(-1, x.size(1))
        
        gates = self.x2h(x) + self.h2h(hx)
        gates = gates.squeeze()
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        
        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        cellgate = F.tanh(cellgate)
        outgate = F.sigmoid(outgate)
        
        cy = torch.mul(cx, forgetgate) +  torch.mul(ingate, cellgate)        
        hy = torch.mul(outgate, F.tanh(cy))   
        return (hy, cy)

In [7]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, bias=True):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        
        self.lstm = LSTMCell(input_dim, hidden_dim, layer_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).to(device)
        h0.requires_grad_(True)
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).to(device)
        c0.requires_grad_(True)
        
        outs = []
        cn= c0[0,:,:]
        hn = h0[0,:,:]
        
        for seq in range(x.size(1)):
            hn, cn = self.lstm(x[:,seq,:], (hn,cn))
            outs.append(hn)
            
        out = outs[-1].squeeze()
        out = self.fc(out)
        return out

In [8]:
input_dim = 28
hidden_dim = 128
layer_dim = 1
output_dim = 10

model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
model.to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

In [9]:
seq_dim = 28 
loss_list = []
iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):          
        images = images.view(-1, seq_dim, input_dim).to(device)
        labels = labels.to(device)
        images.requires_grad_(True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()        
        loss_list.append(loss.item())
        iter += 1
        
        if iter % 500 == 0:         
            correct = 0
            total = 0
            for images, labels in valid_loader:
                images = images.view(-1, seq_dim, input_dim).to(device)

                
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted.cpu() == labels.cpu()).sum()
            
            accuracy = 100 * correct / total
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))

Iteration: 500. Loss: 2.265587091445923. Accuracy: 18.59000015258789
Iteration: 1000. Loss: 0.705100953578949. Accuracy: 70.7300033569336
Iteration: 1500. Loss: 0.34851381182670593. Accuracy: 83.30000305175781
Iteration: 2000. Loss: 0.15733450651168823. Accuracy: 93.16999816894531
Iteration: 2500. Loss: 0.1560400277376175. Accuracy: 94.79000091552734
Iteration: 3000. Loss: 0.04884061589837074. Accuracy: 96.08999633789062
Iteration: 3500. Loss: 0.020289091393351555. Accuracy: 96.37000274658203
Iteration: 4000. Loss: 0.36019933223724365. Accuracy: 96.72000122070312
Iteration: 4500. Loss: 0.14839287102222443. Accuracy: 96.37000274658203
Iteration: 5000. Loss: 0.07194089889526367. Accuracy: 96.83999633789062
Iteration: 5500. Loss: 0.09262465685606003. Accuracy: 97.31999969482422
Iteration: 6000. Loss: 0.05349844694137573. Accuracy: 96.87000274658203
Iteration: 6500. Loss: 0.133072167634964. Accuracy: 97.66000366210938
Iteration: 7000. Loss: 0.09087693691253662. Accuracy: 97.73999786376953


In [10]:
def evaluate(model, val_iter):    
    corrects, total, total_loss = 0, 0, 0
    model.eval()
    for images, labels in val_iter:
        images = images.view(-1, seq_dim, input_dim).to(device)
        labels = labels.to(device)
        
        logit = model(images)
        loss = F.cross_entropy(logit, labels, reduction = "sum")
        _, predicted = torch.max(logit.data, 1)
        total += labels.size(0)
        total_loss += loss.item()
        corrects += (predicted == labels).sum().item()

    avg_loss = total_loss / len(val_iter.dataset)
    avg_accuracy = corrects / total
    return avg_loss, avg_accuracy

In [11]:
test_loss, test_acc = evaluate(model,test_loader)
print("Test Loss: %5.2f | Test Accuracy: %5.2f" % (test_loss, test_acc))

Test Loss:  0.07 | Test Accuracy:  0.98
