<a href="https://colab.research.google.com/github/ysj9909/DL_practice_from_scratch/blob/main/LSTM_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**MNIST classification using LSTM from scratch - code practice**

In [4]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import copy

In [63]:
# Hyper-parameters
num_epochs = 5
learning_rate = 0.001
batch_size = 100

In [64]:
class LSTM_scratch(nn.Module):
  def __init__(self, input_dim = 28, sequence_length = 28, hidden_dim = 128, output_dim = 10):
    self.W_f = nn.init.xavier_normal_(torch.empty(input_dim + hidden_dim, hidden_dim))
    self.b_f = torch.zeros(hidden_dim)
    self.W_i = nn.init.xavier_normal_(torch.empty(input_dim + hidden_dim, hidden_dim))
    self.b_i = torch.zeros(hidden_dim)
    self.W_o = nn.init.xavier_normal_(torch.empty(input_dim + hidden_dim, hidden_dim))
    self.b_o = torch.zeros(hidden_dim)
    self.W_C = nn.init.xavier_normal_(torch.empty(input_dim + hidden_dim, hidden_dim))
    self.b_C = torch.zeros(hidden_dim)
    
    self.W_out = nn.init.xavier_normal_(torch.empty(hidden_dim, output_dim))
    self.b_out = torch.zeros(output_dim)

    self.hidden_dim = hidden_dim
    self.seq_length = sequence_length
    self.input_dim = input_dim

  def forward_t(self, imgs, targets, lr = learning_rate, train = True):
    # imgs : (batch_size, 1, sequenc_length, input_dim)
    # targets : (batch_size, )
    batch_size = imgs.size(0)

    h = torch.zeros(batch_size, self.seq_length + 1, self.hidden_dim)
    c = torch.zeros(batch_size, self.seq_length + 1, self.hidden_dim)
    inputs = torch.zeros(batch_size, self.seq_length, self.hidden_dim + self.input_dim)
    f_gates = torch.zeros(batch_size, self.seq_length, self.hidden_dim)
    i_gates = torch.zeros_like(f_gates)
    o_gates = torch.zeros_like(f_gates)
    c_s = torch.zeros_like(f_gates)
    

    for i in range(self.seq_length ):
      inputs[:, i, :] = torch.cat([h[:, i, :], imgs[:, 0, i, :]], dim = -1)
      f_gates[:, i, :] = torch.sigmoid(torch.mm(inputs[:, i, :], self.W_f) + self.b_f)
      i_gates[:, i, :]= torch.sigmoid(torch.mm(inputs[:, i, :], self.W_i) + self.b_i)
      o_gates[:, i, :] = torch.sigmoid(torch.mm(inputs[:, i, :], self.W_o) + self.b_o)
      c_s[:, i, :] = torch.tanh(torch.mm(inputs[:, i, :], self.W_C) + self.b_C)
      c[:, i + 1, :] = f_gates[:, i, :] * c[:, i, :] + i_gates[:, i, :] * c_s[:, i, :]
      h[:, i + 1, :] = o_gates[:, i, :] * torch.tanh(c[:, i + 1, :])
    
    outputs = torch.mm(h[:, -1, :], self.W_out) + self.b_out

    if train:
      exp_out = torch.exp(outputs)
      softmax = exp_out / torch.sum(exp_out, dim = -1, keepdim = True)
      loss = 0
      dsoftmax = copy.deepcopy(softmax)
      for i, idx in enumerate(targets):
        dsoftmax[i, idx] -= 1
        loss -= torch.log(softmax[i, idx])
      loss /= batch_size
      
      db_out = torch.sum(dsoftmax, dim = 0)
      dW_out = torch.mm(h[:, -1, :].T, dsoftmax)
      dW_f = 0
      dW_i = 0
      dW_o = 0
      dW_C = 0
      db_f = 0
      db_i = 0
      db_o = 0
      db_C = 0
      dc_ = 0
      dh = torch.zeros_like(h)
      dc = torch.zeros_like(c)
    
      
      dh[:, -1, :] = torch.mm(dsoftmax, self.W_out.T)
      for i in range(-1, -self.seq_length - 1, -1):
        do_gate = dh[:, i, :] * torch.tanh(c[:, i, :])
        do_gate_in = do_gate * o_gates[:, i, :] * (1 - o_gates[:, i, :])
        dh[:, i - 1, :] += torch.mm(do_gate_in, self.W_o.T)[:, :self.hidden_dim]      
        dW_o  += torch.mm(inputs[:, i, :].T, do_gate_in)
        db_o += torch.sum(do_gate_in, dim = 0)

        dc[:, i, :] += dh[:, i, :] * o_gates[:, i, :] * (1 - torch.tanh(c[:, i, :]) ** 2)
        
        df_gate = dc[:, i, :] * c[:, i - 1, :]
        df_gate_in = df_gate * f_gates[:, i, :] * (1 - f_gates[:, i, :])
        dh[:, i - 1, :] += torch.mm(df_gate_in, self.W_f.T)[:, :self.hidden_dim]
        dW_f += torch.mm(inputs[:, i, :].T, df_gate_in)
        db_f += torch.sum(df_gate_in, dim = 0)

        dc[:, i - 1, : ] += dc[:, i, :] * f_gates[:, i, :]

        di_gate = dc[:, i, :] * c_s[:, i, :]
        di_gate_in = di_gate * i_gates[:, i, :] * (1 - i_gates[:, i, :])
        dh[:, i - 1, :] += torch.mm(di_gate_in, self.W_i.T)[:, :self.hidden_dim]
        dW_i += torch.mm(inputs[:, i, :].T, di_gate_in)
        db_i += torch.sum(di_gate_in, dim = 0)

        dc_ += dc[:, i, :] * i_gates[:, i, :]
        dc_in = dc_ * (1 - c_s[:, i, :] ** 2)
        dh[:, i -1 , :] += torch.mm(dc_in, self.W_C.T)[:, :self.hidden_dim]
        dW_C += torch.mm(inputs[:, i, :].T, dc_in)
        db_C += torch.sum(dc_in, dim = 0)
      
      # SGD
      self.W_f -= lr * dW_f
      self.W_i -= lr * dW_i
      self.W_o -= lr * dW_o
      self.W_out -= lr * dW_out
      self.W_C -= lr * dW_C
      self.b_f -= lr * db_f
      self.b_i -= lr * db_i
      self.b_o -= lr * db_o
      self.b_out -= lr * db_out
      self.b_C -= lr * db_C

      return outputs, loss
    return outputs

