# **Recurrent Neural Networks - 필수 과제**

**LSTM**을 구현해봅시다!
<br><br><br>
**필요 사전 지식**:

- <u>PyTorch</u> (선택 과제 1)

<br>

**추가 사전 지식**: (알면 좋으나 몰라도 괜찮음)

- <u>Tokenization</u>, <u>Word Embedding</u> (선택 과제 2)

<br><br><br><br><br>

In [3]:
!pip install transformers
!pip install datasets

Collecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m86.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m71.8 MB/s[0m eta [36m0:00:0

In [4]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset

from tqdm import tqdm

<br><br>

[Hugging Face](https://huggingface.co)에서 [Rotten Tomatoes dataset](https://huggingface.co/datasets/rotten_tomatoes)과 [pretrained BERT](https://huggingface.co/bert-base-uncased)의 tokenizer를 가져오겠습니다.

또 학습 부담을 줄이기 위해 pretrained BERT에 내장된 word embedding layer의 weight도 가져옵시다.

In [5]:
# https://huggingface.co/datasets/rotten_tomatoes
dataset = load_dataset("rotten_tomatoes")

# https://huggingface.co/bert-base-uncased
pretrained_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
pretrained_embeddings = AutoModel.from_pretrained("bert-base-uncased").embeddings.word_embeddings

Downloading builder script:   0%|          | 0.00/5.03k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.02k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/488k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8530 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1066 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1066 [00:00<?, ? examples/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

<br><br>

기본 BERT는 token을 768차원 벡터로 embedding합니다. 우리의 작은 dataset과 작은 모델에게 768차원은 부담스러우니 PCA를 사용해 64차원으로 줄여줍시다.

In [6]:
nano_embed = torch.pca_lowrank(pretrained_embeddings.weight.detach(), q=64)[0]

<br><br>

그런데 무작정 64차원으로 줄여도 되는 걸까요? BERT의 d_model이 괜히 768도 아닐 테고, 정보의 손실이 아주 클 것 같은데 말입니다.

궁금하니 코사인 유사도로 축소된 embedding layer에 token들의 정보가 그럭저럭 잘 남아있는지 확인해봅시다.

In [7]:
cos = (nano_embed @ nano_embed.T) / (nano_embed.abs() @ nano_embed.abs().T)

In [8]:
# word에 다양한 값을 넣어보세요! tokenizer의 vocab에 없는 token에 대해서는 빈 list가 뜹니다.
word = "jackson"

([*map(pretrained_tokenizer.decode, cos[pretrained_tokenizer.vocab[word]].argsort(descending=True)[1:21])] if word in pretrained_tokenizer.vocab else [])

['la',
 'mississippi',
 '##gon',
 'historic',
 'detroit',
 'basketball',
 'narrative',
 '2016',
 'owen',
 'michigan',
 'jake',
 'dawson',
 'by',
 '2015',
 'realized',
 'digital',
 'though',
 'them',
 'but',
 'hall']

꽤 잘 남아있는 것 같습니다.

(TMI: 조금 더 욕심을 부려 한번 32차원으로 줄여보면 무시하기 어려운 정보의 손실을 체감할 수 있습니다.)

<br><br>

이제 LSTM을 구현합시다! 사실 원래 BiLSTM으로 하려고 했는데 underfitting이 심해서 그냥 plain LSTM으로 준비했습니다.

<br><br><br><br>
#### <span style="color:red">**<u>Q1.</u>**</span>

`class LSTMCell`의 빈칸을 채우세요.

In [9]:
class LSTMCell(nn.Module):
    def __init__(self, d_x, d_h): # d_x: x의 차원수 (scalar int)
                                  # d_h: h의 차원수 (scalar int)
        super().__init__()
        d_stack = d_x + d_h
        ######################### START OF YOUR CODE #########################

        self.W_f = nn.Linear(d_stack, d_h)
        self.W_i = nn.Linear(d_stack, d_h)
        self.W_C = nn.Linear(d_stack, d_h)
        self.W_o = nn.Linear(d_stack, d_h)

        dim1 = d_stack
        dim2 = d_h
        dim3 = d_stack
        dim4 = d_h
        dim5 = d_stack
        dim6 = d_h

        ########################## END OF YOUR CODE ##########################
        self.W_f = nn.Linear(d_stack, d_h)
        self.W_i = nn.Linear(dim1, dim2)
        self.W_C = nn.Linear(dim3, dim4)
        self.W_o = nn.Linear(dim5, dim6)


    # forward는 t-1의 h_{t-1}, C_{t-1}과 t의 x_t를 입력으로 받아 계산합니다.

    def forward(self, x, h, C): # x: x_t
                                # h: h_{h-1}
                                # C: C_{t-1}
        stack = torch.cat([x, h])
        ######################### START OF YOUR CODE #########################

        f = torch.sigmoid(self.W_f(stack))
        i = torch.sigmoid(self.W_i(stack))
        C_ =  self.W_C(stack).tanh()

        C_t = f * C + i * C_

        o = torch.sigmoid(self.W_o(stack))
        h_t = o * torch.tanh(C_t)

        ########################## END OF YOUR CODE ##########################
        return h_t, C_t

In [10]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, d_out, pretrained_embeddings):
        super().__init__()
        vocab_size = pretrained_embeddings.shape[0]
        d_h = d_model = pretrained_embeddings.shape[1]

        self.embed = nn.Embedding(vocab_size, d_model, _weight=pretrained_embeddings.clone()) # word embedding layer
        self.cell = LSTMCell(d_x=d_model, d_h=d_h) # LSTM cell
        self.out = nn.Linear(d_h, d_out, bias=True) # output layer

        self.h_init = nn.Parameter(torch.zeros(d_h), requires_grad=False) # initial h
        self.C_init = nn.Parameter(torch.zeros(d_h), requires_grad=False) # initial C

    def forward(self, input_ids):
        embedded = self.embed(input_ids).squeeze()

        h = self.h_init.clone() # h 초기화
        C = self.C_init.clone() # C 초기화
        for x in embedded:
            h, C = self.cell(x, h, C) # iterate over embedded sequence

        return self.out(h).squeeze() # last hidden state를 output layer에 통과시킨 값을 반환

<br><br><br><br>
#### <span style="color:red">**<u>Q2.</u>**</span>

Test accuracy가 0.7 이상이 되도록 모델을 훈련시키세요.

In [15]:
######################### START OF YOUR CODE #########################

# 필요에 따라 바꿔도 됩니다.
device = "cuda"

########################## END OF YOUR CODE ##########################

model = LSTM(vocab_size=pretrained_tokenizer.vocab_size, d_out=1, pretrained_embeddings=nano_embed).to(device)

In [16]:
######################### START OF YOUR CODE #########################

# learning rate을 적절히 수정해보세요.
lr = 1e-3

########################## END OF YOUR CODE ##########################

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [17]:
train_loader = DataLoader(dataset["train"], shuffle=True)

In [19]:
######################### START OF YOUR CODE #########################

# 필요에 따라 바꿔도 됩니다.
num_print = 100
num_batch = 5

########################## END OF YOUR CODE ##########################


# train

save_l = 0
optimizer.zero_grad()
for i, data in enumerate(tqdm(train_loader)):
    text, label = data["text"][0], data["label"][0]
    input_ids = pretrained_tokenizer.encode(text, return_tensors="pt").to(device)
    y_pred = model(input_ids)

    label = label.to(device) * 1.
    loss = criterion(y_pred, label)
    loss.backward()

    if not (i+1)%num_batch:
        optimizer.step()
        optimizer.zero_grad()

    save_l += loss.item()
    if not (i+1)%num_print:
        print(f"{i+1:>5} iter: {save_l/num_print}")
        save_l = 0

  1%|          | 102/8530 [00:06<07:37, 18.43it/s]

  100 iter: 0.6908905136585236


  2%|▏         | 204/8530 [00:08<03:47, 36.67it/s]

  200 iter: 0.692317972779274


  4%|▎         | 307/8530 [00:11<03:15, 42.15it/s]

  300 iter: 0.6955159133672715


  5%|▍         | 404/8530 [00:13<03:44, 36.14it/s]

  400 iter: 0.6931998366117478


  6%|▌         | 507/8530 [00:16<03:06, 42.96it/s]

  500 iter: 0.6924934542179108


  7%|▋         | 604/8530 [00:20<04:38, 28.49it/s]

  600 iter: 0.6906831341981888


  8%|▊         | 703/8530 [00:23<03:28, 37.58it/s]

  700 iter: 0.6972316181659699


  9%|▉         | 805/8530 [00:25<03:32, 36.42it/s]

  800 iter: 0.691937215924263


 11%|█         | 903/8530 [00:28<03:09, 40.35it/s]

  900 iter: 0.690854794383049


 12%|█▏        | 1003/8530 [00:31<04:34, 27.46it/s]

 1000 iter: 0.6943992245197296


 13%|█▎        | 1106/8530 [00:35<03:48, 32.53it/s]

 1100 iter: 0.6842437314987183


 14%|█▍        | 1205/8530 [00:37<03:41, 33.07it/s]

 1200 iter: 0.672454991042614


 15%|█▌        | 1304/8530 [00:40<03:11, 37.76it/s]

 1300 iter: 0.6652621346712112


 16%|█▋        | 1400/8530 [00:43<03:58, 29.93it/s]

 1400 iter: 0.6521550276875496


 18%|█▊        | 1505/8530 [00:47<04:59, 23.48it/s]

 1500 iter: 0.6090054339170456


 19%|█▉        | 1605/8530 [00:50<02:59, 38.51it/s]

 1600 iter: 0.6639975306391716


 20%|█▉        | 1702/8530 [00:52<02:39, 42.75it/s]

 1700 iter: 0.654328342974186


 21%|██        | 1805/8530 [00:55<02:50, 39.53it/s]

 1800 iter: 0.6951161147654057


 22%|██▏       | 1903/8530 [00:58<02:45, 40.04it/s]

 1900 iter: 0.5832218547165394


 24%|██▎       | 2005/8530 [01:02<04:39, 23.35it/s]

 2000 iter: 0.5799493651092053


 25%|██▍       | 2103/8530 [01:04<02:56, 36.33it/s]

 2100 iter: 0.6361083556711674


 26%|██▌       | 2205/8530 [01:07<02:41, 39.08it/s]

 2200 iter: 0.5856978644430637


 27%|██▋       | 2306/8530 [01:09<02:30, 41.32it/s]

 2300 iter: 0.6444257663190365


 28%|██▊       | 2403/8530 [01:12<02:29, 41.09it/s]

 2400 iter: 0.6248102267086506


 29%|██▉       | 2506/8530 [01:16<04:01, 24.97it/s]

 2500 iter: 0.596658306568861


 31%|███       | 2609/8530 [01:18<02:11, 45.08it/s]

 2600 iter: 0.5309760695695878


 32%|███▏      | 2707/8530 [01:21<02:14, 43.45it/s]

 2700 iter: 0.5049080022424459


 33%|███▎      | 2805/8530 [01:23<02:15, 42.21it/s]

 2800 iter: 0.5725296476483345


 34%|███▍      | 2907/8530 [01:26<02:27, 38.19it/s]

 2900 iter: 0.6453920075297356


 35%|███▌      | 3005/8530 [01:30<03:08, 29.24it/s]

 3000 iter: 0.5832522836327553


 36%|███▋      | 3106/8530 [01:32<02:22, 38.00it/s]

 3100 iter: 0.5005799189954996


 38%|███▊      | 3208/8530 [01:35<02:14, 39.57it/s]

 3200 iter: 0.602520526573062


 39%|███▊      | 3305/8530 [01:37<02:17, 37.91it/s]

 3300 iter: 0.5788698254525662


 40%|███▉      | 3404/8530 [01:40<01:53, 45.36it/s]

 3400 iter: 0.6114147171378136


 41%|████      | 3504/8530 [01:44<03:11, 26.26it/s]

 3500 iter: 0.5508464317023755


 42%|████▏     | 3608/8530 [01:46<01:59, 41.22it/s]

 3600 iter: 0.5996473225951194


 43%|████▎     | 3705/8530 [01:49<01:57, 40.96it/s]

 3700 iter: 0.5030214154720306


 45%|████▍     | 3804/8530 [01:51<02:23, 32.82it/s]

 3800 iter: 0.611012642160058


 46%|████▌     | 3903/8530 [01:54<02:13, 34.59it/s]

 3900 iter: 0.5556478096544742


 47%|████▋     | 4007/8530 [01:58<02:18, 32.60it/s]

 4000 iter: 0.5406424234807491


 48%|████▊     | 4107/8530 [02:00<01:53, 38.86it/s]

 4100 iter: 0.6520971815288067


 49%|████▉     | 4205/8530 [02:03<01:49, 39.33it/s]

 4200 iter: 0.5464426547288894


 50%|█████     | 4306/8530 [02:06<01:54, 36.86it/s]

 4300 iter: 0.46208203487098215


 52%|█████▏    | 4402/8530 [02:09<02:48, 24.47it/s]

 4400 iter: 0.5581745383143425


 53%|█████▎    | 4505/8530 [02:13<01:44, 38.63it/s]

 4500 iter: 0.5288301074504852


 54%|█████▍    | 4607/8530 [02:15<01:41, 38.70it/s]

 4600 iter: 0.5770467602461576


 55%|█████▌    | 4705/8530 [02:18<01:52, 33.94it/s]

 4700 iter: 0.5945577459782362


 56%|█████▋    | 4806/8530 [02:21<01:36, 38.40it/s]

 4800 iter: 0.560080293416977


 57%|█████▋    | 4901/8530 [02:24<02:53, 20.87it/s]

 4900 iter: 0.5148062076419592


 59%|█████▊    | 5007/8530 [02:27<01:25, 41.04it/s]

 5000 iter: 0.4805477052181959


 60%|█████▉    | 5107/8530 [02:30<01:22, 41.31it/s]

 5100 iter: 0.5214091904461384


 61%|██████    | 5204/8530 [02:32<01:23, 39.70it/s]

 5200 iter: 0.6387228964269162


 62%|██████▏   | 5304/8530 [02:35<01:18, 41.33it/s]

 5300 iter: 0.5641566333174706


 63%|██████▎   | 5402/8530 [02:39<02:30, 20.73it/s]

 5400 iter: 0.5497177828848362


 65%|██████▍   | 5509/8530 [02:42<01:17, 38.92it/s]

 5500 iter: 0.5451958326995373


 66%|██████▌   | 5604/8530 [02:44<01:10, 41.40it/s]

 5600 iter: 0.5024213564395904


 67%|██████▋   | 5705/8530 [02:47<01:14, 37.99it/s]

 5700 iter: 0.559044220149517


 68%|██████▊   | 5802/8530 [02:49<01:09, 39.10it/s]

 5800 iter: 0.5667527135461569


 69%|██████▉   | 5903/8530 [02:53<02:07, 20.56it/s]

 5900 iter: 0.5885455860197544


 70%|███████   | 6006/8530 [02:56<00:57, 43.77it/s]

 6000 iter: 0.5332659043371677


 72%|███████▏  | 6107/8530 [02:59<01:07, 35.82it/s]

 6100 iter: 0.49147191748023034


 73%|███████▎  | 6205/8530 [03:01<01:02, 37.08it/s]

 6200 iter: 0.5556527596712112


 74%|███████▍  | 6302/8530 [03:04<01:02, 35.45it/s]

 6300 iter: 0.6113109631091356


 75%|███████▌  | 6406/8530 [03:08<01:10, 30.03it/s]

 6400 iter: 0.551787471473217


 76%|███████▌  | 6503/8530 [03:10<00:58, 34.47it/s]

 6500 iter: 0.5152902472764254


 77%|███████▋  | 6603/8530 [03:13<01:00, 31.59it/s]

 6600 iter: 0.6107702668756246


 79%|███████▊  | 6703/8530 [03:16<00:46, 39.24it/s]

 6700 iter: 0.49105853237211705


 80%|███████▉  | 6804/8530 [03:19<01:00, 28.30it/s]

 6800 iter: 0.541813224479556


 81%|████████  | 6904/8530 [03:23<00:42, 37.85it/s]

 6900 iter: 0.5443097711354494


 82%|████████▏ | 7006/8530 [03:25<00:42, 36.06it/s]

 7000 iter: 0.4794328175485134


 83%|████████▎ | 7103/8530 [03:28<00:42, 33.56it/s]

 7100 iter: 0.5666265635564923


 84%|████████▍ | 7205/8530 [03:31<00:35, 37.61it/s]

 7200 iter: 0.612522789761424


 86%|████████▌ | 7302/8530 [03:34<00:45, 27.00it/s]

 7300 iter: 0.5836378903687001


 87%|████████▋ | 7402/8530 [03:37<00:34, 32.40it/s]

 7400 iter: 0.4853319113701582


 88%|████████▊ | 7508/8530 [03:40<00:27, 36.51it/s]

 7500 iter: 0.58140515178442


 89%|████████▉ | 7607/8530 [03:43<00:22, 41.14it/s]

 7600 iter: 0.4886979480087757


 90%|█████████ | 7706/8530 [03:46<00:21, 38.08it/s]

 7700 iter: 0.5525704152882099


 92%|█████████▏| 7805/8530 [03:49<00:28, 25.72it/s]

 7800 iter: 0.4526114536821842


 93%|█████████▎| 7905/8530 [03:52<00:13, 46.64it/s]

 7900 iter: 0.5519016380235553


 94%|█████████▍| 8007/8530 [03:55<00:13, 39.86it/s]

 8000 iter: 0.4857998421415687


 95%|█████████▌| 8104/8530 [03:57<00:11, 38.51it/s]

 8100 iter: 0.5544625057280064


 96%|█████████▌| 8208/8530 [04:00<00:06, 47.25it/s]

 8200 iter: 0.4525121823698282


 97%|█████████▋| 8301/8530 [04:03<00:10, 21.79it/s]

 8300 iter: 0.5825672331638634


 99%|█████████▊| 8409/8530 [04:06<00:02, 45.51it/s]

 8400 iter: 0.5043963022902608


100%|█████████▉| 8508/8530 [04:09<00:00, 37.85it/s]

 8500 iter: 0.491114164609462


100%|██████████| 8530/8530 [04:09<00:00, 34.15it/s]


In [20]:
test_loader = DataLoader(dataset["test"], shuffle=True)


# test

res = torch.tensor(0)
with torch.no_grad():
    for i, data in enumerate(tqdm(test_loader)):
        text, label = data["text"][0], data["label"][0]
        input_ids = pretrained_tokenizer.encode(text, return_tensors="pt").to(device)
        y_pred = model(input_ids)
        res += ((1 if y_pred > 0 else 0) == label)

print("Test accuracy:", res.item() / dataset["test"].num_rows)

100%|██████████| 1066/1066 [00:10<00:00, 102.92it/s]

Test accuracy: 0.7363977485928705





In [21]:
# 관찰용
# n 값을 바꿔가며 훈련시킨 모델의 예측값을 구경해보세요
n = 123

print(dataset["test"][n])
with torch.no_grad():
    print(model(pretrained_tokenizer.encode(dataset["test"][n]["text"], return_tensors="pt").to(device)).sigmoid().item())

{'text': "ana's journey is not a stereotypical one of self-discovery , as she's already comfortable enough in her own skin to be proud of her rubenesque physique . . .", 'label': 1}
0.639428973197937
