In [1]:
import pandas as pd
import torch
from transformers import AlbertTokenizer, AlbertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import accuracy_score,f1_score, precision_score, recall_score
from tqdm import tqdm
from scipy.stats import entropy
import math
import random

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

train_df = pd.read_csv('train.csv')
val_df = pd.read_csv('validation.csv')
test_df = pd.read_csv('test.csv')
T_TTA_test_df = pd.read_csv('T_TTA_enhanced_test.csv')
R_TTA_test_df = pd.read_csv('R_TTA_enhanced_test.csv')
MSTTA_test_df = pd.read_csv('MSTTA_enhanced_test.csv')

class SentimentDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        sentence = self.dataframe.iloc[index]['sentence']
        label = self.dataframe.iloc[index]['label']
        encoding = self.tokenizer(
            sentence,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(label, dtype=torch.long)
        }

def create_data_loader(df, tokenizer, max_len, batch_size,shuffle):
    ds = SentimentDataset(df, tokenizer, max_len)
    return DataLoader(ds, batch_size=batch_size, num_workers=0, pin_memory=True,shuffle=shuffle)

PRE_TRAINED_MODEL_NAME = 'albert-base-v2'
tokenizer = AlbertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
model = AlbertForSequenceClassification.from_pretrained(PRE_TRAINED_MODEL_NAME, num_labels=5)

for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

max_len = 128
batch_size = 32  
epochs = 10
learning_rate = 1e-5  

train_data_loader = create_data_loader(train_df, tokenizer, max_len, batch_size,shuffle=True)
val_data_loader = create_data_loader(val_df, tokenizer, max_len, batch_size,shuffle=False)
test_data_loader = create_data_loader(test_df, tokenizer, max_len, batch_size,shuffle=False)
T_TTA_test_data_loader = create_data_loader(T_TTA_test_df, tokenizer, max_len, batch_size,shuffle=False)
R_TTA_test_data_loader = create_data_loader(R_TTA_test_df, tokenizer, max_len, batch_size,shuffle=False)
MSTTA_test_data_loader= create_data_loader(MSTTA_test_df, tokenizer, max_len, batch_size,shuffle=False)

optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_data_loader) * epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

loss_fn = torch.nn.CrossEntropyLoss().to(device)

def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler):
    model.train()
    total_loss, correct_predictions = 0, 0

    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs.logits, dim=1)
        loss = loss_fn(outputs.logits, labels)

        correct_predictions += torch.sum(preds == labels)
        total_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    return correct_predictions.double() / len(data_loader.dataset), total_loss / len(data_loader)

def eval_model(model, data_loader, loss_fn, device):
    model.eval()
    total_loss, correct_predictions = 0, 0

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs.logits, dim=1)
            loss = loss_fn(outputs.logits, labels)

            correct_predictions += torch.sum(preds == labels)
            total_loss += loss.item()

    return correct_predictions.double() / len(data_loader.dataset), total_loss / len(data_loader)


best_accuracy = 0
patience = 3
early_stop_counter = 0

for epoch in range(epochs):
    print(f'Epoch {epoch + 1}/{epochs}')
    print('-' * 50)

    train_acc, train_loss = train_epoch(model, train_data_loader, loss_fn, optimizer, device, scheduler)
    print(f'Train loss {train_loss}, accuracy {train_acc}')

    val_acc, val_loss = eval_model(model, val_data_loader, loss_fn, device)
    print(f'Val loss {val_loss}, accuracy {val_acc}')

    if val_acc > best_accuracy:
        torch.save(model.state_dict(), 'albert_Compare_model.pth')
        best_accuracy = val_acc
        early_stop_counter = 0  
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("Early stopping triggered")
            break



Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForSequenceClassification: ['predictions.dense.bias', 'predictions.LayerNorm.bias', 'predictions.LayerNorm.weight', 'predictions.dense.weight', 'predictions.decoder.weight', 'predictions.decoder.bias', 'predictions.bias']
- This IS expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You sho

albert.embeddings.word_embeddings.weight: True
albert.embeddings.position_embeddings.weight: True
albert.embeddings.token_type_embeddings.weight: True
albert.embeddings.LayerNorm.weight: True
albert.embeddings.LayerNorm.bias: True
albert.encoder.embedding_hidden_mapping_in.weight: True
albert.encoder.embedding_hidden_mapping_in.bias: True
albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.weight: True
albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.bias: True
albert.encoder.albert_layer_groups.0.albert_layers.0.attention.query.weight: True
albert.encoder.albert_layer_groups.0.albert_layers.0.attention.query.bias: True
albert.encoder.albert_layer_groups.0.albert_layers.0.attention.key.weight: True
albert.encoder.albert_layer_groups.0.albert_layers.0.attention.key.bias: True
albert.encoder.albert_layer_groups.0.albert_layers.0.attention.value.weight: True
albert.encoder.albert_layer_groups.0.albert_layers.0.attention.value.bias: True
alb

