In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score,precision_score, recall_score
from tqdm import tqdm
import math
import random

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

class SST5Dataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len):
        self.data = pd.read_csv(file_path)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sentence = self.data.iloc[idx]['sentence']
        label = self.data.iloc[idx]['label']

       
        encoding = self.tokenizer(
            sentence,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long)
        }

class HuggingFaceEmbedding(nn.Module):
    def __init__(self, model_name="bert-base-uncased", freeze=True):
        super(HuggingFaceEmbedding, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.embedding_model = BertModel.from_pretrained(model_name)
        if freeze:
            for param in self.embedding_model.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        outputs = self.embedding_model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state  

class BiLSTMModel(nn.Module):
    def __init__(self, transformer_model_name, hidden_dim, output_dim, n_layers, dropout):
        super(BiLSTMModel, self).__init__()
        self.embedding = HuggingFaceEmbedding(model_name=transformer_model_name, freeze=True)
        embedding_dim = self.embedding.embedding_model.config.hidden_size
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
                            bidirectional=True, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask):
        embedded = self.embedding(input_ids, attention_mask)
        lstm_out, _ = self.lstm(self.dropout(embedded))
        
        forward_last_hidden = lstm_out[:, -1, :self.lstm.hidden_size]
        backward_last_hidden = lstm_out[:, 0, self.lstm.hidden_size:]
        
        last_hidden = torch.cat((forward_last_hidden, backward_last_hidden), dim=1)

        return self.fc(self.dropout(last_hidden))

def create_data_loader(file_path, tokenizer, max_len, batch_size, shuffle):
    dataset = SST5Dataset(file_path, tokenizer, max_len)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def train_model(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)

def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            predictions.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='weighted')
    return total_loss / len(data_loader), accuracy, f1

def predict_probabilities(model, data_loader, device):
    model.eval()
    sentence_probs = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            outputs = model(input_ids, attention_mask)
            sentence_probs.extend(torch.softmax(outputs, dim=1).cpu().numpy())  
    return np.array(sentence_probs)

def calculate_uncertainty_weights(predictions):
    mean_prediction = np.mean(predictions, axis=0)
    variance = np.var(predictions, axis=0) + 1e-10
    uncertainties = []

    for prediction in predictions:
        diff_square = (prediction - mean_prediction) ** 2
        temp_var = 0.5 * np.log(2 * math.pi * variance)
        temp_max = np.maximum(0, temp_var)
        final_uncertainty = temp_max + (diff_square / (math.sqrt(3) * variance))
        uncertainties.append(final_uncertainty)

    weights = 1 / (np.array(uncertainties) + 1e-10)
    return weights / weights.sum(axis=0)  


def get_grouped_probs(data_loader, model,device):
    sentence_probs = predict_probabilities(model, data_loader,device)
    return [sentence_probs[i:i+3] for i in range(0, len(sentence_probs), 3)]


def compute_weighted_predictions(grouped_probs, weight_type="average"):
    weighted_probs = []
    for group in grouped_probs:
        if weight_type == "average":
            weights = np.array([1/3, 1/3, 1/3])
        elif weight_type == "uncertainty":
            weights = calculate_uncertainty_weights(group)
        else:
            raise ValueError(f"Unknown weight type: {weight_type}")
        weighted_probs.append(np.average(group, axis=0, weights=weights))
    return np.array(weighted_probs)

