In [None]:
# General Libraries
#! pip install pandas numpy keras nltk spacy tensorflow torch


In [None]:
#! pip install scikit-learn

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from sklearn.metrics import f1_score
import numpy as np
import os
from tqdm import tqdm

In [None]:
import tensorflow as tf

tf.keras.backend.clear_session()

# clear gpu memory using torch
import torch
torch.cuda.empty_cache()

# clear output
from IPython.display import clear_output
clear_output()

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
device

device(type='cpu')

In [None]:
train_path = ("/content/drive/MyDrive/ViHOS/data/Sequence_labeling_based_version/Word/dev_BIO_Word.csv")
dev_path = ("/content/drive/MyDrive/ViHOS/data/Sequence_labeling_based_version/Word/test_BIO_Word.csv")
test_path = ("/content/drive/MyDrive/ViHOS/data/Sequence_labeling_based_version/Word/train_BIO_Word.csv")

In [None]:
from transformers import (
    AutoModel, AutoConfig, XLMRobertaModel,
    AutoTokenizer, AutoModelForSequenceClassification
)

input_model = XLMRobertaModel.from_pretrained("xlm-roberta-base")
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
input_model.resize_token_embeddings(len(tokenizer))

clear_output()

# Data

In [None]:
import pandas as pd
import transformers
import torch
import torch.nn as nn
import pandas as pd

#clear output
from IPython.display import clear_output
clear_output()

In [None]:
def prepare_data(file_path):
    df = pd.read_csv(file_path)

    # remove nan
    df = df.dropna()
    df = df.reset_index(drop=True)

    texts = df['Word'].tolist()
    spans = df['Tag'].tolist()

    # convert spans to binary representation
    binary_spans = []
    for span in spans:
        binary_span = []
        span = span.split(' ')
        for s in span:
            if s == 'O':
                binary_span.append(0)
            else:
                binary_span.append(1)
        binary_spans.append(binary_span)

    return texts, binary_spans

In [None]:
# Dataloader function
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, spans, tokenizer, max_len):
        self.texts = texts
        self.spans = spans
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        span = self.spans[idx]

        # Tokenize and prepare the input data
        encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')

        # Convert tensors to appropriate shape
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'spans': torch.tensor(span)  # Ensure spans is converted to tensor
        }

# class TextDataset(torch.utils.data.Dataset):
#     def __init__(self, texts, spans, tokenizer, max_len):
#         self.texts = [tokenizer(text,
#                                 padding='max_length',
#                                 max_length = 64, truncation=True,
#                                 return_tensors="pt")for text in texts]
#         self.spans = []

#         for span in spans:
#             if len(span) < max_len:
#                 self.spans.append(span + [0] * (max_len - len(span)))
#             else:
#                 self.spans.append(span[:max_len])

#         self.spans = torch.tensor(self.spans)

#     def __len__(self):
#         return len(self.spans)

#     def __getitem__(self, index):
#         return self.texts[index], self.spans[index]

def create_dataloader(texts, spans, batch_size, tokenizer, max_len, shuffle=True):
    dataset = TextDataset(texts, spans, tokenizer, max_len)
    # return texts
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    print(dataloader)
    return dataloader

In [None]:
# import os
# cwd = os.getcwd()

In [None]:
# train_dir_text = str(os.path.join(cwd,train_path).replace('/codes/Models',''))
# dev_dir_text = str(os.path.join(cwd,dev_path).replace('/codes/Models',''))
# test_dir_text = str(os.path.join(cwd,test_path).replace('/codes/Models',''))

In [None]:
batch_size = 32
train_dataloader = create_dataloader(*prepare_data(train_path), batch_size=batch_size, tokenizer = tokenizer, max_len=64)
dev_dataloader = create_dataloader(*prepare_data(dev_path), batch_size=batch_size, tokenizer = tokenizer, max_len=64, shuffle=False)
test_dataloader = create_dataloader(*prepare_data(test_path), batch_size=batch_size, tokenizer = tokenizer, max_len=64)

