# COVID‑19 Misinformation Detection – BERT Variants

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from transformers import BertForSequenceClassification, BertTokenizer, BertConfig, BertModel
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, f1_score
import time
import datetime
import re
import random
from string import punctuation
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader
import gc

# Set device
torch.manual_seed(0)
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
if use_cuda:
    torch.cuda.manual_seed(0)

print(f"Using GPU: {use_cuda}")

# Load and preprocess data
train = pd.read_csv('Constraint_Train.csv')
val = pd.read_csv('Constraint_Val.csv')
train["label"] = train["label"].map({"real": 1, "fake": 0})
val["label"] = val["label"].map({"real": 1, "fake": 0})

# Combine datasets
data = pd.concat([train, val], axis=0, ignore_index=True).drop(["id"], axis=1)

# Initialize tokenizer
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

tweets = data.tweet.values
labels = data.label.values

# Text preprocessing function
def preprocess(data):
    # Remove URL and hashtag
    for i in range(data.shape[0]):
        text = data[i].lower()
        text1 = ''.join([word+" " for word in text.split()])
        data[i] = text1

    # Regular expressions for cleaning
    giant_url_regex = ('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|'
        '[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
    mention_regex = '@[\w\-]+'
    hashtag_regex = '#[\w\-]+'
    space_pattern = '\s+'

    for i in range(data.shape[0]):
        text_string = data[i]
        parsed_text = re.sub(hashtag_regex, '', text_string)
        parsed_text = re.sub(giant_url_regex, '', parsed_text)
        parsed_text = re.sub(mention_regex, '', parsed_text)
        # Remove punctuation
        parsed_text = re.sub(r"[{}]+".format(punctuation), '', parsed_text)
        parsed_text = re.sub(space_pattern, ' ', parsed_text)
        data[i] = parsed_text
    return data

tweets = preprocess(tweets)
print("Sample processed tweet:", tweets[0])

# Check the length distribution of tweets
max_len = 0
ind = [100, 200, 300, 400, 500, 512]
for i in ind:
    count = 0
    for tweet in tweets:
        max_len = max(max_len, len(tweet))
        if len(tweet) > i:
            count += 1
    print(f"Count of sentence length over {i} is: {count}")
print('Max sentence length: ', max_len)

# Tokenize and encode data
input_ids = []
attention_masks = []
for tweet in tweets:
    encoded_dict = tokenizer.encode_plus(
                        tweet,
                        add_special_tokens = True,
                        max_length = 512,
                        truncation = True,
                        padding = 'max_length',
                        return_attention_mask = True,
                        return_tensors = 'pt',
                   )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

# Convert lists to tensors
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(labels)

# Print sample
print('Original: ', tweets[0])
print('Token IDs:', input_ids[0])

# Create dataset and split to train/validation
dataset = TensorDataset(input_ids, attention_masks, labels)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

print(f'{train_size:>5,} training samples')
print(f'{val_size:>5,} validation samples')

# Create DataLoaders
batch_size = 16
train_dataloader = DataLoader(
            train_dataset,
            shuffle = True,
            batch_size = batch_size
        )

validation_dataloader = DataLoader(
            val_dataset,
            shuffle = False,
            batch_size = batch_size
        )

# Utility function for time formatting
def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

# BERT Model Training Function (Unfrozen)
def train_bert_unfrozen():
    print("Training BERT unfrozen model...")
    # Initialize model
    model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",
        num_labels = 2,
        output_attentions = False,
        output_hidden_states = True,
    )
    model.to(device)

    # Set parameters
    optimizer = AdamW(model.parameters(),
                    lr = 5e-5,
                    eps = 1e-8
                    )
    epochs = 4

    # Set seed
    seed_val = 42
    random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    # Training loop
    training_stats = []
    total_t0 = time.time()
    best_accuracy = 0
    best_model_state = None

    for epoch_i in range(0, epochs):
        # Training
        print("")
        print(f'Epoch {epoch_i + 1} / {epochs}')
        print('Training...')

        t0 = time.time()
        total_train_loss = 0
        total_train_accuracy = 0
        model.train()

        for step, batch in enumerate(train_dataloader):
            input_ids = batch[0].to(device)
            input_mask = batch[1].to(device)
            labels = batch[2].to(device)

            model.zero_grad()

            # Forward pass
            outputs = model(input_ids,
                           token_type_ids=None,
                           attention_mask=input_mask,
                           labels=labels)

            loss = outputs.loss
            logits = outputs.logits

            total_train_loss += loss.item()

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # Calculate training accuracy
            pred = torch.argmax(logits, dim=1)
            total_train_accuracy += torch.sum(pred == labels).item()

        # Calculate average metrics
        avg_train_accuracy = total_train_accuracy / len(train_dataloader.dataset)
        avg_train_loss = total_train_loss / len(train_dataloader.dataset)
        print(f"  Accuracy: {avg_train_accuracy}")
        print(f"  Training loss: {avg_train_loss}")

        # Validation
        print("")
        print("Validation...")
        model.eval()
        total_eval_accuracy = 0
        total_eval_loss = 0
        y_true = []
        y_pred = []

        # No gradient calculation needed for validation
        with torch.no_grad():
            for batch in validation_dataloader:
                input_ids = batch[0].to(device)
                input_mask = batch[1].to(device)
                labels = batch[2].to(device)

                # Forward pass
                outputs = model(input_ids,
                               token_type_ids=None,
                               attention_mask=input_mask,
                               labels=labels)

                loss = outputs.loss
                logits = outputs.logits

                total_eval_loss += loss.item()

                # Calculate validation accuracy
                pred = torch.argmax(logits, dim=1)
                total_eval_accuracy += torch.sum(pred == labels).item()

                # Store true labels and predictions for metrics
                y_true.append(labels.cpu())
                y_pred.append(pred.cpu())

        # Calculate average metrics
        avg_val_accuracy = total_eval_accuracy / len(validation_dataloader.dataset)
        print(f"  Accuracy: {avg_val_accuracy}")
        avg_val_loss = total_eval_loss / len(validation_dataloader.dataset)
        print(f"  Validation loss: {avg_val_loss}")

        training_time = format_time(time.time() - t0)
        print()

        # Concatenate batches
        y_true = torch.cat(y_true).numpy()
        y_pred = torch.cat(y_pred).numpy()

        # Calculate metrics
        print(f"This epoch took: {training_time}")
        roc_auc = roc_auc_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        print(f'ROC-AUC score: {roc_auc}')
        print(f'F1 score: {f1}')
        print()

        # Store stats
        training_stats.append(
            {
                'epoch': epoch_i + 1,
                'Train Accur.': avg_train_accuracy,
                'Training Loss': avg_train_loss,
                'Valid. Loss': avg_val_loss,
                'Valid. Accur.': avg_val_accuracy,
                'ROC-AUC': roc_auc,
                'F1': f1,
                'Training Time': training_time,
            }
        )

        # Save best model
        if avg_val_accuracy > best_accuracy:
            best_accuracy = avg_val_accuracy
            best_model_state = model.state_dict().copy()

    print()
    print("="*10)
    print("Summary")
    print(f"Total time {format_time(time.time()-total_t0)}")
    print(f"Best validation accuracy: {best_accuracy}")

    # Create a new model instance and load the best saved weights
    final_model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",
        num_labels = 2,
        output_attentions = False,
        output_hidden_states = True,
    )
    final_model.load_state_dict(best_model_state)
    final_model.to(device)

    # Save the model state instead of the entire model
    torch.save(final_model.state_dict(), "bert_unfrozen.pt")

    return final_model, training_stats

