In [45]:
from google.colab import drive
drive.mount('/content/drive')

checkpoint_parent_dir = "/content/drive/MyDrive/PBL6/Code/checkpoints"
data_parent_dir = "/content/drive/MyDrive/PBL6/Code/data"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# 1. Khai báo các thư viện cần thiết

In [46]:
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
from tqdm import tqdm
from IPython.display import clear_output
from typing import List, Tuple
import os
from os import walk
import matplotlib.pyplot as plt

# For training and evaluation
import tensorflow as tf
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import XLMRobertaModel, AutoTokenizer

# For data processing
import pandas as pd
from torch.utils.data import DataLoader, Dataset

### 1.1. Setup ban đầu

In [47]:
tf.keras.backend.clear_session()

# Clear memory
torch.cuda.empty_cache()

# Clear output of the cell
clear_output()

# Set runtime on GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### 1.2. Tạo input model

In [48]:
train_path = f"{data_parent_dir}/train.csv"
dev_path = f"{data_parent_dir}/dev.csv"
test_path = f"{data_parent_dir}/test.csv"

input_model = XLMRobertaModel.from_pretrained("xlm-roberta-base") # load pre-trained model
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") # load pre-trained tokenizer
input_model.resize_token_embeddings(len(tokenizer))

clear_output()

# 2. Xử lý dữ liệu

In [49]:
clear_output()

### 2.1. Tiền xử lý data

In [50]:
def prepare_data(file_path: str) -> Tuple[List[str], List[List[int]]]:
    df = pd.read_csv(file_path)

    # Drop rows with missing values
    df = df.dropna()
    df = df.reset_index(drop=True)

    texts = df['Word'].tolist()
    spans = df['Tag'].tolist()

    # Convert spans to binary (0, 1)
    binary_spans = []
    for span in spans:
        binary_span = []
        span = span.split(' ')
        for s in span:
            if s == 'O':
                binary_span.append(0)
            else:
                binary_span.append(1)
        binary_spans.append(binary_span)

    return texts, binary_spans

### 2.2. Tạo class TextDataset và hàm tạo dataloader

In [51]:
# Dataloader class
class TextDataset(Dataset):
    def __init__(self, tokenizer, texts: List[str], spans: List[List[int]], max_len: int):
        # Tokenize text with output format {'input_ids': [], 'attention_mask': []}s
        self.texts = [tokenizer(text, padding='max_length', max_length=max_len, truncation=True, return_tensors="pt") for text in texts]
        self.spans = []

        for span in spans: # Padding spans to max_len
            if len(span) < max_len:
                self.spans.append(span + [0] * (max_len - len(span)))
            else:
                self.spans.append(span[:max_len])

        self.spans = torch.tensor(self.spans)

    def __len__(self):
        return len(self.spans)

    def __getitem__(self, index):
        return self.texts[index], self.spans[index]

