# LSTM, GRU
1. 기존 RNN과 다른 부분에 대해서 배우자
2. 이전 실습에 이어 다양한 적용법을 배우자

## 필요 패키지 import

In [1]:
from tqdm import tqdm
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import torch

## 데이터 전처리
아래의 sample data를 확인
이전 실습과 동일

In [2]:
vocab_size = 100
pad_id = 0

data = [
  [85,14,80,34,99,20,31,65,53,86,3,58,30,4,11,6,50,71,74,13],
  [62,76,79,66,32],
  [93,77,16,67,46,74,24,70],
  [19,83,88,22,57,40,75,82,4,46],
  [70,28,30,24,76,84,92,76,77,51,7,20,82,94,57],
  [58,13,40,61,88,18,92,89,8,14,61,67,49,59,45,12,47,5],
  [22,5,21,84,39,6,9,84,36,59,32,30,69,70,82,56,1],
  [94,21,79,24,3,86],
  [80,80,33,63,34,63],
  [87,32,79,65,2,96,43,80,85,20,41,52,95,50,35,96,24,80]
]

In [3]:
max_len = len(max(data, key = len))
print(f"Maximum sequence length: {max_len}")

valid_lens = []
for i, seq in enumerate(tqdm(data)):
  valid_lens.append(len(seq))
  if len(seq)<max_len:
    data[i] = seq + [pad_id] * (max_len-len(seq))
    

100%|██████████| 10/10 [00:00<00:00, 5636.75it/s]

Maximum sequence length: 20





In [4]:
batch = torch.LongTensor(data)
batch_lens = torch.LongTensor(valid_lens)

batch_lens, sorted_idx = batch_lens.sort(descending = True)
batch = batch[sorted_idx]

print(batch)
print(batch_lens)

tensor([[85, 14, 80, 34, 99, 20, 31, 65, 53, 86,  3, 58, 30,  4, 11,  6, 50, 71,
         74, 13],
        [58, 13, 40, 61, 88, 18, 92, 89,  8, 14, 61, 67, 49, 59, 45, 12, 47,  5,
          0,  0],
        [87, 32, 79, 65,  2, 96, 43, 80, 85, 20, 41, 52, 95, 50, 35, 96, 24, 80,
          0,  0],
        [22,  5, 21, 84, 39,  6,  9, 84, 36, 59, 32, 30, 69, 70, 82, 56,  1,  0,
          0,  0],
        [70, 28, 30, 24, 76, 84, 92, 76, 77, 51,  7, 20, 82, 94, 57,  0,  0,  0,
          0,  0],
        [19, 83, 88, 22, 57, 40, 75, 82,  4, 46,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [93, 77, 16, 67, 46, 74, 24, 70,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [94, 21, 79, 24,  3, 86,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [80, 80, 33, 63, 34, 63,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [62, 76, 79, 66, 32,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]])
tensor([2

## LSTM 사용
LSTM에선 cell state가 추가됩니다.
cell state의 shape는 hidden state의 그것과 동일합니다.

In [8]:
embedding_size = 256
hidden_size = 512
num_layers = 1
num_dirs = 1

embedding = nn.Embedding(vocab_size, embedding_size)
lstm = nn.LSTM(
    input_size = embedding_size, 
    hidden_size = hidden_size,
    num_layers = num_layers,
    bidirectional = True if num_dirs>1 else False
)

h_0 = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size))
c_0 = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size))

In [10]:
batch_emb = embedding(batch)

packed_batch = pack_padded_sequence(batch_emb.transpose(0, 1), batch_lens)

packed_outputs, (h_n, c_n) = lstm(packed_batch, (h_0, c_0))
print(packed_outputs)
print(packed_outputs[0].shape)
print(h_n.shape)
print(c_n.shape)

PackedSequence(data=tensor([[ 0.1140, -0.0754,  0.0414,  ...,  0.0184, -0.1301,  0.1099],
        [-0.0372, -0.0944,  0.1062,  ...,  0.0229,  0.0084,  0.0609],
        [-0.0436,  0.1051,  0.1014,  ..., -0.0444,  0.2490, -0.1016],
        ...,
        [ 0.0047,  0.1545,  0.2151,  ..., -0.1920,  0.0463,  0.0955],
        [ 0.1495, -0.0825, -0.1940,  ..., -0.0950,  0.0918,  0.0778],
        [ 0.1905, -0.0985, -0.1822,  ..., -0.0274,  0.0376,  0.1549]],
       grad_fn=<CatBackward>), batch_sizes=tensor([10, 10, 10, 10, 10,  9,  7,  7,  6,  6,  5,  5,  5,  5,  5,  4,  4,  3,
         1,  1]), sorted_indices=None, unsorted_indices=None)
torch.Size([123, 512])
torch.Size([1, 10, 512])
torch.Size([1, 10, 512])


In [11]:
outputs, output_lens = pad_packed_sequence(packed_outputs)
print(outputs.shape)
print(output_lens)

torch.Size([20, 10, 512])
tensor([20, 18, 18, 17, 15, 10,  8,  6,  6,  5])


## GRU 사용
GRU는 cell state가 없이 RNN과 동일하게 사용 가능
GRU를 이용하여 LM task를 수행

In [12]:
gru = nn.GRU(
    input_size = embedding_size, 
    hidden_size = hidden_size,
    num_layers = num_layers,
    bidirectional = True if num_dirs>1 else False
)

In [22]:
output_layer = nn.Linear(hidden_size, vocab_size)

In [23]:
input_id = batch.transpose(0, 1)[0,:]
hidden = torch.zeros((num_layers*num_dirs, batch.shape[0], hidden_size))

In [24]:
for t in range(max_len):
  input_emb = embedding(input_id).unsqueeze(0)
  output, hidden = gru(input_emb, hidden)

  output = output_layer(output)
  probs, top_id = torch.max(output, dim = -1)

  print("*"*50)
  print(f"Time step: {t}")
  print(output.shape)
  print(probs.shape)
  print(top_id.shape)

  input_id = top_id.squeeze(0)

**************************************************
Time step: 0
torch.Size([1, 10, 100])
torch.Size([1, 10])
torch.Size([1, 10])
**************************************************
Time step: 1
torch.Size([1, 10, 100])
torch.Size([1, 10])
torch.Size([1, 10])
**************************************************
Time step: 2
torch.Size([1, 10, 100])
torch.Size([1, 10])
torch.Size([1, 10])
**************************************************
Time step: 3
torch.Size([1, 10, 100])
torch.Size([1, 10])
torch.Size([1, 10])
**************************************************
Time step: 4
torch.Size([1, 10, 100])
torch.Size([1, 10])
torch.Size([1, 10])
**************************************************
Time step: 5
torch.Size([1, 10, 100])
torch.Size([1, 10])
torch.Size([1, 10])
**************************************************
Time step: 6
torch.Size([1, 10, 100])
torch.Size([1, 10])
torch.Size([1, 10])
**************************************************
Time step: 7
torch.Size([1, 10, 100])
torch.Si

## 양방향 및 여러 layer 사용
이번엔 양방향 +2개 이상의 layer를 쓸 때 얻을 수 있는 결과에 대해 알아보자.

In [25]:
num_layers = 2
num_dirs = 2
dropout = 0.1

gru = nn.GRU(
    input_size = embedding_size, 
    hidden_size = hidden_size,
    num_layers = num_layers,
    dropout = dropout,
    bidirectional = True if num_dirs>1 else False
)

bidirectional이 되었고 layer의 개수가 2로 늘었기 때문에 hidden state의 shape도 `(4, B, d_h)`가 됨

In [26]:
batch_emb = embedding(batch)
h_0 = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size))

