In [49]:
import torch

# 0. Hyper Parameters

In [50]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 10
num_classes = 10
batch_size = 100
num_epochs = 8

In [51]:
import torchvision
import torchvision.transforms as transforms

# 1. Data Load

In [52]:
train_data = torchvision.datasets.MNIST(root='./datasets',
                                        train=True,
                                        transform=transforms.ToTensor(),
                                        download=True)
test_data = torchvision.datasets.MNIST(root='./datasets',
                                        train=False,
                                        transform=transforms.ToTensor(),
                                        download=True)

# 2. Define Dataloader

In [53]:
train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                          batch_size=batch_size,
                                          shuffle=False)

In [54]:
a, b = next(iter(test_loader))
a.shape

torch.Size([100, 1, 28, 28])

# 3. Define Models

In [55]:
import torch.nn as nn
import torch.nn.functional as F

In [56]:
class Mnist(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes, model="RNN", drop_percent=0.2):
    super(Mnist, self).__init__()

    self.model = model
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    
    self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
    self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
    self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
    self.dropout = nn.Dropout(drop_percent)

    self.fc = nn.Linear(hidden_size, num_classes)


  def forward(self, x):
    if self.model == "RNN":
      h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # torch.size([2, 50, 128])
      out, hidden = self.rnn(x, h0)
    elif self.model == "LSTM":
      h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # torch.size([2, 50, 128])
      c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # torch.size([2, 50, 128])
      out, (hidden, cell_state) = self.lstm(x, (h0, c0))
      # out = self.dropout(out)
    elif self.model == "GRU":
      out, hidden = self.gru(x)
    else:
      print("choose a model in ['RNN', 'LSTM', 'GRU']")
      raise

    # out = self.dropout(out)
    out = self.fc(out[:,-1,:])
    return out



# 4-1. RNN Model Set Loss / Optimizer

In [57]:
learning_rate = 0.0001

rnn = Mnist(input_size, hidden_size, num_layers, num_classes, model='RNN').to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)

# 5-1. RNN Train

In [58]:
# RNN

####### Train #######
total_step = len(train_loader)
for epoch in range(num_epochs):
  for i, (image, label) in enumerate(train_loader):
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)

    # Forward
    output = rnn(image)
    loss = criterion(output, label)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i+1) % 400 == 0 or (i+1) == total_step:
      print(f"Epoch [{epoch+1}/{num_epochs}], Step[{i+1}/{total_step}], Loss:{loss.item():.4f}")

Epoch [1/8], Step[400/600], Loss:0.3970
Epoch [1/8], Step[600/600], Loss:0.2683
Epoch [2/8], Step[400/600], Loss:0.1387
Epoch [2/8], Step[600/600], Loss:0.1556
Epoch [3/8], Step[400/600], Loss:0.1144
Epoch [3/8], Step[600/600], Loss:0.0453
Epoch [4/8], Step[400/600], Loss:0.0736
Epoch [4/8], Step[600/600], Loss:0.1011
Epoch [5/8], Step[400/600], Loss:0.0828
Epoch [5/8], Step[600/600], Loss:0.1495
Epoch [6/8], Step[400/600], Loss:0.0540
Epoch [6/8], Step[600/600], Loss:0.0636
Epoch [7/8], Step[400/600], Loss:0.1969
Epoch [7/8], Step[600/600], Loss:0.0939
Epoch [8/8], Step[400/600], Loss:0.0451
Epoch [8/8], Step[600/600], Loss:0.0692


# 6-1. RNN Test

In [59]:
######## TEST ########
with torch.no_grad():
  correct = 0

  for image, label in test_loader:
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)
    output = rnn(image)
    _, pred = torch.max(output.data, 1)
    correct += (pred == label).sum().item()

  print(f'Test Accuracy of {rnn.model} model on the {len(test_data)} test images: {100 * correct / len(test_data)}%')

Test Accuracy of RNN model on the 10000 test images: 97.91%


# 4-2. LSTM Model Set Loss / Optimizer

In [60]:
learning_rate = 0.0002

lstm = Mnist(input_size, hidden_size, num_layers, num_classes, model='LSTM').to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)

# 5-2 LSTM Train

In [61]:
# LSTM

####### Train #######
total_step = len(train_loader)
for epoch in range(num_epochs):
  for i, (image, label) in enumerate(train_loader):
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)

    # Forward
    output = lstm(image)
    loss = criterion(output, label)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i+1) % 400 == 0 or (i+1) == total_step:
      print(f"Epoch [{epoch+1}/{num_epochs}], Step[{i+1}/{total_step}], Loss:{loss.item():.4f}")

Epoch [1/8], Step[400/600], Loss:0.8035
Epoch [1/8], Step[600/600], Loss:0.5112
Epoch [2/8], Step[400/600], Loss:0.2986
Epoch [2/8], Step[600/600], Loss:0.1427
Epoch [3/8], Step[400/600], Loss:0.1825
Epoch [3/8], Step[600/600], Loss:0.2486
Epoch [4/8], Step[400/600], Loss:0.1815
Epoch [4/8], Step[600/600], Loss:0.1275
Epoch [5/8], Step[400/600], Loss:0.0421
Epoch [5/8], Step[600/600], Loss:0.1701
Epoch [6/8], Step[400/600], Loss:0.0932
Epoch [6/8], Step[600/600], Loss:0.1385
Epoch [7/8], Step[400/600], Loss:0.0547
Epoch [7/8], Step[600/600], Loss:0.1600
Epoch [8/8], Step[400/600], Loss:0.1150
Epoch [8/8], Step[600/600], Loss:0.0661