# Custom BERT model with frozen encoder
class BertFrozenEncoder(nn.Module):
    def __init__(self):
        super(BertFrozenEncoder, self).__init__()
        # Load base BERT model
        self.bert = BertModel.from_pretrained("bert-base-uncased")

        # Freeze all BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False

        # Add classifier layer
        self.classifier = nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        # Get BERT outputs
        outputs = self.bert(input_ids=input_ids,
                           attention_mask=attention_mask,
                           token_type_ids=token_type_ids)

        # Get CLS token representation (for classification)
        pooled_output = outputs.pooler_output

        # Get logits
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        # Return outputs similar to BertForSequenceClassification
        return type('BertOutput', (), {
            'loss': loss,
            'logits': logits,
            'hidden_states': outputs.hidden_states,
            'attentions': outputs.attentions
        })

# BERT Model Training Function (Frozen)
def train_bert_frozen():
    print("Training BERT frozen model...")
    # Initialize model
    model = BertFrozenEncoder()
    model.to(device)

    # Set parameters
    optimizer = AdamW(model.classifier.parameters(),  # Only train classifier
                    lr = 5e-5,
                    eps = 1e-8
                    )
    epochs = 4

    # Set seed
    seed_val = 42
    random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    # Training loop
    training_stats = []
    total_t0 = time.time()
    best_accuracy = 0
    best_model_state = None

    for epoch_i in range(0, epochs):
        # Training
        print("")
        print(f'Epoch {epoch_i + 1} / {epochs}')
        print('Training...')

        t0 = time.time()
        total_train_loss = 0
        total_train_accuracy = 0
        model.train()

        for step, batch in enumerate(train_dataloader):
            input_ids = batch[0].to(device)
            input_mask = batch[1].to(device)
            labels = batch[2].to(device)

            model.zero_grad()

            # Forward pass
            outputs = model(input_ids,
                           token_type_ids=None,
                           attention_mask=input_mask,
                           labels=labels)

            loss = outputs.loss
            logits = outputs.logits

            total_train_loss += loss.item()

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.classifier.parameters(), 1.0)
            optimizer.step()

            # Calculate training accuracy
            pred = torch.argmax(logits, dim=1)
            total_train_accuracy += torch.sum(pred == labels).item()

        # Calculate average metrics
        avg_train_accuracy = total_train_accuracy / len(train_dataloader.dataset)
        avg_train_loss = total_train_loss / len(train_dataloader.dataset)
        print(f"  Accuracy: {avg_train_accuracy}")
        print(f"  Training loss: {avg_train_loss}")

        # Validation
        print("")
        print("Validation...")
        model.eval()
        total_eval_accuracy = 0
        total_eval_loss = 0
        y_true = []
        y_pred = []

        # No gradient calculation needed for validation
        with torch.no_grad():
            for batch in validation_dataloader:
                input_ids = batch[0].to(device)
                input_mask = batch[1].to(device)
                labels = batch[2].to(device)

                # Forward pass
                outputs = model(input_ids,
                               token_type_ids=None,
                               attention_mask=input_mask,
                               labels=labels)

                loss = outputs.loss
                logits = outputs.logits

                total_eval_loss += loss.item()

                # Calculate validation accuracy
                pred = torch.argmax(logits, dim=1)
                total_eval_accuracy += torch.sum(pred == labels).item()

                # Store true labels and predictions for metrics
                y_true.append(labels.cpu())
                y_pred.append(pred.cpu())

        # Calculate average metrics
        avg_val_accuracy = total_eval_accuracy / len(validation_dataloader.dataset)
        print(f"  Accuracy: {avg_val_accuracy}")
        avg_val_loss = total_eval_loss / len(validation_dataloader.dataset)
        print(f"  Validation loss: {avg_val_loss}")

        training_time = format_time(time.time() - t0)
        print()

        # Concatenate batches
        y_true = torch.cat(y_true).numpy()
        y_pred = torch.cat(y_pred).numpy()

        # Calculate metrics
        print(f"This epoch took: {training_time}")
        roc_auc = roc_auc_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        print(f'ROC-AUC score: {roc_auc}')
        print(f'F1 score: {f1}')
        print()

        # Store stats
        training_stats.append(
            {
                'epoch': epoch_i + 1,
                'Train Accur.': avg_train_accuracy,
                'Training Loss': avg_train_loss,
                'Valid. Loss': avg_val_loss,
                'Valid. Accur.': avg_val_accuracy,
                'ROC-AUC': roc_auc,
                'F1': f1,
                'Training Time': training_time,
            }
        )

        # Save best model
        if avg_val_accuracy > best_accuracy:
            best_accuracy = avg_val_accuracy
            best_model_state = model.state_dict().copy()

    print()
    print("="*10)
    print("Summary")
    print(f"Total time {format_time(time.time()-total_t0)}")
    print(f"Best validation accuracy: {best_accuracy}")

    # Create a new model instance and load the best saved weights
    final_model = BertFrozenEncoder()
    final_model.load_state_dict(best_model_state)
    final_model.to(device)

    # Save the model state instead of the entire model
    torch.save(final_model.state_dict(), "bert_frozen.pt")

    return final_model, training_stats