def create_dataloader(texts, spans, batch_size, tokenizer, max_len, shuffle=True) -> DataLoader:
    dataset = TextDataset(tokenizer, texts, spans, max_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

### 2.3. Tạo dataloader cho dữ liệu **train**, **dev** và **test**

In [None]:
batch_size = 64 # devide data into batches
train_dataloader = create_dataloader(*prepare_data(train_path), batch_size=batch_size, tokenizer=tokenizer, max_len=64)
dev_dataloader = create_dataloader(*prepare_data(dev_path), batch_size=batch_size, tokenizer=tokenizer, max_len=64, shuffle=False)
test_dataloader = create_dataloader(*prepare_data(test_path), batch_size=batch_size, tokenizer=tokenizer, max_len=64)

# 3. Tạo mô hình huấn luyện

### 3.1. Tạo lớp mô hình huấn luyện

In [None]:
class MultiTaskModel(nn.Module):
    def __init__(self, input_model):
        super(MultiTaskModel, self).__init__()
        self.bert = input_model
        self.span_classifier = nn.Linear(768, 1) # Classification head
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        # Forward pass through the BERT model
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        last_hidden_state = output[0] # Hidden state of shape (batch_size, sequence_length, hidden_size)

        # Apply dropout
        last_hidden_state = self.dropout(last_hidden_state)

        # Apply the span classifier to get logits for each token (batch_size, sequence_length, 1)
        span_logits = self.span_classifier(last_hidden_state)

        span_logits = span_logits.permute(0, 2, 1)
        span_logits = torch.sigmoid(span_logits)
        span_logits = span_logits.permute(0, 2, 1)

        return  span_logits

    def predict(self, input_ids, attention_mask, threshold: float = 0.5):
        span_logits = self.forward(input_ids, attention_mask)

        # Convert logits to binary labels
        predictions = (span_logits > threshold).float()  # (batch_size, sequence_length, 1)

        # Aggregate the predictions across the sequence
        aggregated_predictions = torch.mean(predictions, dim=1)  # (batch_size, 1)

        # If the average prediction is above the threshold, classify as HOS (1), otherwise NOT-HOS (0)
        final_predictions = (aggregated_predictions > threshold).long()  # Final binary prediction (batch_size, 1)

        return final_predictions

### 3.2. Tạo các hàm tính toán chung

In [None]:
def calculate_f1(preds, y):
    return f1_score(y, preds, average='macro')

def calculate_accuracy(preds, y):
    return accuracy_score(y, preds)

def save_checkpoint(model: MultiTaskModel, filename: str):
    torch.save(model.state_dict(), f"{checkpoint_parent_dir}/{filename}")

def load_checkpoint(model: MultiTaskModel, filename: str) -> MultiTaskModel:
    model.load_state_dict(torch.load(f"{checkpoint_parent_dir}/{filename}"))
    return model

def get_epoch_name(id)->str:
    return f"epoch_{id}.pt"

### 3.3. Hàm huấn luyện mô hình

In [None]:
def train(
    model: MultiTaskModel,
    train_dataloader: DataLoader,
    dev_dataloader: DataLoader,
    criterion_span: nn.BCELoss,
    optimizer_spans: optim.Adam,
    device: torch.device,
    start_epoch: int,
    num_epochs: int
):
    model.train() # Turn on training mode
    total_train_loss = [] # For plotting
    total_train_acc = [] # For plotting
    total_val_loss = [] # For plotting
    total_val_acc = [] # For plotting

    if start_epoch >= num_epochs:
        return total_train_loss, total_train_acc, total_val_loss, total_val_acc

    # Load checkpoint if start_epoch > 0
    if start_epoch > 0:
        print('Loading checkpoint...')
        model = load_checkpoint(model, get_epoch_name(start_epoch))

    for epoch in range(start_epoch, num_epochs):
        print('Epoch: ', epoch+1)
        train_loss = 0
        train_span_preds = []
        train_span_targets = []
        for texts, spans in tqdm(train_dataloader):
            input_ids = texts['input_ids'].squeeze(1).to(device)
            attention_mask = texts['attention_mask'].to(device)
            spans = spans.float().to(device)

            optimizer_spans.zero_grad()
            span_logits = model.forward(input_ids, attention_mask) # Forward pass
            loss_span = criterion_span(span_logits.squeeze(), spans)

            loss = loss_span
            loss.backward()

            optimizer_spans.step()
            train_loss += loss.item()
            train_span_preds.append(span_logits.squeeze().cpu().detach().numpy().flatten())
            train_span_targets.append(spans.cpu().numpy().flatten())

        # Calculate validation loss and macro F1-score
        val_loss = 0
        val_span_preds = []
        val_span_targets = []

        for texts, spans in tqdm(dev_dataloader):
            input_ids = texts['input_ids'].squeeze(1).to(device)
            attention_mask = texts['attention_mask'].to(device)
            spans = spans.float().to(device)
            with torch.no_grad():
                span_logits = model.forward(input_ids, attention_mask) # Forward pass
                loss_span = criterion_span(span_logits.squeeze(), spans)

                val_loss += loss_span

            # Save the true labels and predicted labels for each sample
            val_span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
            val_span_targets.append(spans.cpu().numpy().flatten())

        # Validation loss, accuracy, F1-score
        val_span_preds = np.concatenate(val_span_preds)
        val_span_preds = (val_span_preds > 0.5).astype(int)
        val_span_targets = np.concatenate(val_span_targets)

        val_loss = val_loss/len(dev_dataloader)
        val_acc = calculate_accuracy(val_span_preds, val_span_targets)
        val_f1 = calculate_f1(val_span_preds, val_span_targets)

        # Train loss, accuracy
        train_span_preds = np.concatenate(train_span_preds)
        train_span_preds = (train_span_preds > 0.5).astype(int)
        train_span_targets = np.concatenate(train_span_targets)

        train_loss = train_loss/len(train_dataloader)
        train_acc = calculate_accuracy(train_span_preds, train_span_targets)

        print(f" -> Train loss: {train_loss}; Train acc: {train_acc} -- Val loss: {val_loss}; Val acc: {val_acc} -- F1-score: {val_f1}")

        if os.path.exists(f"{checkpoint_parent_dir}/{get_epoch_name(epoch)}"): # Remove previous checkpoint
            os.remove(f"{checkpoint_parent_dir}/{get_epoch_name(epoch)}")
        save_checkpoint(model, get_epoch_name(epoch+1) if epoch < num_epochs - 1 else get_epoch_name("final")) # Save current checkpoint

        # Save values for plotting
        total_train_loss.append(train_loss)
        total_train_acc.append(train_acc)
        total_val_loss.append(val_loss)
        total_val_acc.append(val_acc)

    return total_train_loss, total_train_acc, total_val_loss, total_val_acc

# 4. Huấn luyện và kiểm tra mô hình

### 4.1. Huấn luyện mô hình

In [None]:
num_epochs = 100

# Create an instance of the multi-task model
model = MultiTaskModel(input_model = input_model)
model.to(device)

criterion_span = nn.BCELoss()

# Define the optimizer
optimizer_spans = optim.Adam(list(model.parameters()), lr=5e-6, weight_decay=1e-5)

# Check checkpoint
filenames = next(walk(checkpoint_parent_dir), (None, None, []))[2]  # [] if no file
start_epoch = 0
for i in filenames:
    if '.pt' in i:
      try:
        if 'final' in i:
          start_epoch = num_epochs
          break

        idx = int(i[:-3].split("_")[-1]) # remove '.pt' and split file name
        start_epoch = idx if idx > start_epoch else start_epoch
      except Exception as e:
        continue


total_train_loss, total_train_acc, total_val_loss, total_val_acc = train(
    model = model,
    train_dataloader = train_dataloader,
    dev_dataloader = dev_dataloader,
    criterion_span = criterion_span,
    optimizer_spans = optimizer_spans,
    device = device,
    start_epoch = start_epoch,
    num_epochs = num_epochs
)

### 4.2. Hiển thị các biểu đồ

In [None]:
def tensor_to_numpy(tensor):
    if torch.is_tensor(tensor):
        # Check if tensor is on GPU, move it to CPU first if necessary
        if tensor.is_cuda:
            tensor = tensor.cpu()
        return tensor.numpy()
    else:
        # If input is not a tensor, return it as is
        return tensor

In [None]:
def plot_training_results(total_train_loss, total_train_acc, total_val_loss, total_val_acc):
    total_train_loss = [tensor_to_numpy(e) for e in total_train_loss]
    total_train_acc = [tensor_to_numpy(e) for e in total_train_acc]
    total_val_loss = [tensor_to_numpy(e) for e in total_val_loss]
    total_val_acc = [tensor_to_numpy(e) for e in total_val_acc]

    epochs = range(1, len(total_train_loss) + 1)

    plt.figure(figsize=(12, 6))

    # Plot training and validation loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, total_train_loss, 'b', label='Training Loss')
    plt.plot(epochs, total_val_loss, 'r', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot training and validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, total_train_acc, 'b', label='Training Accuracy')
    plt.plot(epochs, total_val_acc, 'r', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

plot_training_results(total_train_loss, total_train_acc, total_val_loss, total_val_acc)

### 4.3. Kiểm tra mô hình trên tập dữ liệu test

In [None]:
# def test(model, test_dataloader, device):
#     model.eval()
#     span_preds = []
#     span_targets = []
#     for texts, spans in tqdm(test_dataloader):
#         input_ids = texts['input_ids'].squeeze(1).to(device)
#         attention_mask = texts['attention_mask'].to(device)
#         spans = spans.float().to(device)
#         with torch.no_grad():
#             span_logits = model(input_ids, attention_mask)

#         span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
#         span_targets.append(spans.cpu().numpy().flatten())

#     span_preds = np.concatenate(span_preds)
#     span_targets = np.concatenate(span_targets)
#     span_preds = (span_preds > 0.5).astype(int)
#     span_f1 = f1_score(span_targets, span_preds, average='macro')

#     print("Span F1 Score: {:.4f}".format(span_f1))

In [None]:
# model = MultiTaskModel(input_model = input_model)
# model.load_state_dict(torch.load(f"{checkpoint_parent_dir}/epoch_final.pt", weights_only=True))
# model.to(device)

# test(model = model, test_dataloader = test_dataloader, device = device)