Training: 100%|██████████████████████████████████████████████████████████████████████| 267/267 [01:32<00:00,  2.89it/s]


Train loss 1.4170479828052307, accuracy 0.37862827715355807


Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 35/35 [00:04<00:00,  7.87it/s]


Val loss 1.2039932761873517, accuracy 0.4713896457765667
Epoch 2/10
--------------------------------------------------


Training: 100%|██████████████████████████████████████████████████████████████████████| 267/267 [01:31<00:00,  2.92it/s]


Train loss 1.1172361592674969, accuracy 0.5070224719101124


Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 35/35 [00:04<00:00,  7.86it/s]


Val loss 1.1543403540338788, accuracy 0.4786557674841053
Epoch 3/10
--------------------------------------------------


Training: 100%|██████████████████████████████████████████████████████████████████████| 267/267 [01:31<00:00,  2.92it/s]


Train loss 0.971652243244514, accuracy 0.5773642322097379


Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 35/35 [00:04<00:00,  7.83it/s]


Val loss 1.1019754460879734, accuracy 0.5131698455949137
Epoch 4/10
--------------------------------------------------


Training: 100%|██████████████████████████████████████████████████████████████████████| 267/267 [01:32<00:00,  2.90it/s]


Train loss 0.817343610279569, accuracy 0.6645599250936329


Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 35/35 [00:04<00:00,  7.78it/s]


Val loss 1.146733592237745, accuracy 0.5331516802906449
Epoch 5/10
--------------------------------------------------


Training: 100%|██████████████████████████████████████████████████████████████████████| 267/267 [01:32<00:00,  2.88it/s]


Train loss 0.678650792953227, accuracy 0.7437968164794008


Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 35/35 [00:04<00:00,  7.69it/s]


Val loss 1.2285424590110778, accuracy 0.49227974568574023
Epoch 6/10
--------------------------------------------------


Training: 100%|██████████████████████████████████████████████████████████████████████| 267/267 [01:33<00:00,  2.85it/s]


Train loss 0.5367936474107178, accuracy 0.8195224719101124


Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 35/35 [00:04<00:00,  7.72it/s]


Val loss 1.2654025963374547, accuracy 0.5149863760217983
Epoch 7/10
--------------------------------------------------


Training: 100%|██████████████████████████████████████████████████████████████████████| 267/267 [01:34<00:00,  2.84it/s]


Train loss 0.39620456841777774, accuracy 0.8854166666666666


Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 35/35 [00:04<00:00,  7.63it/s]

Val loss 1.4068447896412442, accuracy 0.49409627611262485
Early stopping triggered