# Main execution flow
def main():
    # Train unfrozen BERT model
    print("Starting training for unfrozen BERT model...")
    unfrozen_model, unfrozen_stats = train_bert_unfrozen()

    # Free up GPU memory
    del unfrozen_model
    torch.cuda.empty_cache()
    gc.collect()

    # Train frozen BERT model
    print("Starting training for frozen BERT model...")
    frozen_model, frozen_stats = train_bert_frozen()

    # Compare performance metrics
    print("\nPerformance comparison:")
    print("\nUnfrozen BERT:")
    for epoch, stat in enumerate(unfrozen_stats):
        print(f"Epoch {epoch+1}: Accuracy = {stat['Valid. Accur.']:.4f}, F1 = {stat['F1']:.4f}")

    print("\nFrozen BERT:")
    for epoch, stat in enumerate(frozen_stats):
        print(f"Epoch {epoch+1}: Accuracy = {stat['Valid. Accur.']:.4f}, F1 = {stat['F1']:.4f}")

    # Plot training and validation loss for both models
    plt.figure(figsize=(12, 6))

    # Plot training loss
    plt.subplot(1, 2, 1)
    plt.plot([stat['Training Loss'] for stat in frozen_stats], label='Frozen BERT')
    plt.plot([stat['Training Loss'] for stat in unfrozen_stats], label='Unfrozen BERT')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot validation loss
    plt.subplot(1, 2, 2)
    plt.plot([stat['Valid. Loss'] for stat in frozen_stats], label='Frozen BERT')
    plt.plot([stat['Valid. Loss'] for stat in unfrozen_stats], label='Unfrozen BERT')
    plt.title('Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('loss_comparison.png')
    plt.close()

    # Plot training and validation accuracy for both models
    plt.figure(figsize=(12, 6))

    # Plot training accuracy
    plt.subplot(1, 2, 1)
    plt.plot([stat['Train Accur.'] for stat in frozen_stats], label='Frozen BERT')
    plt.plot([stat['Train Accur.'] for stat in unfrozen_stats], label='Unfrozen BERT')
    plt.title('Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    # Plot validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot([stat['Valid. Accur.'] for stat in frozen_stats], label='Frozen BERT')
    plt.plot([stat['Valid. Accur.'] for stat in unfrozen_stats], label='Unfrozen BERT')
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('accuracy_comparison.png')
    plt.close()

    print("\nTraining complete. Model states saved to 'bert_frozen.pt' and 'bert_unfrozen.pt'")
    print("Performance comparison plots saved to 'loss_comparison.png' and 'accuracy_comparison.png'")

if __name__ == "__main__":
    # Install required packages if running on Colab
    try:
        import google.colab
        print("Running on Colab, installing required packages...")
        !pip install -q transformers matplotlib
    except:
        print("Not running on Colab")

    # Run the main function
    main()

## Evaluation

In [18]:
# Import necessary libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from transformers import BertForSequenceClassification, BertTokenizer, BertConfig, BertModel
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix, classification_report
import time
import datetime
import re
import random
import seaborn as sns
from string import punctuation
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader
from collections import Counter
from wordcloud import WordCloud, STOPWORDS
import gc

# Set device
torch.manual_seed(0)
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
if use_cuda:
    torch.cuda.manual_seed(0)

print(f"Using GPU: {use_cuda}")

# Load and preprocess data
train = pd.read_csv('Constraint_Train.csv')
val = pd.read_csv('Constraint_Val.csv')
train["label"] = train["label"].map({"real": 1, "fake": 0})
val["label"] = val["label"].map({"real": 1, "fake": 0})

# Combine datasets
data = pd.concat([train, val], axis=0, ignore_index=True).drop(["id"], axis=1)

# Initialize tokenizer
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

tweets = data.tweet.values
labels = data.label.values

# Text preprocessing function
def preprocess(data):
    # Remove URL and hashtag
    for i in range(data.shape[0]):
        text = data[i].lower()
        text1 = ''.join([word+" " for word in text.split()])
        data[i] = text1

    # Regular expressions for cleaning
    giant_url_regex = ('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|'
        '[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
    mention_regex = '@[\w\-]+'
    hashtag_regex = '#[\w\-]+'
    space_pattern = '\s+'

    for i in range(data.shape[0]):
        text_string = data[i]
        parsed_text = re.sub(hashtag_regex, '', text_string)
        parsed_text = re.sub(giant_url_regex, '', parsed_text)
        parsed_text = re.sub(mention_regex, '', parsed_text)
        # Remove punctuation
        parsed_text = re.sub(r"[{}]+".format(punctuation), '', parsed_text)
        parsed_text = re.sub(space_pattern, ' ', parsed_text)
        data[i] = parsed_text
    return data

tweets = preprocess(tweets)
print("Sample processed tweet:", tweets[0])

# Tokenize and encode data
input_ids = []
attention_masks = []
for tweet in tweets:
    encoded_dict = tokenizer.encode_plus(
                        tweet,
                        add_special_tokens = True,
                        max_length = 512,
                        truncation = True,
                        padding = 'max_length',
                        return_attention_mask = True,
                        return_tensors = 'pt',
                   )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

# Convert lists to tensors
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(labels)

# Create dataset and split to train/validation
dataset = TensorDataset(input_ids, attention_masks, labels)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

print(f'{train_size:>5,} training samples')
print(f'{val_size:>5,} validation samples')

# Create DataLoaders for evaluation
batch_size = 16
validation_dataloader = DataLoader(
            val_dataset,
            shuffle = False,
            batch_size = batch_size
        )

# Custom BERT model with frozen encoder (needed for loading)
class BertFrozenEncoder(nn.Module):
    def __init__(self):
        super(BertFrozenEncoder, self).__init__()
        # Load base BERT model
        self.bert = BertModel.from_pretrained("bert-base-uncased")

        # Freeze all BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False

        # Add classifier layer
        self.classifier = nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        # Get BERT outputs
        outputs = self.bert(input_ids=input_ids,
                          attention_mask=attention_mask,
                          token_type_ids=token_type_ids)

        # Get CLS token representation (for classification)
        pooled_output = outputs.pooler_output

        # Get logits
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        # Return outputs similar to BertForSequenceClassification
        return type('BertOutput', (), {
            'loss': loss,
            'logits': logits,
            'hidden_states': outputs.hidden_states,
            'attentions': outputs.attentions
        })

# Extract important features
def get_feature_importance(model, validation_dataloader, model_type="unfrozen", num_samples=100):
    """Analyze feature importance based on attention weights"""
    print(f"\nExtracting feature importance for {model_type} model...")

    # Sample data
    sample_input_ids = []
    sample_masks = []
    sample_labels = []

    count = 0
    for batch in validation_dataloader:
        if count >= num_samples:
            break

        batch_size = batch[0].size(0)
        samples_to_take = min(batch_size, num_samples - count)

        sample_input_ids.append(batch[0][:samples_to_take])
        sample_masks.append(batch[1][:samples_to_take])
        sample_labels.append(batch[2][:samples_to_take])

        count += samples_to_take

    sample_input_ids = torch.cat(sample_input_ids, dim=0).to(device)
    sample_masks = torch.cat(sample_masks, dim=0).to(device)
    sample_labels = torch.cat(sample_labels, dim=0).to(device)

    model.eval()
    word_importance = {}

    with torch.no_grad():
        # Get model outputs with attention weights
        if model_type == "unfrozen":
            outputs = model(input_ids=sample_input_ids,
                          attention_mask=sample_masks,
                          labels=sample_labels,
                          output_attentions=True)
            attentions = outputs.attentions[-1]  # Last layer attention
        else:
            # For the custom model, attention may need to be accessed differently
            outputs = model.bert(input_ids=sample_input_ids,
                               attention_mask=sample_masks,
                               output_attentions=True)
            attentions = outputs.attentions[-1]

    # Average attention weights across heads and samples
    avg_attention = attentions.mean(dim=[0, 1])  # Shape: [seq_len, seq_len]

    # Get attention from CLS token (used for classification)
    cls_attention = avg_attention[0, :].cpu().numpy()

    # Map tokens to attention weights
    for i in range(min(100, len(cls_attention))):  # Look at top tokens
        if cls_attention[i] > 0.01:  # Only consider significant attention
            token_id = sample_input_ids[0][i].item()
            token = tokenizer.convert_ids_to_tokens([token_id])[0]
            if token not in word_importance:
                word_importance[token] = cls_attention[i]
            else:
                word_importance[token] = max(word_importance[token], cls_attention[i])

    # Sort by importance
    sorted_words = sorted(word_importance.items(), key=lambda x: x[1], reverse=True)
    top_words = sorted_words[:30]

    print(f"Top 30 important words for {model_type} model:")
    for token, importance in top_words:
        print(f"{token}: {importance:.4f}")

    # Create word cloud for visualization
    wordcloud = WordCloud(width=800, height=400, background_color="white",
                        colormap="viridis", max_words=100)
    wordcloud.generate_from_frequencies(word_importance)
    plt.figure(figsize=(10, 6))
    plt.imshow(wordcloud, interpolation="bilinear")
    plt.axis("off")
    plt.title(f"Important Features - {model_type.capitalize()} BERT")
    plt.savefig(f"feature_importance_{model_type}.png")
    plt.close()

    return word_importance

# Error Analysis Function
def analyze_errors(model, dataloader, model_type="bert_unfrozen"):
    """Analyze model errors to understand strengths and weaknesses"""
    print(f"\nAnalyzing errors for {model_type} model...")
    model.eval()
    all_preds = []
    all_labels = []
    all_texts = []
    all_probs = []  # For confidence analysis

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch[0].to(device)
            input_mask = batch[1].to(device)
            labels = batch[2].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=input_mask)

            # Get predictions
            if model_type == "bert_unfrozen":
                logits = outputs.logits
            else:  # For custom model
                logits = outputs.logits

            probs = torch.nn.functional.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)

            # Convert to lists
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

            # Get original texts
            for i in range(input_ids.size(0)):
                text = tokenizer.decode(input_ids[i], skip_special_tokens=True)
                all_texts.append(text)

    # Separate into correct and incorrect predictions
    correct = []
    incorrect = []
    for i in range(len(all_preds)):
        if all_preds[i] == all_labels[i]:
            correct.append((all_texts[i], all_labels[i], all_probs[i]))
        else:
            incorrect.append((all_texts[i], all_labels[i], all_preds[i], all_probs[i]))

    # Analysis of errors
    print(f"Total examples: {len(all_preds)}")
    print(f"Correct predictions: {len(correct)} ({len(correct)/len(all_preds)*100:.2f}%)")
    print(f"Incorrect predictions: {len(incorrect)} ({len(incorrect)/len(all_preds)*100:.2f}%)")

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Fake', 'Real'],
                yticklabels=['Fake', 'Real'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix - {model_type}')
    plt.savefig(f'confusion_matrix_{model_type}.png')
    plt.close()

    # Detailed classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=['Fake', 'Real']))

    # Analyze error patterns
    real_as_fake = [item for item in incorrect if item[1] == 1 and item[2] == 0]
    fake_as_real = [item for item in incorrect if item[1] == 0 and item[2] == 1]

    # Length analysis
    real_as_fake_lengths = [len(item[0].split()) for item in real_as_fake]
    fake_as_real_lengths = [len(item[0].split()) for item in fake_as_real]

    print(f"\nReal news classified as fake - avg length: {np.mean(real_as_fake_lengths):.2f} words")
    print(f"Fake news classified as real - avg length: {np.mean(fake_as_real_lengths):.2f} words")

    # Word frequency in errors
    def get_top_words(texts, n=20):
        words = []
        for text in texts:
            words.extend(text.lower().split())
        return Counter(words).most_common(n)

    print("\nMost common words in real news classified as fake:")
    real_as_fake_texts = [item[0] for item in real_as_fake]
    print(get_top_words(real_as_fake_texts))

    print("\nMost common words in fake news classified as real:")
    fake_as_real_texts = [item[0] for item in fake_as_real]
    print(get_top_words(fake_as_real_texts))

    # Confidence analysis
    real_as_fake_conf = [item[3][0] for item in real_as_fake]  # Confidence for fake class
    fake_as_real_conf = [item[3][1] for item in fake_as_real]  # Confidence for real class

    print(f"\nAvg confidence when real news classified as fake: {np.mean(real_as_fake_conf):.4f}")
    print(f"Avg confidence when fake news classified as real: {np.mean(fake_as_real_conf):.4f}")

    # Generate word clouds for misclassified examples
    if real_as_fake_texts:
        wordcloud = WordCloud(width=800, height=400, background_color="white",
                            colormap="Reds", max_words=100, collocations=False)
        wordcloud.generate(" ".join(real_as_fake_texts))
        plt.figure(figsize=(10, 6))
        plt.imshow(wordcloud, interpolation="bilinear")
        plt.axis("off")
        plt.title(f"Words in Real News Classified as Fake - {model_type}")
        plt.savefig(f"real_as_fake_wordcloud_{model_type}.png")
        plt.close()

    if fake_as_real_texts:
        wordcloud = WordCloud(width=800, height=400, background_color="white",
                            colormap="Blues", max_words=100, collocations=False)
        wordcloud.generate(" ".join(fake_as_real_texts))
        plt.figure(figsize=(10, 6))
        plt.imshow(wordcloud, interpolation="bilinear")
        plt.axis("off")
        plt.title(f"Words in Fake News Classified as Real - {model_type}")
        plt.savefig(f"fake_as_real_wordcloud_{model_type}.png")
        plt.close()

    # Return error examples for further analysis
    return real_as_fake, fake_as_real, {
        'predictions': all_preds,
        'true_labels': all_labels,
        'texts': all_texts,
        'accuracy': len(correct) / len(all_preds),
        'f1': f1_score(all_labels, all_preds)
    }

# Analyze example confidence
def analyze_confidence(model, dataloader, model_type="unfrozen"):
    """Analyze prediction confidence patterns"""
    print(f"\nAnalyzing prediction confidence for {model_type} model...")
    model.eval()
    confidences = []
    labels = []
    correct_pred_conf = []
    incorrect_pred_conf = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch[0].to(device)
            input_mask = batch[1].to(device)
            batch_labels = batch[2].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=input_mask)

            # Get predictions
            if model_type == "bert_unfrozen":
                logits = outputs.logits
            else:  # For custom model
                logits = outputs.logits

            probs = torch.nn.functional.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)

            # Get confidence scores (probability of predicted class)
            batch_confidences = torch.gather(probs, 1, preds.unsqueeze(1)).squeeze(1)

            # Separate confidences for correct and incorrect predictions
            correct_mask = (preds == batch_labels)

            correct_conf = batch_confidences[correct_mask].cpu().numpy()
            incorrect_conf = batch_confidences[~correct_mask].cpu().numpy()

            correct_pred_conf.extend(correct_conf)
            incorrect_pred_conf.extend(incorrect_conf)

            confidences.extend(batch_confidences.cpu().numpy())
            labels.extend(batch_labels.cpu().numpy())

    # Confidence statistics
    avg_confidence = np.mean(confidences)
    avg_correct_conf = np.mean(correct_pred_conf) if correct_pred_conf else 0
    avg_incorrect_conf = np.mean(incorrect_pred_conf) if incorrect_pred_conf else 0

    print(f"Average confidence: {avg_confidence:.4f}")
    print(f"Average confidence for correct predictions: {avg_correct_conf:.4f}")
    print(f"Average confidence for incorrect predictions: {avg_incorrect_conf:.4f}")

    # Plot confidence distributions
    plt.figure(figsize=(10, 6))
    plt.hist(correct_pred_conf, alpha=0.5, bins=20, label='Correct predictions')
    plt.hist(incorrect_pred_conf, alpha=0.5, bins=20, label='Incorrect predictions')
    plt.xlabel('Confidence')
    plt.ylabel('Count')
    plt.title(f'Confidence Distribution - {model_type.capitalize()} BERT')
    plt.legend()
    plt.savefig(f'confidence_dist_{model_type}.png')
    plt.close()

    return {
        'avg_confidence': avg_confidence,
        'avg_correct_conf': avg_correct_conf,
        'avg_incorrect_conf': avg_incorrect_conf,
        'correct_conf': correct_pred_conf,
        'incorrect_conf': incorrect_pred_conf
    }