In [65]:
model_scratch = LSTM_scratch()

In [66]:
for epoch in range(num_epochs):
  for i, (imgs, labels) in enumerate(train_loader):
    _, loss = model_scratch.forward_t(imgs, labels)

    if (i + 1) % 100 == 0:
      print(f"Epoch[{epoch + 1} / {num_epochs}], Step [{i + 1} / {len(train_loader)}], Loss : {loss.item()}")

Epoch[1 / 5], Step [100 / 600], Loss : 2.308032751083374
Epoch[1 / 5], Step [200 / 600], Loss : 2.2962279319763184
Epoch[1 / 5], Step [300 / 600], Loss : 2.21468186378479
Epoch[1 / 5], Step [400 / 600], Loss : 2.0790462493896484
Epoch[1 / 5], Step [500 / 600], Loss : 2.0688376426696777
Epoch[1 / 5], Step [600 / 600], Loss : 1.7364228963851929
Epoch[2 / 5], Step [100 / 600], Loss : 1.6522705554962158
Epoch[2 / 5], Step [200 / 600], Loss : 1.679588794708252
Epoch[2 / 5], Step [300 / 600], Loss : 1.3982651233673096
Epoch[2 / 5], Step [400 / 600], Loss : 1.1255719661712646
Epoch[2 / 5], Step [500 / 600], Loss : 1.442572832107544
Epoch[2 / 5], Step [600 / 600], Loss : 1.1934150457382202
Epoch[3 / 5], Step [100 / 600], Loss : 1.1459053754806519
Epoch[3 / 5], Step [200 / 600], Loss : 1.2142002582550049
Epoch[3 / 5], Step [300 / 600], Loss : 1.4791487455368042
Epoch[3 / 5], Step [400 / 600], Loss : 2.148773431777954
Epoch[3 / 5], Step [500 / 600], Loss : 1.8318679332733154
Epoch[3 / 5], Step [

In [68]:
correct = 0
total = 0
for imgs, labels in test_loader:
  outputs, _ = model_scratch.forward_t(imgs,labels)
  _, predicted = torch.max(outputs, dim = -1)
  correct += (labels == predicted).sum().item()
  total += imgs.size(0)

print(f"Accuracy of the model from scratch : {100 * correct / total}")

Accuracy of the model from scratch : 67.04


**Torch Optimizer를 이용한 학습과 비교**

In [7]:
data_transforms = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5), (0.5))])