packed_batch = pack_padded_sequence(batch_emb.transpose(0, 1), batch_lens)

packed_outputs, h_n = gru(packed_batch, h_0)
print(packed_outputs)
print(packed_outputs[0].shape)
print(h_n.shape)

PackedSequence(data=tensor([[-0.0616,  0.1419, -0.0419,  ...,  0.1286,  0.1644, -0.0350],
        [-0.0446, -0.0617, -0.0475,  ...,  0.1917, -0.1740,  0.1693],
        [ 0.0255, -0.0287, -0.0590,  ..., -0.0269, -0.0133, -0.0027],
        ...,
        [-0.0650,  0.0372,  0.0685,  ...,  0.0623,  0.0891,  0.0533],
        [ 0.1509,  0.0227,  0.0581,  ...,  0.1506, -0.0588,  0.0192],
        [ 0.1018, -0.0215,  0.1022,  ...,  0.1361, -0.0378, -0.0353]],
       grad_fn=<CatBackward>), batch_sizes=tensor([10, 10, 10, 10, 10,  9,  7,  7,  6,  6,  5,  5,  5,  5,  5,  4,  4,  3,
         1,  1]), sorted_indices=None, unsorted_indices=None)
torch.Size([123, 1024])
torch.Size([4, 10, 512])


In [28]:
outputs, output_lens = pad_packed_sequence(packed_outputs)

print(outputs.shape)
print(output_lens)

torch.Size([20, 10, 1024])
tensor([20, 18, 18, 17, 15, 10,  8,  6,  6,  5])


In [30]:
batch_size = h_n.shape[1]
print(h_n.view(num_layers, num_dirs, batch_size, hidden_size))
print(h_n.view(num_layers, num_dirs, batch_size, hidden_size).shape)

tensor([[[[ 0.4368, -0.3862, -0.0914,  ..., -0.4564, -0.0240, -0.0892],
          [ 0.0985, -0.2560, -0.0031,  ..., -0.0098, -0.1318, -0.3166],
          [ 0.2565, -0.3446,  0.0408,  ...,  0.1227,  0.2785,  0.0633],
          ...,
          [-0.2482,  0.0838, -0.2070,  ..., -0.2314,  0.0856,  0.2959],
          [-0.3201,  0.0730,  0.0147,  ..., -0.6137,  0.0900, -0.0028],
          [-0.0086, -0.0173, -0.1789,  ..., -0.3638,  0.1985,  0.1712]],

         [[-0.1861,  0.2234,  0.5400,  ..., -0.0256, -0.0860,  0.2350],
          [-0.0284, -0.0254,  0.0959,  ..., -0.1586,  0.0432,  0.3732],
          [-0.0878,  0.3232, -0.0517,  ..., -0.2300, -0.3331,  0.3021],
          ...,
          [-0.5691,  0.1697,  0.4559,  ...,  0.0041, -0.0519,  0.1329],
          [ 0.0297, -0.1635,  0.2470,  ..., -0.0777, -0.3214, -0.0565],
          [-0.0101,  0.2400,  0.4176,  ..., -0.2866, -0.4759,  0.0602]]],


        [[[ 0.1018, -0.0215,  0.1022,  ...,  0.1940,  0.1381, -0.0305],
          [ 0.1600,  0.1439,