# Compare frozen vs unfrozen models
def compare_models(frozen_results, unfrozen_results):
    """Compare performance between frozen and unfrozen BERT models"""
    # Extract predictions and labels
    frozen_preds = frozen_results['predictions']
    unfrozen_preds = unfrozen_results['predictions']
    true_labels = frozen_results['true_labels']  # Same for both

    # Calculate agreement statistics
    agreements = sum(1 for f, u in zip(frozen_preds, unfrozen_preds) if f == u)
    agreement_rate = agreements / len(frozen_preds)

    # Calculate correct agreements/disagreements
    correct_both = sum(1 for f, u, t in zip(frozen_preds, unfrozen_preds, true_labels)
                       if f == u and f == t)
    wrong_both = sum(1 for f, u, t in zip(frozen_preds, unfrozen_preds, true_labels)
                     if f == u and f != t)
    frozen_correct_only = sum(1 for f, u, t in zip(frozen_preds, unfrozen_preds, true_labels)
                             if f != u and f == t)
    unfrozen_correct_only = sum(1 for f, u, t in zip(frozen_preds, unfrozen_preds, true_labels)
                               if f != u and u == t)

    # Print comparison results
    print("\nModel Comparison - Frozen vs. Unfrozen BERT:")
    print(f"Agreement rate: {agreement_rate:.4f} ({agreements}/{len(frozen_preds)})")
    print(f"Both correct: {correct_both}/{len(frozen_preds)} ({correct_both/len(frozen_preds)*100:.2f}%)")
    print(f"Both wrong: {wrong_both}/{len(frozen_preds)} ({wrong_both/len(frozen_preds)*100:.2f}%)")
    print(f"Only frozen correct: {frozen_correct_only}/{len(frozen_preds)} ({frozen_correct_only/len(frozen_preds)*100:.2f}%)")
    print(f"Only unfrozen correct: {unfrozen_correct_only}/{len(frozen_preds)} ({unfrozen_correct_only/len(frozen_preds)*100:.2f}%)")

    # Accuracy comparison
    print(f"\nFrozen model accuracy: {frozen_results['accuracy']:.4f}")
    print(f"Unfrozen model accuracy: {unfrozen_results['accuracy']:.4f}")

    # F1 score comparison
    print(f"Frozen model F1 score: {frozen_results['f1']:.4f}")
    print(f"Unfrozen model F1 score: {unfrozen_results['f1']:.4f}")

    # Create comparison plot
    metrics = ['Accuracy', 'F1 Score']
    frozen_values = [frozen_results['accuracy'], frozen_results['f1']]
    unfrozen_values = [unfrozen_results['accuracy'], unfrozen_results['f1']]

    x = np.arange(len(metrics))
    width = 0.35

    fig, ax = plt.subplots(figsize=(10, 6))
    frozen_bars = ax.bar(x - width/2, frozen_values, width, label='Frozen BERT')
    unfrozen_bars = ax.bar(x + width/2, unfrozen_values, width, label='Unfrozen BERT')

    ax.set_ylabel('Score')
    ax.set_title('Performance Comparison: Frozen vs. Unfrozen BERT')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics)
    ax.legend()

    # Add values on top of bars
    def add_labels(bars):
        for bar in bars:
            height = bar.get_height()
            ax.annotate(f'{height:.4f}',
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0, 3),  # 3 points vertical offset
                        textcoords="offset points",
                        ha='center', va='bottom')

    add_labels(frozen_bars)
    add_labels(unfrozen_bars)

    plt.tight_layout()
    plt.savefig('model_comparison.png')
    plt.close()

    # Analyze examples where models disagree
    disagreements = []
    for i in range(len(frozen_preds)):
        if frozen_preds[i] != unfrozen_preds[i]:
            disagreements.append({
                'text': frozen_results['texts'][i],
                'true_label': true_labels[i],
                'frozen_pred': frozen_preds[i],
                'unfrozen_pred': unfrozen_preds[i]
            })

    # Print some examples where models disagree
    print("\nExamples where models disagree:")
    for i, example in enumerate(disagreements[:5]):
        print(f"\nExample {i+1}:")
        print(f"Text: {example['text'][:100]}...")
        print(f"True label: {'Real' if example['true_label'] == 1 else 'Fake'}")
        print(f"Frozen BERT prediction: {'Real' if example['frozen_pred'] == 1 else 'Fake'}")
        print(f"Unfrozen BERT prediction: {'Real' if example['unfrozen_pred'] == 1 else 'Fake'}")
        print(f"Correct model: {'Both' if (example['frozen_pred'] == example['true_label'] and example['unfrozen_pred'] == example['true_label']) else 'Frozen' if example['frozen_pred'] == example['true_label'] else 'Unfrozen' if example['unfrozen_pred'] == example['true_label'] else 'Neither'}")

    # Save the disagreement examples
    with open('model_disagreements.txt', 'w') as f:
        f.write(f"Total disagreements: {len(disagreements)}\n\n")
        for i, example in enumerate(disagreements):
            f.write(f"Example {i+1}:\n")
            f.write(f"Text: {example['text']}\n")
            f.write(f"True label: {'Real' if example['true_label'] == 1 else 'Fake'}\n")
            f.write(f"Frozen BERT prediction: {'Real' if example['frozen_pred'] == 1 else 'Fake'}\n")
            f.write(f"Unfrozen BERT prediction: {'Real' if example['unfrozen_pred'] == 1 else 'Fake'}\n")
            f.write(f"Correct model: {'Both' if (example['frozen_pred'] == example['true_label'] and example['unfrozen_pred'] == example['true_label']) else 'Frozen' if example['frozen_pred'] == example['true_label'] else 'Unfrozen' if example['unfrozen_pred'] == example['true_label'] else 'Neither'}\n\n")

    return {
        'agreement_rate': agreement_rate,
        'correct_both': correct_both / len(frozen_preds),
        'wrong_both': wrong_both / len(frozen_preds),
        'frozen_correct_only': frozen_correct_only / len(frozen_preds),
        'unfrozen_correct_only': unfrozen_correct_only / len(frozen_preds),
        'disagreements': disagreements
    }