<torch.utils.data.dataloader.DataLoader object at 0x7a6d888e7e80>
<torch.utils.data.dataloader.DataLoader object at 0x7a6d887aef50>
<torch.utils.data.dataloader.DataLoader object at 0x7a6d848fa0b0>


In [None]:
for b in train_dataloader:
    print(b['input_ids'])
    break

tensor([[    0,  3711,     2,  ...,     1,     1,     1],
        [    0,  4194,   454,  ...,     1,     1,     1],
        [    0, 44433,     2,  ...,     1,     1,     1],
        ...,
        [    0, 47364,    19,  ...,     1,     1,     1],
        [    0,  2579,     2,  ...,     1,     1,     1],
        [    0,   550,     2,  ...,     1,     1,     1]])


# Create Model

In [None]:
def calculate_f1(preds, y):
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    return f1_score(y.cpu(), max_preds.cpu(), average='macro')

In [None]:
class MultiTaskModel(nn.Module):
    def __init__(self, input_model):
        super(MultiTaskModel, self).__init__()
        self.bert = input_model
        self.span_classifier = nn.Linear(768, 1)  # Chuyển từ kích thước đầu ra của BERT (768) thành 1
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        last_hidden_state = output[0]
        last_hidden_state = self.dropout(last_hidden_state)
        span_logits = self.span_classifier(last_hidden_state)

        span_logits = span_logits.mean(dim=1)  # Tính trung bình theo chiều 1 để giảm kích thước thành [batch_size, 1]
        span_logits = span_logits.unsqueeze(-1)  # Thêm chiều cuối cùng để có kích thước [batch_size, 1, 1]
        span_logits = torch.sigmoid(span_logits)

        return span_logits