# 6-2. LSTM Test

In [62]:
######## TEST ########
with torch.no_grad():
  correct = 0

  for image, label in test_loader:
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)
    output = lstm(image)
    _, pred = torch.max(output.data, 1)
    correct += (pred == label).sum().item()

  print(f'Test Accuracy of {lstm.model} model on the {len(test_data)} test images: {100 * correct / len(test_data)}%')

Test Accuracy of LSTM model on the 10000 test images: 97.7%


# 4-3. GRU Model Set Loss / Optimizer

In [63]:
learning_rate = 0.0001

gru = Mnist(input_size, hidden_size, num_layers, num_classes, model='GRU').to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gru.parameters(), lr=learning_rate)

# 5-3. GRU Train

In [64]:
# GRU

####### Train #######
total_step = len(train_loader)
for epoch in range(num_epochs):
  for i, (image, label) in enumerate(train_loader):
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)

    # Forward
    output = gru(image)
    loss = criterion(output, label)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i+1) % 400 == 0 or (i+1) == total_step:
      print(f"Epoch [{epoch+1}/{num_epochs}], Step[{i+1}/{total_step}], Loss:{loss.item():.4f}")

Epoch [1/8], Step[400/600], Loss:0.6124
Epoch [1/8], Step[600/600], Loss:0.4126
Epoch [2/8], Step[400/600], Loss:0.2633
Epoch [2/8], Step[600/600], Loss:0.2179
Epoch [3/8], Step[400/600], Loss:0.1759
Epoch [3/8], Step[600/600], Loss:0.0791
Epoch [4/8], Step[400/600], Loss:0.1760
Epoch [4/8], Step[600/600], Loss:0.1398
Epoch [5/8], Step[400/600], Loss:0.0949
Epoch [5/8], Step[600/600], Loss:0.0858
Epoch [6/8], Step[400/600], Loss:0.0325
Epoch [6/8], Step[600/600], Loss:0.0635
Epoch [7/8], Step[400/600], Loss:0.0232
Epoch [7/8], Step[600/600], Loss:0.0979
Epoch [8/8], Step[400/600], Loss:0.0703
Epoch [8/8], Step[600/600], Loss:0.0300


# 6-3. GRU Test

In [65]:
######## TEST ########
with torch.no_grad():
  correct = 0

  for image, label in test_loader:
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)
    output = gru(image)
    _, pred = torch.max(output.data, 1)
    correct += (pred == label).sum().item()

  print(f'Test Accuracy of {gru.model} model on the {len(test_data)} test images: {100 * correct / len(test_data)}%')

Test Accuracy of GRU model on the 10000 test images: 97.65%


# 7. Save Model

In [66]:
torch.save(rnn.state_dict(), 'rnn.pth')
torch.save(lstm.state_dict(), 'lstm.pth')
torch.save(gru.state_dict(), 'gru.pth')

# 8. Load Model and Test All

In [69]:
model1 = Mnist(input_size, hidden_size, num_layers, num_classes, model='RNN').to(device)
model1.load_state_dict(torch.load("rnn.pth"))

model2 = Mnist(input_size, hidden_size, num_layers, num_classes, model='LSTM').to(device)
model2.load_state_dict(torch.load("lstm.pth"))

model3 = Mnist(input_size, hidden_size, num_layers, num_classes, model='GRU').to(device)
model3.load_state_dict(torch.load("gru.pth"))

<All keys matched successfully>

In [70]:
######## TEST ########
with torch.no_grad():
  correct = 0

  for image, label in test_loader:
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)
    output = model1(image)
    _, pred = torch.max(output.data, 1)
    correct += (pred == label).sum().item()

  print(f'Test Accuracy of {model1.model} model on the {len(test_data)} test images: {100 * correct / len(test_data)}%')

######## TEST ########
with torch.no_grad():
  correct = 0

  for image, label in test_loader:
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)
    output = model2(image)
    _, pred = torch.max(output.data, 1)
    correct += (pred == label).sum().item()

  print(f'Test Accuracy of {model2.model} model on the {len(test_data)} test images: {100 * correct / len(test_data)}%')

######## TEST ########
with torch.no_grad():
  correct = 0

  for image, label in test_loader:
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)
    output = model3(image)
    _, pred = torch.max(output.data, 1)
    correct += (pred == label).sum().item()

  print(f'Test Accuracy of {model3.model} model on the {len(test_data)} test images: {100 * correct / len(test_data)}%')

Test Accuracy of RNN model on the 10000 test images: 97.92%
Test Accuracy of LSTM model on the 10000 test images: 97.67%
Test Accuracy of GRU model on the 10000 test images: 97.91%
