<a href="https://colab.research.google.com/github/tomonari-masada/course2025-nlp/blob/main/08_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LSTM

In [None]:
import os
import time
import random
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence


from datasets import load_dataset, DatasetDict
from tokenizers import Tokenizer, normalizers
from tokenizers.models import BPE
from tokenizers.normalizers import NFD, Lowercase, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer


def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

set_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
ag_news_label = { 0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tec" }

ds = load_dataset("ag_news")

train_valid = ds["train"].train_test_split(test_size=0.05)

ds = DatasetDict({
    "train": train_valid["train"],
    "valid": train_valid["test"],
    "test": ds["test"],
})

ds

In [None]:
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
tokenizer.pre_tokenizer = Whitespace()

# padding用トークンのIDを0にするため、special_tokensで先に登録しておく
trainer = BpeTrainer(special_tokens=["[PAD]", "[UNK]"], vocab_size=30_000)
tokenizer.train_from_iterator(ds["train"]["text"], trainer=trainer)

In [None]:
#save_dir = "/content/drive/MyDrive/2025courses/nlp"
save_dir = "./"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

tokenizer.save(os.path.join(save_dir, "my-tokenizer.json"))

In [None]:
padding_value = tokenizer.token_to_id("[PAD]")

def collate_padded_batch(batch):
  label_list, text_list = [], []
  for instance in batch:
    _label, _text = instance["label"], instance["text"]
    label_list.append(_label)
    text_list.append(_text)
  labels = torch.tensor(label_list, dtype=torch.int64)
  sequences_list = tokenizer.encode_batch(text_list)
  token_ids = [torch.tensor(encoded.ids, dtype=torch.int64) for encoded in sequences_list]
  padded_sequences = pad_sequence(token_ids, batch_first=True, padding_value=padding_value)
  lengths = torch.tensor([len(ids) for ids in token_ids], dtype=torch.int64)
  return labels.to(device), padded_sequences.to(device), lengths.to(device)

In [None]:
BATCH_SIZE = 64

train_dataloader = DataLoader(
    ds["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_padded_batch
)
valid_dataloader = DataLoader(
    ds["valid"], batch_size=BATCH_SIZE, collate_fn=collate_padded_batch
)
test_dataloader = DataLoader(
    ds["test"], batch_size=BATCH_SIZE, collate_fn=collate_padded_batch
)

In [None]:
from torch import nn

class TextClassificationModel(nn.Module):
  def __init__(self, vocab_size, embed_dim, num_class):
    super(TextClassificationModel, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embed_dim, sparse=False)
    self.rnn = nn.GRU(embed_dim, embed_dim, num_layers=5, batch_first=True)
    self.fc = nn.Linear(embed_dim, num_class)

  # forward pass
  def forward(self, text, lengths):
    embeded = self.embedding(text)
    packed_input = pack_padded_sequence(embeded, lengths.cpu(), batch_first=True, enforce_sorted=False)
    rnn_out, h_n = self.rnn(packed_input)
    out = self.fc(h_n[-1])
    return out

In [None]:
num_class = len(set([label for label in ds["train"]["label"]]))
vocab_size = tokenizer.get_vocab_size()
print(f"Vocab size: {vocab_size}, num_class: {num_class}")

emsize = 64 # 埋め込みベクトルの次元 (これは自分で決める)
print(f"Embedding size: {emsize}")

model = TextClassificationModel(vocab_size, emsize, num_class).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()

epochs = 20
learning_rate = 1e-4

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.5)

In [None]:
def train(dataloader):
  model.train()
  total_acc, total_count = 0, 0
  log_interval = 500 # ログ情報を表示する間隔
  start_time = time.time()

  for idx, (label, text, lengths) in enumerate(dataloader):
    optimizer.zero_grad()
    predicted_label = model(text, lengths)
    loss = criterion(predicted_label, label)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
    optimizer.step()
    total_acc += (predicted_label.argmax(1) == label).sum().item()
    total_count += label.size(0)
    if idx % log_interval == 0 and idx > 0:
      elapsed = time.time() - start_time
      print(
          f"||| {idx:5d}/{len(dataloader):5d} batches | "
          f"time: {elapsed:5.2f}s | "
          f"accuracy {total_acc / total_count:8.3f}"
      )
      total_acc, total_count = 0, 0
      start_time = time.time()

In [None]:
def evaluate(dataloader):
  model.eval()
  total_acc, total_count = 0, 0

  with torch.no_grad():
    for idx, (label, text, lengths) in enumerate(dataloader):
      predicted_label = model(text, lengths)
      loss = criterion(predicted_label, label)
      total_acc += (predicted_label.argmax(1) == label).sum().item()
      total_count += label.size(0)
  return total_acc / total_count

In [None]:
total_accu = None

for epoch in range(epochs):
  epoch_start_time = time.time()
  train(train_dataloader)
  accu_val = evaluate(valid_dataloader)
  if total_accu is not None and total_accu > accu_val:
    # 検証データの正解率が前のエポックより下がったらスケジューラを動かす
    scheduler.step()
  else:
    total_accu = accu_val
  print("-" * 59)
  elapsed = time.time() - epoch_start_time
  print(
      f"| end of epoch {epoch+1:3d} | "
      f"time: {elapsed:5.2f}s | "
      f"lr = {optimizer.param_groups[0]['lr']:.3e} | "
      f"validation accuracy {accu_val:8.3f}"
  )
  print("-" * 82)

In [None]:
torch.save(model.state_dict(), os.path.join(save_dir, "my-model.pt"))

In [None]:
model = TextClassificationModel(vocab_size, emsize, num_class)
model.load_state_dict(torch.load(os.path.join(save_dir, "my-model.pt"), weights_only=True))
model.to(device)
model.eval()

In [None]:
print("Checking the results of test dataset...")
accu_test = evaluate(test_dataloader)
print(f"test accuracy {accu_test:8.3f}")

In [None]:
def predict(text):
  with torch.no_grad():
    text = torch.tensor([tokenizer.encode(text).ids]).to(device)
    lengths = torch.tensor([text.size(1)]).to(device)
    output = model(text, lengths)
    return output.argmax(1).item()


ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."

print("This is a {} news".format(ag_news_label[predict(ex_text_str)]))