In [None]:
# Training function with gradient accumulation
def train(model, train_dataloader, dev_dataloader, criterion_span, optimizer_spans, device, num_epochs, accumulation_steps):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        print(f'Epoch: {epoch+1}/{num_epochs}')
        for i, batch in enumerate(tqdm(train_dataloader)):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            spans = batch['spans'].float().to(device)

            # Forward pass
            span_logits = model(input_ids, attention_mask)
            loss = criterion_span(span_logits.squeeze(), spans.squeeze()) / accumulation_steps

            # Backward pass and optimization
            loss.backward()
            if (i + 1) % accumulation_steps == 0:
                optimizer_spans.step()
                optimizer_spans.zero_grad()

            total_loss += loss.item() * accumulation_steps

        # Handle the last accumulation step if not perfectly divisible
        if (i + 1) % accumulation_steps != 0:
            optimizer_spans.step()
            optimizer_spans.zero_grad()

        # Calculate validation loss and F1-score
        val_loss = 0
        span_preds = []
        span_targets = []

        model.eval()
        with torch.no_grad():
            for batch in tqdm(dev_dataloader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                spans = batch['spans'].float().to(device)
                span_logits = model(input_ids, attention_mask)
                loss_span = criterion_span(span_logits.squeeze(), spans.squeeze())
                val_loss += loss_span.item()

                # Collect predictions and targets for F1 score
                span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
                span_targets.append(spans.cpu().numpy().flatten())

        # Flatten lists of predictions and targets
        span_preds = np.concatenate(span_preds)
        span_targets = np.concatenate(span_targets)
        span_preds = (span_preds > 0.5).astype(int)
        span_f1 = f1_score(span_targets, span_preds, average='macro')

        print(f'Training loss: {total_loss / len(train_dataloader)}')
        print(f'Validation loss: {val_loss / len(dev_dataloader)}')
        print(f'Span F1-score: {span_f1}')



# def train(model, train_dataloader, dev_dataloader, criterion_span, optimizer_spans, device, num_epochs):
#     model.train()
#     for epoch in range(num_epochs):
#         total_loss = 0
#         print('Epoch: ', epoch+1)
#         for texts, spans in tqdm(train_dataloader):
#             input_ids = texts['input_ids'].squeeze(1).to(device)
#             attention_mask = texts['attention_mask'].to(device)
#             spans = spans.float().to(device)

#             optimizer_spans.zero_grad()
#             span_logits = model(input_ids, attention_mask)
#             loss_span = criterion_span(span_logits.squeeze(), spans)

#             loss = loss_span
#             loss.backward()

#             optimizer_spans.step()
#             total_loss += loss.item()

#         # Calculate validation loss and macro F1-score
#         val_loss = 0
#         span_preds = []
#         span_targets = []

#         for texts, spans in tqdm(dev_dataloader):
#             input_ids = texts['input_ids'].squeeze(1).to(device)
#             attention_mask = texts['attention_mask'].to(device)
#             spans = spans.float().to(device)
#             with torch.no_grad():
#                 span_logits = model(input_ids, attention_mask)
#                 loss_span = criterion_span(span_logits.squeeze(), spans)

#                 val_loss += loss_span #+ loss_label

#             # Save the true labels and predicted labels for each sample
#             span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
#             span_targets.append(spans.cpu().numpy().flatten())

#         span_preds = np.concatenate(span_preds)
#         span_targets = np.concatenate(span_targets)
#         span_preds = (span_preds > 0.5).astype(int)
#         span_f1 = f1_score(span_targets, span_preds, average='macro')

#         print('Training loss: ', total_loss/len(train_dataloader))
#         print('Validation loss: ', val_loss/len(dev_dataloader))
#         print('Span F1-score: ', span_f1)


# Start Training

In [None]:
# import optim
import torch.optim as optim

# Set the number of epochs and the device to use for training
num_epochs = 100
accumulation_steps = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
# Create an instance of the multi-task model
model = MultiTaskModel(input_model = input_model)
model.to(device)

criterion_span = nn.BCELoss()

# Define the optimizer
optimizer_spans = optim.Adam(list(model.parameters()), lr=5e-6, weight_decay=1e-5)

train(model = model, train_dataloader = train_dataloader, dev_dataloader = dev_dataloader ,
      criterion_span = criterion_span, optimizer_spans = optimizer_spans ,device = device, num_epochs = num_epochs,
      accumulation_steps=accumulation_steps)

Epoch: 1/100


100%|██████████| 436/436 [02:41<00:00,  2.70it/s]
100%|██████████| 420/420 [00:46<00:00,  9.07it/s]


Training loss: 0.49886026400901856
Validation loss: 0.42146154854978835
Span F1-score: 0.45313072880596406
Epoch: 2/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.07it/s]


Training loss: 0.40457544468958445
Validation loss: 0.3650069402619487
Span F1-score: 0.45313072880596406
Epoch: 3/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.05it/s]


Training loss: 0.35451893880963326
Validation loss: 0.339992056006477
Span F1-score: 0.6858525531777235
Epoch: 4/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.05it/s]


Training loss: 0.3173425789888299
Validation loss: 0.3266323388775899
Span F1-score: 0.7484248240729137
Epoch: 5/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.07it/s]


Training loss: 0.27823310350537844
Validation loss: 0.32553797302146753
Span F1-score: 0.7489516485099821
Epoch: 6/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.05it/s]


Training loss: 0.2517178525481749
Validation loss: 0.35497870514435426
Span F1-score: 0.7498911695323272
Epoch: 7/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.05it/s]


Training loss: 0.23104148303423452
Validation loss: 0.33992839748305936
Span F1-score: 0.7534703940847907
Epoch: 8/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.10it/s]


Training loss: 0.21478049882576553
Validation loss: 0.3667199858774742
Span F1-score: 0.7578034393839364
Epoch: 9/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.09it/s]


Training loss: 0.20332127471172481
Validation loss: 0.37268892096444256
Span F1-score: 0.7625362932748925
Epoch: 10/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.1975943792030352
Validation loss: 0.4057709395530678
Span F1-score: 0.7582022253376479
Epoch: 11/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.09it/s]