# Main function for loading and analyzing models
def analyze_trained_models():
    print("Loading trained models...")

    # Load unfrozen model
    unfrozen_model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",
        num_labels = 2,
    )
    try:
        # Try to load with state_dict
        unfrozen_model.load_state_dict(torch.load("bert_unfrozen.pt", map_location='cpu'))
        print("Loaded unfrozen model state_dict successfully")
    except Exception as e:
        print(f"Error loading unfrozen model state_dict: {e}")
        print("This may be because your model was saved in a different format.")
        try:
            # Try to load full model with weights_only=False
            unfrozen_model = torch.load("bert_unfrozen.pt", map_location='cpu', weights_only=False)
            print("Loaded unfrozen model successfully with weights_only=False")
        except Exception as e:
            print(f"Error loading unfrozen model: {e}")
            print("Using a newly initialized model instead - results will not be accurate!")

    unfrozen_model.to(device)
    unfrozen_model.eval()

    # Load frozen model
    frozen_model = BertFrozenEncoder()
    try:
        # Try to load with state_dict
        frozen_model.load_state_dict(torch.load("bert_frozen.pt", map_location='cpu'))
        print("Loaded frozen model state_dict successfully")
    except Exception as e:
        print(f"Error loading frozen model state_dict: {e}")
        print("This may be because your model was saved in a different format.")
        try:
            # Try to load full model with weights_only=False
            frozen_model = torch.load("bert_frozen.pt", map_location='cpu', weights_only=False)
            print("Loaded frozen model successfully with weights_only=False")
        except Exception as e:
            print(f"Error loading frozen model: {e}")
            print("Using a newly initialized model instead - results will not be accurate!")

    frozen_model.to(device)
    frozen_model.eval()

    # Analyze feature importance
    try:
        unfrozen_features = get_feature_importance(unfrozen_model, validation_dataloader, "unfrozen")
        frozen_features = get_feature_importance(frozen_model, validation_dataloader, "frozen")
    except Exception as e:
        print(f"Error during feature importance analysis: {e}")
        print("Skipping feature importance analysis.")

    # Analyze confidence patterns
    try:
        unfrozen_conf = analyze_confidence(unfrozen_model, validation_dataloader, "unfrozen")
        frozen_conf = analyze_confidence(frozen_model, validation_dataloader, "frozen")
    except Exception as e:
        print(f"Error during confidence analysis: {e}")
        print("Skipping confidence analysis.")

    # Analyze errors
    print("Analyzing unfrozen model errors...")
    _, _, unfrozen_results = analyze_errors(unfrozen_model, validation_dataloader, "bert_unfrozen")

    print("Analyzing frozen model errors...")
    _, _, frozen_results = analyze_errors(frozen_model, validation_dataloader, "bert_frozen")

    # Compare models
    print("Comparing models...")
    comparison_results = compare_models(frozen_results, unfrozen_results)

    # Combine all results into a summary report
    with open('analysis_summary.txt', 'w') as f:
        f.write("ANALYSIS SUMMARY\n")
        f.write("===============\n\n")

        f.write("MODEL PERFORMANCE\n")
        f.write("----------------\n")
        f.write(f"Frozen BERT accuracy: {frozen_results['accuracy']:.4f}\n")
        f.write(f"Unfrozen BERT accuracy: {unfrozen_results['accuracy']:.4f}\n")
        f.write(f"Frozen BERT F1 score: {frozen_results['f1']:.4f}\n")
        f.write(f"Unfrozen BERT F1 score: {unfrozen_results['f1']:.4f}\n\n")

        f.write("MODEL AGREEMENT ANALYSIS\n")
        f.write("-----------------------\n")
        f.write(f"Agreement rate: {comparison_results['agreement_rate']:.4f}\n")
        f.write(f"Both correct: {comparison_results['correct_both']*100:.2f}%\n")
        f.write(f"Both wrong: {comparison_results['wrong_both']*100:.2f}%\n")
        f.write(f"Only frozen correct: {comparison_results['frozen_correct_only']*100:.2f}%\n")
        f.write(f"Only unfrozen correct: {comparison_results['unfrozen_correct_only']*100:.2f}%\n\n")

        try:
            f.write("CONFIDENCE ANALYSIS\n")
            f.write("------------------\n")
            f.write(f"Frozen model average confidence: {frozen_conf['avg_confidence']:.4f}\n")
            f.write(f"Unfrozen model average confidence: {unfrozen_conf['avg_confidence']:.4f}\n")
            f.write(f"Frozen model confidence on correct predictions: {frozen_conf['avg_correct_conf']:.4f}\n")
            f.write(f"Unfrozen model confidence on correct predictions: {unfrozen_conf['avg_correct_conf']:.4f}\n")
            f.write(f"Frozen model confidence on incorrect predictions: {frozen_conf['avg_incorrect_conf']:.4f}\n")
            f.write(f"Unfrozen model confidence on incorrect predictions: {unfrozen_conf['avg_incorrect_conf']:.4f}\n\n")
        except:
            f.write("Confidence analysis not available.\n\n")

        f.write("CONCLUSIONS\n")
        f.write("-----------\n")
        if frozen_results['accuracy'] > unfrozen_results['accuracy']:
            f.write("The frozen BERT model performed better in terms of accuracy. This suggests that for this COVID-19 fake news detection task, the pre-trained BERT features were sufficient, and fine-tuning the entire model might have led to overfitting.\n\n")
        else:
            f.write("The unfrozen BERT model performed better in terms of accuracy. This suggests that for this COVID-19 fake news detection task, fine-tuning the entire model helped to adapt the features to the specific domain.\n\n")

        if comparison_results['frozen_correct_only'] > comparison_results['unfrozen_correct_only']:
            f.write("The frozen model correctly classified more examples that the unfrozen model missed. This suggests the frozen model might be more robust to certain types of examples.\n\n")
        else:
            f.write("The unfrozen model correctly classified more examples that the frozen model missed. This suggests the fine-tuning process helped capture patterns that the frozen model couldn't detect.\n\n")

    print("\nAnalysis complete. Results saved to files.")

    return unfrozen_model, frozen_model, unfrozen_results, frozen_results, comparison_results