if __name__ == "__main__":
    
    transformer_model_name = 'bert-base-uncased'
    hidden_dim = 256
    output_dim = 5
    n_layers = 2
    dropout = 0.3
    max_len = 128
    batch_size = 32
    learning_rate = 1e-4
    n_epochs = 20
    early_stopping_patience = 3

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
    train_data = pd.read_csv('train.csv')
    sentences = train_data['sentence'].values
    tokenizer = BertTokenizer.from_pretrained(transformer_model_name)
    
    train_loader = create_data_loader('train.csv', tokenizer, max_len, batch_size,shuffle = True)
    val_loader = create_data_loader('validation.csv', tokenizer, max_len, batch_size,shuffle = False)
    test_data_loader = create_data_loader('test.csv', tokenizer, max_len, batch_size,shuffle = False)
    T_TTA_test_data_loader = create_data_loader('T_TTA_enhanced_test.csv', tokenizer, max_len, batch_size,shuffle = False)
    R_TTA_test_data_loader = create_data_loader('R_TTA_enhanced_test.csv', tokenizer, max_len, batch_size,shuffle = False)
    MSTTA_test_data_loader = create_data_loader('MSTTA_enhanced_test.csv', tokenizer, max_len, batch_size,shuffle = False)
 
    model = BiLSTMModel(transformer_model_name, hidden_dim, output_dim, n_layers, dropout).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(n_epochs):
        train_loss = train_model(model, train_loader, optimizer, criterion, device)
        val_loss, val_accuracy, val_f1 = evaluate_model(model, val_loader, criterion, device)
        print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Val Accuracy: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'bilstm_Compare_model.pth') 
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered")
                break

    
    model.load_state_dict(torch.load('bilstm_Compare_model.pth'))
    test_loss, test_accuracy, test_f1 = evaluate_model(model, test_data_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, Test F1: {test_f1:.4f}")
    
    test_df = pd.read_csv('test.csv')
    true_labels = test_df['label'].values

    test_simple_probs = predict_probabilities(model, test_data_loader,device)
    T_TTA_grouped_probs = get_grouped_probs(T_TTA_test_data_loader, model,device)
    R_TTA_grouped_probs = get_grouped_probs(R_TTA_test_data_loader, model,device)
    MSTTA_grouped_probs = get_grouped_probs(MSTTA_test_data_loader, model,device)

    simple_preds = np.argmax(test_simple_probs, axis=1)
    simple_accuracy = accuracy_score(true_labels, simple_preds)
    simple_precision = precision_score(true_labels, simple_preds, average='weighted')
    simple_recall = recall_score(true_labels, simple_preds, average='weighted')
    simple_f1 = f1_score(true_labels, simple_preds, average='weighted')
    print(f"Simple Accuracy: {simple_accuracy:.4f}, Precision: {simple_precision:.4f}, Recall: {simple_recall:.4f}, F1: {simple_f1:.4f}")

    weight_types = ["average",  "uncertainty"]
    results = {}

    for weight_type in weight_types:
        T_TTA_weighted_probs = compute_weighted_predictions(T_TTA_grouped_probs, weight_type=weight_type)
        R_TTA_weighted_probs = compute_weighted_predictions(R_TTA_grouped_probs, weight_type=weight_type)
        MSTTA_weighted_probs = compute_weighted_predictions(MSTTA_grouped_probs, weight_type=weight_type)

        T_TTA_preds = np.argmax(T_TTA_weighted_probs, axis=1)
        R_TTA_preds = np.argmax(R_TTA_weighted_probs, axis=1)
        MSTTA_preds = np.argmax(MSTTA_weighted_probs, axis=1)
     
        true_labels = test_df['label'].values
        results[f"T_TTA {weight_type}"] = {
            "accuracy": accuracy_score(true_labels, T_TTA_preds),
            "f1": f1_score(true_labels, T_TTA_preds, average='weighted'),
            "precision": precision_score(true_labels, T_TTA_preds, average='weighted'),
            "recall": recall_score(true_labels, T_TTA_preds, average='weighted')
        }
        results[f"R_TTA {weight_type}"] = {
            "accuracy": accuracy_score(true_labels, R_TTA_preds),
            "f1": f1_score(true_labels, R_TTA_preds, average='weighted'),
            "precision": precision_score(true_labels, R_TTA_preds, average='weighted'),
            "recall": recall_score(true_labels, R_TTA_preds, average='weighted')
        }
        results[f"MSTTA {weight_type}"] = {
            "accuracy": accuracy_score(true_labels, MSTTA_preds),
            "f1": f1_score(true_labels, MSTTA_preds, average='weighted'),
            "precision": precision_score(true_labels, MSTTA_preds, average='weighted'),
            "recall": recall_score(true_labels, MSTTA_preds, average='weighted')
        }

    for method, metrics in results.items():
        print(f"{method} - Accuracy: {metrics['accuracy']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, F1: {metrics['f1']:.4f}")



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Test Loss: 1.2066, Test Accuracy: 0.4778, Test F1: 0.4378
Simple Accuracy: 0.4778, Precision: 0.4907, Recall: 0.4778, F1: 0.4378
T_TTA average - Accuracy: 0.4783, Precision: 0.5025, Recall: 0.4783, F1: 0.4327
R_TTA average - Accuracy: 0.4751, Precision: 0.4941, Recall: 0.4751, F1: 0.4314
MSTTA average - Accuracy: 0.4837, Precision: 0.5058, Recall: 0.4837, F1: 0.4389
T_TTA uncertainty - Accuracy: 0.4796, Precision: 0.5022, Recall: 0.4796, F1: 0.4374
R_TTA uncertainty - Accuracy: 0.4760, Precision: 0.4973, Recall: 0.4760, F1: 0.4342
MSTTA uncertainty - Accuracy: 0.4833, Precision: 0.5038, Recall: 0.4833, F1: 0.4400