Training loss: 0.1929516184288974
Validation loss: 0.37563966714910096
Span F1-score: 0.7539688784532739
Epoch: 12/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.10it/s]


Training loss: 0.18627558853648124
Validation loss: 0.4086452567062917
Span F1-score: 0.7613452402388912
Epoch: 13/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.18360065899515918
Validation loss: 0.40519845156619944
Span F1-score: 0.7595399437772142
Epoch: 14/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.10it/s]


Training loss: 0.1798122635435894
Validation loss: 0.41917396962110487
Span F1-score: 0.7578636250197296
Epoch: 15/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.1770416672315893
Validation loss: 0.4545219454249101
Span F1-score: 0.7599712509546356
Epoch: 16/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.11it/s]


Training loss: 0.1770011853221633
Validation loss: 0.4317412876302288
Span F1-score: 0.7595072551918929
Epoch: 17/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.10it/s]


Training loss: 0.1763733862130262
Validation loss: 0.438738785675239
Span F1-score: 0.7612711120811462
Epoch: 18/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.11it/s]


Training loss: 0.17428985781916376
Validation loss: 0.4624053645745984
Span F1-score: 0.7574717769712149
Epoch: 19/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.17359871213931008
Validation loss: 0.4691816866043068
Span F1-score: 0.7583784986060624
Epoch: 20/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.09it/s]


Training loss: 0.17147702410594995
Validation loss: 0.48803345907390827
Span F1-score: 0.7620477082033147
Epoch: 21/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.17521964248116
Validation loss: 0.4889148705683294
Span F1-score: 0.7585558517319
Epoch: 22/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.10it/s]


Training loss: 0.17415455164422
Validation loss: 0.4331118837353729
Span F1-score: 0.7550700597380295
Epoch: 23/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.1737623666327686
Validation loss: 0.47601125219038554
Span F1-score: 0.7612648988852719
Epoch: 24/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.07it/s]


Training loss: 0.1737760501409616
Validation loss: 0.45227204872561355
Span F1-score: 0.7625343329097924
Epoch: 25/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.05it/s]


Training loss: 0.16910900207612356
Validation loss: 0.5104074934231383
Span F1-score: 0.7575403175480802
Epoch: 26/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.09it/s]


Training loss: 0.16941526219453834
Validation loss: 0.47034455497882194
Span F1-score: 0.7633172842638045
Epoch: 27/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.03it/s]


Training loss: 0.16765326519129337
Validation loss: 0.5089757519641093
Span F1-score: 0.7601435960339364
Epoch: 28/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.1659241335321252
Validation loss: 0.5258283574666296
Span F1-score: 0.7620340808272729
Epoch: 29/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.02it/s]


Training loss: 0.16676855252184178
Validation loss: 0.50361610556997
Span F1-score: 0.7629484734139476
Epoch: 30/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.09it/s]


Training loss: 0.16583681737559788
Validation loss: 0.5195321539150817
Span F1-score: 0.7618512773639146
Epoch: 31/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.11it/s]


Training loss: 0.16528536978288802
Validation loss: 0.506963334259178
Span F1-score: 0.7642134630154285
Epoch: 32/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.05it/s]


Training loss: 0.1648246278610388
Validation loss: 0.5570344380901329
Span F1-score: 0.7602204452389059
Epoch: 33/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.03it/s]


Training loss: 0.16494091429722008
Validation loss: 0.529118036132838
Span F1-score: 0.7632152874428038
Epoch: 34/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.06it/s]


Training loss: 0.16432633517532175
Validation loss: 0.5384750751689786
Span F1-score: 0.7639966564007243
Epoch: 35/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.06it/s]


Training loss: 0.1706495169216955
Validation loss: 0.5034693677909672
Span F1-score: 0.7637474031748916
Epoch: 36/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.06it/s]


Training loss: 0.1737223178481495
Validation loss: 0.4670281908075724
Span F1-score: 0.7563291154219492
Epoch: 37/100


