<a href="https://colab.research.google.com/github/tomonari-masada/course2024-nlp/blob/main/07_PyTorch_3_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets tokenizers

In [None]:
import random
import time
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from tokenizers import Tokenizer, normalizers
from tokenizers.models import WordPiece
from tokenizers.normalizers import NFD, Lowercase, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordPieceTrainer


def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

set_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = load_dataset("ag_news")
ag_news_label = { 0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tec" }

train_valid = dataset["train"].train_test_split(test_size=0.05)
dataset = DatasetDict({
    "train": train_valid["train"],
    "valid": train_valid["test"],
    "test": dataset["test"],
})
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 114000
    })
    valid: Dataset({
        features: ['text', 'label'],
        num_rows: 6000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

In [None]:
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
tokenizer.pre_tokenizer = Whitespace()

trainer = WordPieceTrainer(special_tokens=["[UNK]"])
tokenizer.train_from_iterator(dataset["train"]["text"], trainer)






In [None]:
vocab = tokenizer.get_vocab()
print([k for k, v in sorted(vocab.items(), key=lambda item: item[1])])

In [None]:
def collate_batch(batch):
  label_list, text_list, offsets = [], [], [0]
  for instance in batch:
    _label, _text = instance["label"], instance["text"]
    label_list.append(_label)
    token_ids = torch.tensor(tokenizer.encode(_text).ids, dtype=torch.int64)
    text_list.append(token_ids)
    offsets.append(token_ids.size(0))
  label_list = torch.tensor(label_list, dtype=torch.int64)
  offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
  text_list = torch.cat(text_list)
  return label_list.to(device), text_list.to(device), offsets.to(device)

In [None]:
BATCH_SIZE = 32
train_dataloader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(dataset["valid"], batch_size=BATCH_SIZE, collate_fn=collate_batch)
test_dataloader = DataLoader(dataset["test"], batch_size=BATCH_SIZE, collate_fn=collate_batch)

In [None]:
class TextClassificationModel(nn.Module):

  def __init__(self, vocab_size, embed_dim, num_class):
    super(TextClassificationModel, self).__init__()
    self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
    self.fc1 = nn.Linear(embed_dim, embed_dim)
    self.fc2 = nn.Linear(embed_dim, embed_dim)
    self.fc3 = nn.Linear(embed_dim, num_class)
    self.act = nn.ReLU()
    self.dropout = nn.Dropout()

  def forward(self, text, offsets):
    embedded = self.dropout(self.embedding(text, offsets))
    hidden = self.act(self.fc1(embedded)) + embedded
    hidden = self.act(self.fc2(hidden)) + hidden
    return self.fc3(hidden)

In [None]:
unique_labels = set([label for label in dataset["train"]["label"]])
print(unique_labels)
num_class = len(unique_labels)

vocab_size = len(vocab)
emsize = 128

model = TextClassificationModel(vocab_size, emsize, num_class).to(device)

{0, 1, 2, 3}


In [None]:
learning_rate = 1e-3

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.3)

In [None]:
def train(dataloader):
  model.train()
  total_acc, total_count = 0, 0
  start_time = time.time()
  for idx, (label, text, offsets) in enumerate(dataloader):
    optimizer.zero_grad()
    predicted_label = model(text, offsets)
    loss = criterion(predicted_label, label)
    loss.backward()
    optimizer.step()
    total_acc += (predicted_label.argmax(1) == label).sum().item()
    total_count += label.size(0)
  elapsed = time.time() - start_time
  return total_acc / total_count

In [None]:
def evaluate(dataloader):
  model.eval()
  total_acc, total_count = 0, 0
  with torch.no_grad():
    for _, (label, text, offsets) in enumerate(dataloader):
      predicted_label = model(text, offsets)
      total_acc += (predicted_label.argmax(1) == label).sum().item()
      total_count += label.size(0)
  return total_acc / total_count

In [None]:
epochs = 20

total_accu = None
start_time = time.time()
for epoch in range(epochs):
  accu_train = train(train_dataloader)
  accu_val = evaluate(valid_dataloader)
  if total_accu is not None and total_accu > accu_val:
    scheduler.step()
  total_accu = accu_val
  elapsed = time.time() - start_time
  print(
      f"| epoch {epoch+1} ({elapsed:.2f}s) | "
      f"lr={optimizer.param_groups[0]['lr']:.2e} | "
      f"train accu {accu_train:.3f} | "
      f"val accu {accu_val:.3f}"
  )

In [None]:
accu_test = evaluate(test_dataloader)
print(f"test accuracy {accu_test:8.3f}")

test accuracy    0.928


In [None]:
def predict(text):
  with torch.no_grad():
    input_ids = tokenizer.encode(text).ids
    text = torch.tensor(input_ids).to(device)
    output = model(text, torch.tensor([0]).to(device))
    return output.argmax(1).item()

ex_text_str = """
Despite the bustle of Christmas shoppers in Leeds, small businesses
are sharing their concerns about the upcoming challenges they face.
In Labour's first budget in 14 years,
chancellor Rachel Reeves raised employer National Insurance contributions
and announced minimum wage increases.
"""

print("This is a {} news".format(ag_news_label[predict(ex_text_str)]))

This is a Business news
