In [None]:
import torch
from transformers import RobertaForSequenceClassification, XLNetForSequenceClassification, RobertaTokenizer, XLNetTokenizer
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, roc_curve
import matplotlib.pyplot as plt
import pandas as pd
from torch.nn import functional as F

# Step 1: Load the Dataset
dataset = pd.read_csv('balanced1_dataset.csv')
texts = dataset['text'].tolist()
labels = dataset['label'].tolist()

# Step 2: Load Teacher and Student Models
teacher_model = RobertaForSequenceClassification.from_pretrained('roberta-base')
teacher_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

student_model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased')
student_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

teacher_model.eval()  # Set teacher model to evaluation mode
student_model.train()  # Set student model to training mode

# Step 3: Tokenize the Dataset
teacher_inputs = teacher_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
student_inputs = student_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

labels = torch.tensor(labels)

# Step 4: Create DataLoader
dataset = TensorDataset(student_inputs['input_ids'], student_inputs['attention_mask'], labels)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Step 5: Define Distillation Loss Function
def distillation_loss(student_logits, teacher_logits, temperature=2.0):
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    student_probs = F.log_softmax(student_logits / temperature, dim=-1)
    loss = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
    return loss

# Step 6: Train the Student Model
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-5)

for epoch in range(3):  # Example: 3 epochs
    for batch in dataloader:
        input_ids, attention_mask, labels = batch
        
        # Forward pass through the student model
        student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        student_logits = student_outputs.logits
        
        # Forward pass through the teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits
        
        # Compute distillation loss
        loss = distillation_loss(student_logits, teacher_logits)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# Step 7: Evaluation Metrics and ROC Curve
student_model.eval()  # Set student model to evaluation mode

predictions = []
true_labels = []

with torch.no_grad():
    for batch in dataloader:  # Reuse the same dataloader for evaluation
        input_ids, attention_mask, labels = batch
        outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        
        preds = torch.argmax(logits, dim=-1)
        
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

# Step 8: Calculate Evaluation Metrics
accuracy = accuracy_score(true_labels, predictions)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted')

print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")

# Step 9: Plot ROC Curve (for binary classification)
if len(set(true_labels)) == 2:
    roc_auc = roc_auc_score(true_labels, predictions)
    print(f"ROC-AUC Score: {roc_auc}")

    fpr, tpr, _ = roc_curve(true_labels, predictions)
    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')  # Random classifier line
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC)')
    plt.legend(loc='lower right')
    plt.show()