In [None]:
# EDS 6397 - Natural Language Processing
# Assignment #5 - Sentence Pair Classification - Task 2: Entailment Detection

# Install necessary packages (if not already installed)
!pip install transformers datasets evaluate torch scikit-learn matplotlib

# Import necessary libraries
import torch
import numpy as np
import evaluate
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import (
    RobertaTokenizer, 
    RobertaForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.model_selection import train_test_split

# Set seed for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the RTE dataset
dataset = load_dataset("glue", "rte")
print(dataset)

# Examine a sample
print(dataset["train"][0])
print(f"Label meaning: {dataset['train'].features['label'].names}")

# As instructed, split the train dataset 80-20 for training and evaluation
train_val_split = dataset["train"].train_test_split(test_size=0.2, seed=seed)
# Rename the splits for clarity
train_dataset_raw = train_val_split["train"]
val_dataset_raw = train_val_split["test"]
test_dataset_raw = dataset["validation"]  # Using the validation set as our test set

print(f"Train set size: {len(train_dataset_raw)}")
print(f"Validation set size: {len(val_dataset_raw)}")
print(f"Test set size: {len(test_dataset_raw)}")

# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(
        examples["sentence1"],
        examples["sentence2"],
        padding="max_length",
        truncation=True,
        max_length=128
    )

# Tokenize all datasets
train_dataset = train_dataset_raw.map(tokenize_function, batched=True)
val_dataset = val_dataset_raw.map(tokenize_function, batched=True)
test_dataset = test_dataset_raw.map(tokenize_function, batched=True)

# Define metrics computation function
metric = evaluate.load("glue", "rte")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

# PART 1: FINE-TUNING WITH FROZEN BASE MODEL
# ==========================================

# Load pre-trained model
model_frozen = RobertaForSequenceClassification.from_pretrained(
    "roberta-base", 
    num_labels=2
)

# Freeze all parameters except the classification head
for param in model_frozen.roberta.parameters():
    param.requires_grad = False

# Count trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

frozen_params = count_parameters(model_frozen)
print(f"Number of trainable parameters (frozen base model): {frozen_params}")

# Define training arguments
training_args_frozen = TrainingArguments(
    output_dir="./results_rte_frozen",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-4,  # Higher learning rate for fewer trainable parameters
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=15,  # Min 10, max 20 epochs as per instructions
    weight_decay=0.01,
    push_to_hub=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=torch.cuda.is_available(),
    report_to="none",  # Disable wandb
    run_name="rte_frozen"  # Explicitly set run name
)

# Set up early stopping
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

# Initialize trainer
trainer_frozen = Trainer(
    model=model_frozen,
    args=training_args_frozen,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback]
)

# Train the model
print("Training frozen model...")
trainer_frozen.train()

# Evaluate on validation set
eval_results_frozen = trainer_frozen.evaluate()
print(f"Validation results (frozen model): {eval_results_frozen}")

# Evaluate on test set
test_results_frozen = trainer_frozen.predict(test_dataset)
print(f"Test results (frozen model):")
print(f"Accuracy: {test_results_frozen.metrics['test_accuracy']}")

# Get detailed metrics for each class
test_preds = np.argmax(test_results_frozen.predictions, axis=1)
test_labels = test_results_frozen.label_ids
class_report = classification_report(test_labels, test_preds, target_names=dataset['train'].features['label'].names, output_dict=True)
print(classification_report(test_labels, test_preds, target_names=dataset['train'].features['label'].names))

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
loss_values = [x.get('loss', 0) for x in trainer_frozen.state.log_history if 'loss' in x]
plt.plot(loss_values)
plt.title("Training Loss (Frozen)")
plt.xlabel("Step")
plt.ylabel("Loss")

plt.subplot(1, 2, 2)
eval_results = [x.get('eval_accuracy', 0) for x in trainer_frozen.state.log_history if 'eval_accuracy' in x]
plt.plot(eval_results)
plt.title("Validation Accuracy (Frozen)")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.tight_layout()
plt.savefig("rte_frozen_training.png")
plt.show()

# PART 2: FINE-TUNING WITH UNFROZEN MODEL
# =======================================

# Load pre-trained model again (all weights trainable)
model_unfrozen = RobertaForSequenceClassification.from_pretrained(
    "roberta-base", 
    num_labels=2
)

# Count trainable parameters
unfrozen_params = count_parameters(model_unfrozen)
print(f"Number of trainable parameters (unfrozen model): {unfrozen_params}")