# Run the analysis
if __name__ == "__main__":
    # Install required packages if running on Colab
    try:
        import google.colab
        print("Running on Colab, installing required packages...")
        !pip install -q wordcloud seaborn transformers
    except:
        print("Not running on Colab")

    # Run the analysis function
    unfrozen_model, frozen_model, unfrozen_results, frozen_results, comparison_results = analyze_trained_models()

Using GPU: True
Loading BERT tokenizer...
Sample processed tweet: the cdc currently reports 99031 deaths in general the discrepancies in death counts between different sources are small and explicable the death toll stands at roughly 100000 people today 
7,704 training samples
  856 validation samples
Running on Colab, installing required packages...
Loading trained models...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded unfrozen model state_dict successfully




Loaded frozen model state_dict successfully

Extracting feature importance for unfrozen model...
Top 30 important words for unfrozen model:
trump: 0.0529
##s: 0.0458
faint: 0.0457
[CLS]: 0.0422
is: 0.0409
infected: 0.0405
[PAD]: 0.0400
with: 0.0369
[SEP]: 0.0346
corona: 0.0332
##virus: 0.0318

Extracting feature importance for frozen model...
Top 30 important words for frozen model:
[CLS]: 0.0827
trump: 0.0454
[PAD]: 0.0410
faint: 0.0395
##s: 0.0379
is: 0.0356
infected: 0.0347
[SEP]: 0.0329
with: 0.0315
corona: 0.0285
##virus: 0.0270

Analyzing prediction confidence for unfrozen model...
Average confidence: 0.9977
Average confidence for correct predictions: 0.9985
Average confidence for incorrect predictions: 0.9835

Analyzing prediction confidence for frozen model...
Average confidence: 0.6510
Average confidence for correct predictions: 0.6665
Average confidence for incorrect predictions: 0.5974
Analyzing unfrozen model errors...

Analyzing errors for bert_unfrozen model...
Total exam