train_dataset = torchvision.datasets.MNIST(root = "./data", train = True, download = True, transform = data_transforms)
test_dataset = torchvision.datasets.MNIST(root = "./data", train = False, transform = data_transforms)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [8]:
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = False)

In [13]:
class LSTM_Torch(nn.Module):
  def __init__(self, input_dim = 28, hidden_dim = 128, sequence_length = 28, num_classes = 10):
    super(LSTM_Torch, self).__init__()
    
    self.sequence_length = sequence_length
    self.hidden_dim = hidden_dim

    self.W_f = nn.Linear(input_dim + hidden_dim, hidden_dim)
    self.W_i = nn.Linear(input_dim + hidden_dim, hidden_dim)
    self.W_o = nn.Linear(input_dim + hidden_dim, hidden_dim)
    self.W_C = nn.Linear(input_dim + hidden_dim, hidden_dim)

    self.fc_output = nn.Linear(hidden_dim, num_classes)
    self.tanh = nn.Tanh()

  def forward(self, imgs):
    # imgs : (batch_size, 1, sequence_length, input_dim)
    h = torch.zeros(imgs.size(0), self.hidden_dim)
    c = torch.zeros(imgs.size(0), self.hidden_dim)
    for i in range(self.sequence_length):
      input = torch.cat([h, imgs[:,0, i, :]], dim = -1)   # (batch_size, input_dim + hidden_dim)

      input_gate = torch.sigmoid(self.W_i(input))
      forget_gate = torch.sigmoid(self.W_f(input))
      output_gate = torch.sigmoid(self.W_o(input))

      C_tilda = self.tanh(self.W_C(input))
      c = forget_gate * c + input_gate * C_tilda
      h = output_gate * self.tanh(c)
    
    outputs = self.fc_output(h)

    return outputs

In [14]:
model  = LSTM_Torch()

In [30]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

In [31]:
# Train the model
for epoch in range(num_epochs):
  for i, (imgs, labels) in enumerate(train_loader):
    outputs  = model(imgs)

    optimizer.zero_grad()
    loss = criterion(outputs, labels)

    loss.backward()
    optimizer.step()

    if (i + 1) % 100 == 0:
      print(f"Epoch [{epoch + 1} / {num_epochs}], Step [{i + 1} / {len(train_loader)}], Loss : {loss.item()}")


Epoch [1 / 5], Step [100 / 600], Loss : 2.210110664367676
Epoch [1 / 5], Step [200 / 600], Loss : 1.8703222274780273
Epoch [1 / 5], Step [300 / 600], Loss : 1.4303443431854248
Epoch [1 / 5], Step [400 / 600], Loss : 1.052319049835205
Epoch [1 / 5], Step [500 / 600], Loss : 0.8794002532958984
Epoch [1 / 5], Step [600 / 600], Loss : 0.869049072265625
Epoch [2 / 5], Step [100 / 600], Loss : 0.40183424949645996
Epoch [2 / 5], Step [200 / 600], Loss : 0.5404083132743835
Epoch [2 / 5], Step [300 / 600], Loss : 0.4130474925041199
Epoch [2 / 5], Step [400 / 600], Loss : 0.38171806931495667
Epoch [2 / 5], Step [500 / 600], Loss : 0.26108217239379883
Epoch [2 / 5], Step [600 / 600], Loss : 0.30149534344673157
Epoch [3 / 5], Step [100 / 600], Loss : 0.18814486265182495
Epoch [3 / 5], Step [200 / 600], Loss : 0.1365203708410263
Epoch [3 / 5], Step [300 / 600], Loss : 0.2682044506072998
Epoch [3 / 5], Step [400 / 600], Loss : 0.2381170690059662
Epoch [3 / 5], Step [500 / 600], Loss : 0.150743186473

In [32]:
# Test the model
with torch.no_grad():
  correct = 0
  total = 0
  for imgs, labels in test_loader:

    outputs = model(imgs)
    _, predicted = torch.max(outputs, dim = -1)

    correct += (predicted == labels).sum().item()
    total += imgs.size(0)

  print(f"Accuracy of the model on the test data : {round(100 * correct / total, 3)} ")



Accuracy of the model on the test data : 97.0 
