In [None]:
from sklearn.metrics import f1_score
import numpy as np
import os
from tqdm import tqdm

In [None]:
import tensorflow as tf

tf.keras.backend.clear_session()

# clear gpu memory using torch
import torch
torch.cuda.empty_cache()

# clear output
from IPython.display import clear_output
clear_output()

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
train_path = ("data/Sequence_labeling_based_version/Word/train_BIO_Word.csv")
dev_path = ("data/Sequence_labeling_based_version/Word/dev_BIO_Word.csv")
test_path = ("data/Span Extraction-based version/test.csv")

In [None]:
from transformers import (
    AutoModel, AutoConfig, XLMRobertaModel,
    AutoTokenizer, AutoModelForSequenceClassification
)

input_model = XLMRobertaModel.from_pretrained("vinai/phobert-large")
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-large")
input_model.resize_token_embeddings(len(tokenizer))

clear_output()

# Data

In [None]:
import pandas as pd
import transformers
import torch
import torch.nn as nn
import pandas as pd

#clear output
from IPython.display import clear_output
clear_output()

In [None]:
def prepare_data(file_path):
    df = pd.read_csv(file_path)

    # remove nan
    df = df.dropna()
    df = df.reset_index(drop=True)

    texts = df['text'].tolist()
    spans = df['spans'].tolist()

    # convert spans to binary representation
    binary_spans = []
    for span in spans:
        binary_span = []
        span = span.split(' ')
        for s in span:
            if s == 'O':
                binary_span.append(0)
            else:
                binary_span.append(1)
        binary_spans.append(binary_span)

    return texts, binary_spans

In [None]:
# Dataloader function
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, spans, tokenizer, max_len):
        self.texts = [tokenizer(text,
                                padding='max_length',
                                max_length = 64, truncation=True,
                                return_tensors="pt")for text in texts]
        self.spans = []

        for span in spans:
            if len(span) < max_len:
                self.spans.append(span + [0] * (max_len - len(span)))
            else:
                self.spans.append(span[:max_len])

        self.spans = torch.tensor(self.spans)

    def __len__(self):
        return len(self.spans)

    def __getitem__(self, index):
        return self.texts[index], self.spans[index]

def create_dataloader(texts, spans, batch_size, tokenizer, max_len, shuffle=True):
    dataset = TextDataset(texts, spans, tokenizer, max_len)
    # return texts
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

In [None]:
batch_size = 64
train_dataloader = create_dataloader(*prepare_data(train_path), batch_size=batch_size, tokenizer = tokenizer, max_len=64)
dev_dataloader = create_dataloader(*prepare_data(dev_path), batch_size=batch_size, tokenizer = tokenizer, max_len=64, shuffle=False)
test_dataloader = create_dataloader(*prepare_data(test_path), batch_size=batch_size, tokenizer = tokenizer, max_len=64)

# Create Model

In [None]:
def calculate_f1(preds, y):
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    return f1_score(y.cpu(), max_preds.cpu(), average='macro')

In [None]:
class MultiTaskModel(nn.Module):
    def __init__(self, input_model):
        super(MultiTaskModel, self).__init__()
        self.bert = input_model
        self.span_classifier = nn.Linear(768, 1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        last_hidden_state = output[0]
        last_hidden_state = self.dropout(last_hidden_state)
        span_logits = self.span_classifier(last_hidden_state)

        span_logits = span_logits.permute(0, 2, 1)
        span_logits = torch.sigmoid(span_logits)
        span_logits = span_logits.permute(0, 2, 1)

        return  span_logits

In [None]:
def train(model, train_dataloader, dev_dataloader, criterion_span, optimizer_spans, device, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        print('Epoch: ', epoch+1)
        for texts, spans in tqdm(train_dataloader):
            input_ids = texts['input_ids'].squeeze(1).to(device)
            attention_mask = texts['attention_mask'].to(device)
            spans = spans.float().to(device)

            optimizer_spans.zero_grad()
            span_logits = model(input_ids, attention_mask)
            loss_span = criterion_span(span_logits.squeeze(), spans)

            loss = loss_span
            loss.backward()

            optimizer_spans.step()
            total_loss += loss.item()

        # Calculate validation loss and macro F1-score
        val_loss = 0
        span_preds = []
        span_targets = []

        for texts, spans in tqdm(dev_dataloader):
            input_ids = texts['input_ids'].squeeze(1).to(device)
            attention_mask = texts['attention_mask'].to(device)
            spans = spans.float().to(device)
            with torch.no_grad():
                span_logits = model(input_ids, attention_mask)
                loss_span = criterion_span(span_logits.squeeze(), spans)

                val_loss += loss_span #+ loss_label

            # Save the true labels and predicted labels for each sample
            span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
            span_targets.append(spans.cpu().numpy().flatten())

        span_preds = np.concatenate(span_preds)
        span_targets = np.concatenate(span_targets)
        span_preds = (span_preds > 0.5).astype(int)
        span_f1 = f1_score(span_targets, span_preds, average='macro')

        print('Training loss: ', total_loss/len(train_dataloader))
        print('Validation loss: ', val_loss/len(dev_dataloader))
        print('Span F1-score: ', span_f1)


# Start Training

In [None]:
# import optim
import torch.optim as optim

# Set the number of epochs and the device to use for training
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create an instance of the multi-task model
model = MultiTaskModel(input_model = input_model)
model.to(device)

criterion_span = nn.BCELoss()

# Define the optimizer
optimizer_spans = optim.Adam(list(model.parameters()), lr=5e-6, weight_decay=1e-5)

train(model = model, train_dataloader = train_dataloader, dev_dataloader = dev_dataloader ,
      criterion_span = criterion_span, optimizer_spans = optimizer_spans ,device = device, num_epochs = num_epochs)

# Load and test

In [None]:
def test(model, test_dataloader, device):
    model.eval()
    span_preds = []
    span_targets = []
    for texts, spans in tqdm(test_dataloader):
        input_ids = texts['input_ids'].squeeze(1).to(device)
        attention_mask = texts['attention_mask'].to(device)
        spans = spans.float().to(device)
        with torch.no_grad():
            span_logits = model(input_ids, attention_mask)

        span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
        span_targets.append(spans.cpu().numpy().flatten())

    span_preds = np.concatenate(span_preds)
    span_targets = np.concatenate(span_targets)
    span_preds = (span_preds > 0.5).astype(int)
    span_f1 = f1_score(span_targets, span_preds, average='macro')

    print("Span F1 Score: {:.4f}".format(span_f1))

In [None]:
# model = MultiTaskModel(input_model = input_model)
# model.load_state_dict(torch.load('model.pt'))
# model.to(device)

# test(model = model, test_dataloader = test_dataloader, device = device)