In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [53]:
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 15
learning_rate = 0.0001
dropout_prob = 0.5


In [4]:
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False,
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 137607316.13it/s]

Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 23210518.07it/s]


Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 43910142.77it/s]

Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3433765.10it/s]


Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw



Here we're using lstm

The LSTM uses a set of gates (input gate, forget gate, and output gate) to selectively update the hidden state. The updating process allows the model to decide what information to retain or forget from the current input and past hidden state.

The linear layer takes the hidden state of the last time step from the LSTM and transforms it to the output space.

What uniquely happens is, while each hidden state consists of weights that are updated based on previous outputs...and this helps the model retain the important information that passes on to the linear layer.

This learning process enables the model to assign higher importance (weights) to features that are relevant for the task, helping the model retain important information and make informed decisions.

LEARN MORE ABOUT RNN....

In [54]:
class RNN(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super(RNN, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
    self.fc = nn.Linear(hidden_size, num_classes)
    self.dropout = nn.Dropout(dropout_prob)
  def forward(self, x):

    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
    c0 = torch.zeros(self.num_layers, x.size(0),self.hidden_size).to(device)

    out, _ = self.lstm(x, (h0, c0))
    out = self.dropout(out)

    out = self.fc(out[:, -1, :])
    return out

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



In [55]:
total_step = len(train_loader)
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):
    images = images.reshape(-1, sequence_length, input_size).to(device)
    labels = labels.to(device)

    # forward pass
    outputs= model(images)
    loss = criterion(outputs, labels)

    # backward adnd optimize
    outputs = model(images)
    loss.backward()
    optimizer.step()

    if(i+1)%100 == 0:
      print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))



Epoch [1/15], Step [100/600], Loss: 2.1955
Epoch [1/15], Step [200/600], Loss: 1.8800
Epoch [1/15], Step [300/600], Loss: 1.6854
Epoch [1/15], Step [400/600], Loss: 1.7076
Epoch [1/15], Step [500/600], Loss: 1.7873
Epoch [1/15], Step [600/600], Loss: 1.3335
Epoch [2/15], Step [100/600], Loss: 1.5284
Epoch [2/15], Step [200/600], Loss: 1.3467
Epoch [2/15], Step [300/600], Loss: 1.0718
Epoch [2/15], Step [400/600], Loss: 1.0585
Epoch [2/15], Step [500/600], Loss: 0.6676
Epoch [2/15], Step [600/600], Loss: 0.8675
Epoch [3/15], Step [100/600], Loss: 0.6389
Epoch [3/15], Step [200/600], Loss: 0.4492
Epoch [3/15], Step [300/600], Loss: 0.6394
Epoch [3/15], Step [400/600], Loss: 0.7350
Epoch [3/15], Step [500/600], Loss: 0.6075
Epoch [3/15], Step [600/600], Loss: 0.5168
Epoch [4/15], Step [100/600], Loss: 0.8979
Epoch [4/15], Step [200/600], Loss: 0.6162
Epoch [4/15], Step [300/600], Loss: 0.3616
Epoch [4/15], Step [400/600], Loss: 0.6315
Epoch [4/15], Step [500/600], Loss: 0.4583
Epoch [4/15

In [56]:
model.eval()
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
    images = images.reshape(-1, sequence_length, input_size).to(device)
    labels = labels.to(device)
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))


Test Accuracy of the model on the 10000 test images: 97.44 %
