<a href="https://colab.research.google.com/github/tomonari-masada/course2025-nlp/blob/main/08_text_classification_with_LLM_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLMを使ったテキスト分類

* LLMをfreezeさせたままの状態で、分類用のヘッドだけを訓練する。
* 事前学習済みのautoregressiveな言語モデルを使う。
  * 最終レイヤの出力を平均したものをテキストの埋め込みとして使う。

## 準備

In [None]:
import time
import torch
from torch import nn
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from transformers import set_seed, AutoModel, AutoTokenizer

set_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

## データセット
* AG Newsデータセットを使う。

In [None]:
ds = load_dataset("ag_news")
train_valid = ds["train"].train_test_split(test_size=0.05)
ds = DatasetDict({
    "train": train_valid["train"],
    "valid": train_valid["test"],
    "test": ds["test"],
})

ds

In [None]:
unique_labels = set([label for label in ds["train"]["label"]])
num_class = len(unique_labels)
print(f"Number of classes: {num_class}")

ag_news_label = { 0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tec" }

## LLM
* `LiquidAI/LFM2-350M`を使う。
  * 軽量なautoregressive LLM。

### モデルとトークナイザの取得

In [None]:
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-350M")
model = AutoModel.from_pretrained("LiquidAI/LFM2-350M").to(device)

* パラメータ数を確認する。

In [None]:
params = 0
for p in model.parameters():
  params += p.numel()

print(f"The model has {params:,} parameters")

### パラメータの凍結

In [None]:
for param in model.parameters():
  param.requires_grad = False

### モデルのモジュールの確認

In [None]:
model

### モデルの詳細を知る方法＝ソースを読む
* guthubにログインしてtransformersライブラリで`Lfm2ForCausalLM`を検索
  * 「/」キーを押してから検索語を入力すればよい。
* あるいは、下の場所でモデルがありそうなディレクトリを探す。
  * https://github.com/huggingface/transformers/tree/main/src/transformers/models

## トークナイザ

### padding

* paddingなし

In [None]:
tokenizer(ds["train"]["text"][:3])

* paddingあり

In [None]:
tokenizer(ds["train"]["text"][:3], padding=True)

## DataLoader

### colate関数
* トークナイザを使っているだけ。

In [None]:
def collate_fn(batch):
  texts = []
  labels = []
  for sample in batch:
    texts.append(sample["text"])
    labels.append(sample["label"])
  tokenized = tokenizer(texts, padding=True, return_tensors="pt")
  return tokenized.to(device), torch.tensor(labels).to(device)

* バッチサイズはGPUメモリの容量で決める。

In [None]:
BATCH_SIZE = 8

train_dataloader = DataLoader(ds["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(ds["valid"], batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_dataloader = DataLoader(ds["test"], batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [None]:
next(iter(train_dataloader))

## 分類モデル

### LLMの挙動の確認

In [None]:
tokenized, labels = next(iter(train_dataloader))
output = model(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"])

In [None]:
output

In [None]:
output.last_hidden_state.shape

### 分類モデルの定義

In [None]:
class TextClassificationModel(nn.Module):
  def __init__(self, model, num_class):
    super(TextClassificationModel, self).__init__()
    self.model = model
    self.fc = nn.Linear(model.config.hidden_size, num_class).to(model.device)

  # forward pass
  def forward(self, input_ids, attention_mask):
    output = self.model(input_ids=input_ids, attention_mask=attention_mask)
    embedded = (output.last_hidden_state * attention_mask.unsqueeze(-1)).sum(1)
    embedded = embedded / attention_mask.sum(1, keepdim=True)
    return self.fc(embedded)

## 学習のための準備

### 訓練を実行するヘルパ関数

In [None]:
def train(clf, dataloader, optimizer, criterion):
  clf.train()
  total_acc, total_loss, total_count = 0, 0, 0
  log_interval = 10
  start_time = time.time()

  for idx, (tokenized, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    logits = clf(tokenized["input_ids"], tokenized["attention_mask"])
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()
    total_acc += (logits.argmax(1) == labels).sum().item()
    total_loss += loss.item() * labels.size(0)
    total_count += labels.size(0)
    if idx % log_interval == 0 and idx > 0:
      print(
          f"||| {idx:5d}/{len(dataloader):5d} batches | "
          f"time: {time.time() - start_time:5.2f}s | "
          f"accuracy {total_acc / total_count:8.3f} | "
          f"loss {total_loss / total_count:8.3f}"
      )
      total_acc, total_loss, total_count = 0, 0, 0

### 評価を実行するヘルパ関数

In [None]:
from tqdm import tqdm

def evaluate(clf, dataloader, criterion):
  clf.eval()
  total_acc, total_count = 0, 0

  with torch.no_grad():
    for tokenized, labels in tqdm(dataloader):
      logits = clf(tokenized["input_ids"], tokenized["attention_mask"])
      loss = criterion(logits, labels)
      total_acc += (logits.argmax(1) == labels).sum().item()
      total_count += labels.size(0)
  return total_acc / total_count

## 学習の実行

In [None]:
epochs = 10
learning_rate = 1e-4

clf = TextClassificationModel(model, num_class)

# LLMの全パラメータを微調整する場合は、以下のコメントアウトを外す
#for param in clf.model.parameters():
#  param.requires_grad = True

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(clf.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
  epoch_start_time = time.time()
  train(clf, train_dataloader, optimizer, criterion)
  accu_val = evaluate(clf, valid_dataloader, criterion)
  print("-" * 59)
  elapsed = time.time() - epoch_start_time
  print(
      f"| end of epoch {epoch+1:3d} | "
      f"time: {elapsed:5.2f}s | "
      f"lr = {optimizer.param_groups[0]['lr']:.3f} | "
      f"validation accuracy {accu_val:8.3f}"
  )
  print("-" * 82)

## 別の分類モデル
* 最後のトークンに対応する最終レイヤの出力だけを使う。

* コードの書き方を調べるには・・・
* transformersのgithubで`GenericForSequenceClassification`を探す。
  * https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_layers.py

In [None]:
class TextClassificationModel(nn.Module):
  def __init__(self, model, num_class):
    super(TextClassificationModel, self).__init__()
    self.model = model
    self.fc = nn.Linear(model.config.hidden_size, num_class).to(model.device)

  # forward pass
  def forward(self, input_ids, attention_mask):
    output = self.model(input_ids=input_ids, attention_mask=attention_mask)
    non_pad_mask = (input_ids != self.model.config.pad_token_id).to(input_ids.device, torch.int32)
    token_indices = torch.arange(input_ids.shape[-1], device=input_ids.device)
    last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
    sample_indeces = torch.arange(input_ids.shape[0], device=input_ids.device)
    pooled_logits = output.last_hidden_state[sample_indeces, last_non_pad_token]
    return self.fc(pooled_logits)

In [None]:
epochs = 10
learning_rate = 1e-4

clf = TextClassificationModel(model, num_class)

# LLMの全パラメータを微調整する場合は、以下のコメントアウトを外す
#for param in clf.model.parameters():
#  param.requires_grad = True

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(clf.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
  epoch_start_time = time.time()
  train(clf, train_dataloader, optimizer, criterion)
  accu_val = evaluate(clf, valid_dataloader, criterion)
  print("-" * 59)
  elapsed = time.time() - epoch_start_time
  print(
      f"| end of epoch {epoch+1:3d} | "
      f"time: {elapsed:5.2f}s | "
      f"lr = {optimizer.param_groups[0]['lr']:.3f} | "
      f"validation accuracy {accu_val:8.3f}"
  )
  print("-" * 82)

## テストセットで最終評価

In [None]:
print("Checking the results of test dataset...")
accu_test = evaluate(clf, test_dataloader, criterion)
print(f"test accuracy {accu_test:8.3f}")