100%|██████████| 436/436 [02:27<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.07it/s]


Training loss: 0.17977792958071062
Validation loss: 0.42471808350334567
Span F1-score: 0.7624261088834691
Epoch: 38/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.16989785228608126
Validation loss: 0.49220925510994024
Span F1-score: 0.7619372819033227
Epoch: 39/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.11it/s]


Training loss: 0.16629433309778982
Validation loss: 0.5317576361997497
Span F1-score: 0.7641557923672755
Epoch: 40/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.08it/s]


Training loss: 0.16407134779942556
Validation loss: 0.5291818395390042
Span F1-score: 0.7619491696027687
Epoch: 41/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
100%|██████████| 420/420 [00:46<00:00,  9.09it/s]


Training loss: 0.16368929112603495
Validation loss: 0.5499034897202537
Span F1-score: 0.7629296218487396
Epoch: 42/100


100%|██████████| 436/436 [02:28<00:00,  2.95it/s]
100%|██████████| 420/420 [00:46<00:00,  9.09it/s]


Training loss: 0.16291314753951555
Validation loss: 0.5166276289416211
Span F1-score: 0.76654363610631
Epoch: 43/100


100%|██████████| 436/436 [02:28<00:00,  2.94it/s]
 69%|██████▉   | 291/420 [00:32<00:14,  9.02it/s]

# Load and test

In [None]:
# def test(model, test_dataloader, device):
#     model.eval()
#     span_preds = []
#     span_targets = []
#     for texts, spans in tqdm(test_dataloader):
#         input_ids = texts['input_ids'].squeeze(1).to(device)
#         attention_mask = texts['attention_mask'].to(device)
#         spans = spans.float().to(device)
#         with torch.no_grad():
#             span_logits = model(input_ids, attention_mask)

#         span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
#         span_targets.append(spans.cpu().numpy().flatten())

#     span_preds = np.concatenate(span_preds)
#     span_targets = np.concatenate(span_targets)
#     span_preds = (span_preds > 0.5).astype(int)
#     span_f1 = f1_score(span_targets, span_preds, average='macro')

#     print("Span F1 Score: {:.4f}".format(span_f1))

def test(model, test_dataloader, device):
    model.eval()
    span_preds = []
    span_targets = []

    for batch in tqdm(test_dataloader):
        # Print keys to understand the structure of batch
        print("Batch keys:", batch.keys())

        # Adjust the keys based on what you find
        # For example, if the keys are 'input_ids' and 'labels', update as follows:
        texts = batch['input_ids']  # Adjust this key according to your dataset
        spans = batch['labels']  # Adjust this key according to your dataset

        input_ids = texts.squeeze(1).to(device)
        attention_mask = batch['attention_mask'].to(device)
        spans = spans.float().to(device)

        with torch.no_grad():
            span_logits = model(input_ids, attention_mask)

        span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
        span_targets.append(spans.cpu().numpy().flatten())

    span_preds = np.concatenate(span_preds)
    span_targets = np.concatenate(span_targets)
    span_preds = (span_preds > 0.5).astype(int)
    span_f1 = f1_score(span_targets, span_preds, average='macro')

    print("Span F1 Score: {:.4f}".format(span_f1))


In [None]:
# Save the model
model_path = "/content/drive/MyDrive/ViHOS/data/ViHos_50epoch.pth"
torch.save(model.state_dict(), model_path)


In [None]:
# Load the model
device = torch.device("cpu")

model = MultiTaskModel(input_model=input_model)  # Reinitialize your model architecture
model.load_state_dict(torch.load("/content/drive/MyDrive/ViHOS/data/ViHos.pth", map_location=torch.device('cpu')))
model.to(device)  # Move the model to the appropriate device (GPU or CPU)
model.eval()  # Set the model to evaluation mode


