In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
torch.cuda.is_available()

True

In [3]:
input_size = 28
sequence_length = 28
num_layers = 2
hidden_size = 256
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 2

class LSTM(nn.Module):
  def __init__(self,input_size,hidden_size, num_layers,num_classes):
    super().__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
    self.fc = nn.Linear(hidden_size*2,num_classes)

  def forward(self,x):
    h0 = torch.zeros(self.num_layers*2,x.size(0),self.hidden_size).to(device)
    c0 = torch.zeros(self.num_layers*2,x.size(0),self.hidden_size).to(device)

    out, _ = self.lstm(x,(h0,c0))
    out = self.fc(out[:,-1,:])
    return out



In [4]:
train_dataset = datasets.MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [5]:
model = LSTM(input_size,hidden_size,num_layers,num_classes).to(device)

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=learning_rate)

In [7]:
def save_model(state,filename='my_checkpoint.pth.tar'):
  torch.save(state,filename)

def load_model(model,optimizer,checkpoint):
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

In [11]:
load_model(model,optimizer,torch.load('my_checkpoint.pth.tar'))

In [8]:
model.train()
for epoch in range(num_epochs):
  for batch_idx, (data,targets) in enumerate(train_loader):
    data = data.to(device=device).squeeze(1)
    targets = targets.to(device=device)
    # forward
    scores = model(data)
    loss = criterion(scores,targets)
    # backward
    optimizer.zero_grad()
    loss.backward()
    # gradient descent step
    optimizer.step()

checkpoint = { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }
save_model(checkpoint)

In [8]:
def check_accuracy(loader,model):
  num_correct = 0
  num_samples = 0
  model.eval()
  with torch.no_grad():
    for x, y in loader:
      x = x.to(device=device).squeeze(1)
      y = y.to(device=device)
      scores = model(x)
      _, predictions = scores.max(1)
      num_correct += (predictions == y).sum()
      num_samples += predictions.size(0)
  print(f'correct/false: {num_correct}/{num_samples}, accuracy: {float(num_correct)/float(num_samples)}')

In [9]:
check_accuracy(test_loader,model)

correct/false: 9803/10000, accuracy: 0.9803