In [2]:
def predict_probabilities(model, data_loader,device):
    model.eval()
    sentence_probs = []

    with torch.no_grad():
        for d in tqdm(data_loader, desc="Predicting"):
            input_ids = d['input_ids'].to(device)
            attention_mask = d['attention_mask'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            sentence_probs.extend(probs.cpu().numpy())

    return np.array(sentence_probs)

model.load_state_dict(torch.load('albert_Compare_model.pth'))
true_labels = test_df['label'].values

simple_preds = []

with torch.no_grad():
    for d in tqdm(test_data_loader, desc="Predicting Simple"):
        input_ids = d['input_ids'].to(device)
        attention_mask = d['attention_mask'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs.logits, dim=1)
        simple_preds.extend(preds.cpu().numpy())

simple_accuracy = accuracy_score(true_labels, simple_preds)
print(f'Simple Accuracy: {simple_accuracy}')


def calculate_uncertainty_weights(predictions):
    mean_prediction = np.mean(predictions, axis=0)
    variance = np.var(predictions, axis=0) + 1e-10
    uncertainties = []

    for prediction in predictions:
        diff_square = (prediction - mean_prediction) ** 2
        temp_var = 0.5 * np.log(2 * math.pi * variance)
        temp_max = np.maximum(0, temp_var)
        final_uncertainty = temp_max + (diff_square / (math.sqrt(3) * variance))
        uncertainties.append(final_uncertainty)

    weights = 1 / (np.array(uncertainties) + 1e-10)
    return weights / weights.sum(axis=0)  


def get_grouped_probs(data_loader, model,device):
    sentence_probs = predict_probabilities(model, data_loader,device)
    return [sentence_probs[i:i+3] for i in range(0, len(sentence_probs), 3)]


def compute_weighted_predictions(grouped_probs, weight_type="average"):
    weighted_probs = []
    for group in grouped_probs:
        if weight_type == "average":
            weights = np.array([1/3, 1/3, 1/3])
        elif weight_type == "uncertainty":
            weights = calculate_uncertainty_weights(group)
        else:
            raise ValueError(f"Unknown weight type: {weight_type}")
        weighted_probs.append(np.average(group, axis=0, weights=weights))
    return np.array(weighted_probs)

test_simple_probs = predict_probabilities(model, test_data_loader,device)
T_TTA_grouped_probs = get_grouped_probs(T_TTA_test_data_loader, model,device)
R_TTA_grouped_probs = get_grouped_probs(R_TTA_test_data_loader, model,device)
MSTTA_grouped_probs = get_grouped_probs(MSTTA_test_data_loader, model,device)

simple_preds = np.argmax(test_simple_probs, axis=1)
simple_accuracy = accuracy_score(true_labels, simple_preds)
simple_precision = precision_score(true_labels, simple_preds, average='weighted')
simple_recall = recall_score(true_labels, simple_preds, average='weighted')
simple_f1 = f1_score(true_labels, simple_preds, average='weighted')
print(f"Simple Accuracy: {simple_accuracy:.4f}, Precision: {simple_precision:.4f}, Recall: {simple_recall:.4f}, F1: {simple_f1:.4f}")

weight_types = ["average",  "uncertainty"]
results = {}

for weight_type in weight_types:
    T_TTA_weighted_probs = compute_weighted_predictions(T_TTA_grouped_probs, weight_type=weight_type)
    R_TTA_weighted_probs = compute_weighted_predictions(R_TTA_grouped_probs, weight_type=weight_type)
    MSTTA_weighted_probs = compute_weighted_predictions(MSTTA_grouped_probs, weight_type=weight_type)
    
    T_TTA_preds = np.argmax(T_TTA_weighted_probs, axis=1)
    R_TTA_preds = np.argmax(R_TTA_weighted_probs, axis=1)
    MSTTA_preds = np.argmax(MSTTA_weighted_probs, axis=1)
    
    true_labels = test_df['label'].values
    results[f"T_TTA {weight_type}"] = {
        "accuracy": accuracy_score(true_labels, T_TTA_preds),
        "f1": f1_score(true_labels, T_TTA_preds, average='weighted'),
        "precision": precision_score(true_labels, T_TTA_preds, average='weighted'),
        "recall": recall_score(true_labels, T_TTA_preds, average='weighted')
    }
    results[f"R_TTA {weight_type}"] = {
        "accuracy": accuracy_score(true_labels, R_TTA_preds),
        "f1": f1_score(true_labels, R_TTA_preds, average='weighted'),
        "precision": precision_score(true_labels, R_TTA_preds, average='weighted'),
        "recall": recall_score(true_labels, R_TTA_preds, average='weighted')
    }
    results[f"MSTTA {weight_type}"] = {
        "accuracy": accuracy_score(true_labels, MSTTA_preds),
        "f1": f1_score(true_labels, MSTTA_preds, average='weighted'),
        "precision": precision_score(true_labels, MSTTA_preds, average='weighted'),
        "recall": recall_score(true_labels, MSTTA_preds, average='weighted')
    }

for method, metrics in results.items():
    print(f"{method} - Accuracy: {metrics['accuracy']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, F1: {metrics['f1']:.4f}")


Predicting Simple: 100%|███████████████████████████████████████████████████████████████| 70/70 [00:09<00:00,  7.55it/s]


Simple Accuracy: 0.539366515837104


Predicting: 100%|██████████████████████████████████████████████████████████████████████| 70/70 [00:09<00:00,  7.71it/s]
Predicting: 100%|████████████████████████████████████████████████████████████████████| 208/208 [00:27<00:00,  7.64it/s]
Predicting: 100%|████████████████████████████████████████████████████████████████████| 208/208 [00:27<00:00,  7.66it/s]
Predicting: 100%|████████████████████████████████████████████████████████████████████| 208/208 [00:27<00:00,  7.67it/s]


Simple Accuracy: 0.5394, Precision: 0.5498, Recall: 0.5394, F1: 0.5350
T_TTA average - Accuracy: 0.5367, Precision: 0.5528, Recall: 0.5367, F1: 0.5311
R_TTA average - Accuracy: 0.5362, Precision: 0.5547, Recall: 0.5362, F1: 0.5301
MSTTA average - Accuracy: 0.5425, Precision: 0.5599, Recall: 0.5425, F1: 0.5361
T_TTA uncertainty - Accuracy: 0.5357, Precision: 0.5500, Recall: 0.5357, F1: 0.5306
R_TTA uncertainty - Accuracy: 0.5398, Precision: 0.5553, Recall: 0.5398, F1: 0.5347
MSTTA uncertainty - Accuracy: 0.5443, Precision: 0.5580, Recall: 0.5443, F1: 0.5394