# Define training arguments (smaller learning rate for all parameters)
training_args_unfrozen = TrainingArguments(
    output_dir="./results_rte_unfrozen",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=8e-6,  # Lower learning rate when fine-tuning all parameters
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=4,  # Min 3, max 5 epochs as per instructions
    weight_decay=0.01,
    push_to_hub=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=torch.cuda.is_available(),
    report_to="none",  # Disable wandb
    run_name="rte_unfrozen"  # Explicitly set run name
)

# Initialize trainer
trainer_unfrozen = Trainer(
    model=model_unfrozen,
    args=training_args_unfrozen,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# Train the model
print("Training unfrozen model...")
trainer_unfrozen.train()

# Evaluate on validation set
eval_results_unfrozen = trainer_unfrozen.evaluate()
print(f"Validation results (unfrozen model): {eval_results_unfrozen}")

# Evaluate on test set
test_results_unfrozen = trainer_unfrozen.predict(test_dataset)
print(f"Test results (unfrozen model):")
print(f"Accuracy: {test_results_unfrozen.metrics['test_accuracy']}")

# Get detailed metrics for each class
test_preds_unfrozen = np.argmax(test_results_unfrozen.predictions, axis=1)
test_labels_unfrozen = test_results_unfrozen.label_ids
class_report_unfrozen = classification_report(test_labels_unfrozen, test_preds_unfrozen, target_names=dataset['train'].features['label'].names, output_dict=True)
print(classification_report(test_labels_unfrozen, test_preds_unfrozen, target_names=dataset['train'].features['label'].names))

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
loss_values = [x.get('loss', 0) for x in trainer_unfrozen.state.log_history if 'loss' in x]
plt.plot(loss_values)
plt.title("Training Loss (Unfrozen)")
plt.xlabel("Step")
plt.ylabel("Loss")

plt.subplot(1, 2, 2)
eval_results = [x.get('eval_accuracy', 0) for x in trainer_unfrozen.state.log_history if 'eval_accuracy' in x]
plt.plot(eval_results)
plt.title("Validation Accuracy (Unfrozen)")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.tight_layout()
plt.savefig("rte_unfrozen_training.png")
plt.show()

# COMPARE RESULTS
frozen_metrics = {
    "accuracy": test_results_frozen.metrics['test_accuracy'],
    f"{dataset['train'].features['label'].names[0]}_precision": class_report[dataset['train'].features['label'].names[0]]["precision"],
    f"{dataset['train'].features['label'].names[0]}_recall": class_report[dataset['train'].features['label'].names[0]]["recall"],
    f"{dataset['train'].features['label'].names[0]}_f1": class_report[dataset['train'].features['label'].names[0]]["f1-score"],
    f"{dataset['train'].features['label'].names[1]}_precision": class_report[dataset['train'].features['label'].names[1]]["precision"],
    f"{dataset['train'].features['label'].names[1]}_recall": class_report[dataset['train'].features['label'].names[1]]["recall"],
    f"{dataset['train'].features['label'].names[1]}_f1": class_report[dataset['train'].features['label'].names[1]]["f1-score"],
}

unfrozen_metrics = {
    "accuracy": test_results_unfrozen.metrics['test_accuracy'],
    f"{dataset['train'].features['label'].names[0]}_precision": class_report_unfrozen[dataset['train'].features['label'].names[0]]["precision"],
    f"{dataset['train'].features['label'].names[0]}_recall": class_report_unfrozen[dataset['train'].features['label'].names[0]]["recall"],
    f"{dataset['train'].features['label'].names[0]}_f1": class_report_unfrozen[dataset['train'].features['label'].names[0]]["f1-score"],
    f"{dataset['train'].features['label'].names[1]}_precision": class_report_unfrozen[dataset['train'].features['label'].names[1]]["precision"],
    f"{dataset['train'].features['label'].names[1]}_recall": class_report_unfrozen[dataset['train'].features['label'].names[1]]["recall"],
    f"{dataset['train'].features['label'].names[1]}_f1": class_report_unfrozen[dataset['train'].features['label'].names[1]]["f1-score"],
}

# Print comparison
print("\nModel Comparison:")
print(f"{'Metric':<25} {'Frozen Base':<15} {'Unfrozen':<15}")
print("-" * 55)
for metric in frozen_metrics:
    print(f"{metric:<25} {frozen_metrics[metric]:<15.4f} {unfrozen_metrics[metric]:<15.4f}")

# Save models (optional)
trainer_frozen.save_model("./final_rte_frozen")
trainer_unfrozen.save_model("./final_rte_unfrozen")