In [None]:
!pip install transformers

In [None]:
from os import path
import json
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR
from transformers import BertModel
import matplotlib.pyplot as plt

base = "/content/drive/MyDrive/NLP100/ch09"

fp_train = "80/train.csv"
fp_valid = "80/valid.csv"
fp_words = "80/word_ids.json"
df_train = pd.read_csv(path.join(base, fp_train), index_col=0)
df_valid = pd.read_csv(path.join(base, fp_valid), index_col=0)
word_ids = json.load(open(path.join(base, fp_words), "r"))
df_train.head()

In [None]:
num_words_of_title = max([len(title.split()) for title in df_train["TITLE"]])


def title_to_ids(t):
  res = [0 for _ in range(num_words_of_title)]
  mask = [0 for _ in range(num_words_of_title)]
  for i, w in enumerate(t.split()):
    if w in word_ids.keys():
      res[i] = word_ids[w]
      mask[i] = 1
  return res, mask


res, mask = title_to_ids("Europe reaches crunch point on banking union")
print(res[:5])
print(mask[:5])

In [None]:
X_train = torch.tensor([title_to_ids(title) for title in df_train["TITLE"]])
y_train = torch.tensor(df_train["CATEGORY"].values.astype("int"))
X_valid = torch.tensor([title_to_ids(title) for title in df_valid["TITLE"]])
y_valid = torch.tensor(df_valid["CATEGORY"].values.astype("int"))
print(X_train[:5])

In [None]:
batch_size = 64
num_workers = 2
dataset_train = [(X_i[0], X_i[1], y_i) for X_i, y_i in zip(X_train, y_train)]
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
dataset_valid = [(X_i[0], X_i[1], y_i) for X_i, y_i in zip(X_valid, y_valid)]
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
hidden_size = 768

class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.bert = BertModel.from_pretrained("bert-base-uncased")
    self.linear = nn.Linear(hidden_size, 4)
    self.softmax = nn.Softmax(dim=1)
    
  def forward(self, x, mask):
    y = self.bert(x, attention_mask=mask)
    y = self.linear(y.pooler_output)
    y = self.softmax(y)
    return y

if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"
print(f"device: {device}")
model = Model().to(device)
print(model)

In [None]:
learning_rate = 0.05
loss_fn = nn.CrossEntropyLoss()
optimizer = SGD(model.linear.parameters(), lr=learning_rate)
scheduler = ExponentialLR(optimizer, gamma=0.95)

loss_train = []
correct_train = []
loss_valid = []
correct_valid = []

for epoch in range(5):
  print(f"Epoch {epoch + 1}\n-------------------------------")
  size = len(dataloader_train.dataset)
  for batch, (X, mask, y) in enumerate(dataloader_train):
    X, mask, y = X.to(device), mask.to(device), y.to(device)
    loss = loss_fn(model(X, mask), y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if batch % 10 == 0:
      loss, current = loss.item(), batch * len(X)
      print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
  scheduler.step()

  size = len(dataloader_train.dataset)
  loss, correct = 0, 0
  with torch.no_grad():
    for X, mask, y in dataloader_train:
      X, mask, y = X.to(device), mask.to(device), y.to(device)
      pred = model(X, mask)
      loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  loss /= size
  correct /= size
  loss_train.append(loss)
  correct_train.append(correct)

  size = len(dataloader_valid.dataset)
  loss, correct = 0, 0
  with torch.no_grad():
    for X, mask, y in dataloader_valid:
      X, mask, y = X.to(device), mask.to(device), y.to(device)
      pred = model(X, mask)
      loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  loss /= size
  correct /= size
  loss_valid.append(loss)
  correct_valid.append(correct)
  print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {loss:>8f} \n")

In [None]:
plt.plot(loss_train, label="train")
plt.plot(loss_valid, label="valid")
plt.title("Loss")
plt.legend()
plt.show()

In [None]:
plt.plot(correct_train, label="train")
plt.plot(correct_valid, label="valid")
plt.title("Accuracy")
plt.legend()
plt.show()