MultiTaskModel(
  (bert): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
            

In [59]:
from sklearn.metrics import f1_score, recall_score, accuracy_score
import numpy as np
import torch
from tqdm import tqdm

def test(model, test_dataloader, device):
    model.eval()  # Set the model to evaluation mode
    span_preds = []
    span_targets = []

    for batch in tqdm(test_dataloader):
        # Extract components from the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        spans = batch['spans'].float().to(device)  # Ensure spans are in float for loss computation

        with torch.no_grad():  # Disable gradient computation for inference
            span_logits = model(input_ids, attention_mask)  # Get model predictions

        # Append predictions and targets for later evaluation
        span_preds.append(span_logits.squeeze().cpu().numpy().flatten())
        span_targets.append(spans.cpu().numpy().flatten())

    # Concatenate all predictions and targets
    span_preds = np.concatenate(span_preds)
    span_targets = np.concatenate(span_targets)

    # Binarize predictions based on threshold
    span_preds = (span_preds > 0.5).astype(int)

    # Calculate metrics
    span_f1 = f1_score(span_targets, span_preds, average='macro')
    span_recall = recall_score(span_targets, span_preds, average='macro')
    span_accuracy = accuracy_score(span_targets, span_preds)

    print("Span F1 Score: {:.4f}".format(span_f1))
    print("Span Recall: {:.4f}".format(span_recall))
    print("Span Accuracy: {:.4f}".format(span_accuracy))


In [60]:
def create_subset(dataloader, subset_size=100):
    subset_texts = []
    subset_spans = []

    for i, batch in enumerate(dataloader):
        if i * dataloader.batch_size >= subset_size:
            break

        # Extract data based on batch structure
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        spans = batch['spans']

        subset_texts.append({
            'input_ids': input_ids,
            'attention_mask': attention_mask
        })
        subset_spans.append(spans)

    # Convert lists to tensors if needed
    subset_texts = {
        'input_ids': torch.cat([x['input_ids'] for x in subset_texts], dim=0),
        'attention_mask': torch.cat([x['attention_mask'] for x in subset_texts], dim=0)
    }
    subset_spans = torch.cat(subset_spans, dim=0)

    return subset_texts, subset_spans

# Create a subset dataloader
def create_subset_dataloader(dataloader, subset_size=1000):
    subset_texts, subset_spans = create_subset(dataloader, subset_size)

    # Define a new dataset and dataloader
    class SubsetDataset(torch.utils.data.Dataset):
        def __init__(self, texts, spans):
            self.texts = texts
            self.spans = spans

        def __len__(self):
            return len(self.spans)

        def __getitem__(self, idx):
            return {
                'input_ids': self.texts['input_ids'][idx],
                'attention_mask': self.texts['attention_mask'][idx],
                'spans': self.spans[idx]
            }

    subset_dataset = SubsetDataset(subset_texts, subset_spans)
    subset_dataloader = torch.utils.data.DataLoader(
        subset_dataset, batch_size=dataloader.batch_size, shuffle=False, num_workers=4
    )
    return subset_dataloader

# Assuming test_dataloader is defined
subset_dataloader = create_subset_dataloader(test_dataloader, subset_size=1000)




In [None]:
# Example of using the model for inference
test(model, subset_dataloader, device)


 16%|█▌        | 5/32 [00:47<03:42,  8.23s/it]

In [None]:
# Example input sentence
sentence = "<điền đoạn text để sử dụng thử model>"

# Step 1: Tokenize the input sentence
encoding = tokenizer(sentence, truncation=True, padding='max_length', max_length=64, return_tensors='pt')

# Step 2: Prepare input tensors
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# Step 3: Pass the tensors through the model
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    output = model(input_ids, attention_mask)

# Step 4: Interpret the output
# Assuming the model predicts binary spans (0 or 1)
span_logits = output.squeeze().cpu().numpy()
span_predictions = (span_logits > 0.5).astype(int)  # Convert probabilities to binary predictions

# Print the results
print("Input Sentence:", sentence)
print("Predicted Spans:", span_predictions)
