In [3]:
%%capture
import os
if "COLAB_GPU" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

import pandas as pd
import numpy as np
from sklearn.metrics import (accuracy_score, classification_report, confusion_matrix,
                           precision_score, recall_score, f1_score, roc_auc_score,
                           matthews_corrcoef, cohen_kappa_score, balanced_accuracy_score)
from transformers import TextStreamer
import torch
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import re

# Load and setup the Llama model
print("Loading Llama 3.1 8B model...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct",
    max_seq_length = 8192,
    load_in_4bit = True,
    # token = "hf_...", # use one if using gated models
)

# Setup chat template for Llama
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"},
)

FastLanguageModel.for_inference(model)
print("Model loaded successfully! Using Llama 3.1 8B for multiple choice questions.")

# Load your dataset
print("Loading dataset...")
df = pd.read_csv('/content/hard_120_samples.csv')
print(f"Dataset loaded: {len(df)} rows")
print(f"Columns: {df.columns.tolist()}")
print(f"True label distribution:\n{df['true_label'].value_counts()}")

def predict_multiple_choice(row):
    """
    Predict the correct answer for a multiple choice question using Llama
    Returns tuple: (prediction, raw_response)
    """
    sentence = row['sentence']
    option_A = row['option_A']
    option_B = row['option_B']
    option_C = row['option_C']
    option_D = row['option_D']

    # Create a comprehensive prompt optimized for Llama
    prompt = f"""You are an expert annotator for multiple-choice classification.

Given the sentence below, select the most appropriate answer.

Sentence: "{sentence}"

Options:
• A: {option_A}
• B: {option_B}
• C: {option_C}
• D: {option_D}

**Respond with only a single letter (A, B, C, or D) and nothing else.**

Answer:"""

    messages = [
        {"from": "human", "value": prompt}
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to("cuda")

    # Generate response with parameters optimized for Llama
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs,
            max_new_tokens=10,       # Very limited tokens for single letter answer
            use_cache=True,
            do_sample=False,         # Llama works better with sampling
            temperature=0.01,        # Low temperature for consistency
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    # Decode only the generated part
    raw_response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()

    # Clean and extract prediction
    response_clean = raw_response.replace('\n', ' ').strip().upper()

    # Look for the first occurrence of A, B, C, or D
    prediction = None
    for char in response_clean:
        if char in ['A', 'B', 'C', 'D']:
            prediction = char
            break

    # If no clear answer found, try alternative extraction
    if prediction is None:
        # Check if response starts with a letter
        if response_clean.startswith('A'):
            prediction = 'A'
        elif response_clean.startswith('B'):
            prediction = 'B'
        elif response_clean.startswith('C'):
            prediction = 'C'
        elif response_clean.startswith('D'):
            prediction = 'D'
        # Check if response contains only one of the letters
        elif 'A' in response_clean and all(x not in response_clean for x in ['B', 'C', 'D']):
            prediction = 'A'
        elif 'B' in response_clean and all(x not in response_clean for x in ['A', 'C', 'D']):
            prediction = 'B'
        elif 'C' in response_clean and all(x not in response_clean for x in ['A', 'B', 'D']):
            prediction = 'C'
        elif 'D' in response_clean and all(x not in response_clean for x in ['A', 'B', 'C']):
            prediction = 'D'
        else:
            # Default to A if completely unclear
            prediction = 'A'

    return prediction, raw_response

def evaluate_dataset(df, sample_size=None):
    """
    Evaluate the entire dataset or a sample
    Returns predictions, true_labels, and raw_responses
    """
    if sample_size:
        df_eval = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
        print(f"Evaluating sample of {sample_size} questions...")
    else:
        df_eval = df.copy()
        print(f"Evaluating all {len(df_eval)} questions...")

    predictions = []
    raw_responses = []
    true_labels = df_eval['true_label'].tolist()

    for idx, row in df_eval.iterrows():
        print(f"\nEvaluating {idx+1}/{len(df_eval)}: {row['sentence'][:50]}...")
        print(f"Domain: {row['domain']}")

        try:
            prediction, raw_response = predict_multiple_choice(row)
            predictions.append(prediction)
            raw_responses.append(raw_response)
            print(f"Raw Response: {raw_response}")
            print(f"Prediction: {prediction}, True: {row['true_label']}")
        except Exception as e:
            print(f"Error processing question {idx+1}: {e}")
            predictions.append('A')  # Default to A on error
            raw_responses.append(f"ERROR: {str(e)}")

    return predictions, true_labels, raw_responses, df_eval

def calculate_comprehensive_metrics(predictions, true_labels):
    """
    Calculate and display comprehensive evaluation metrics for multiple choice
    """
    # Calculate accuracy
    correct = sum(1 for p, t in zip(predictions, true_labels) if p == t)
    accuracy = correct / len(predictions)

    # Create confusion matrix
    labels = ['A', 'B', 'C', 'D']
    cm = np.zeros((4, 4), dtype=int)

    for true, pred in zip(true_labels, predictions):
        true_idx = labels.index(true)
        pred_idx = labels.index(pred)
        cm[true_idx][pred_idx] += 1

    # Calculate per-class metrics
    class_accuracies = {}
    class_predictions = {}
    class_totals = {}

    for label in labels:
        class_totals[label] = true_labels.count(label)
        class_predictions[label] = sum(1 for t, p in zip(true_labels, predictions) if t == label and t == p)
        if class_totals[label] > 0:
            class_accuracies[label] = class_predictions[label] / class_totals[label]
        else:
            class_accuracies[label] = 0.0

    print(f"\n{'='*60}")
    print("COMPREHENSIVE EVALUATION METRICS - LLAMA 3.1")
    print(f"{'='*60}")

    print(f"\n--- Overall Metrics ---")
    print(f"Total Questions: {len(predictions)}")
    print(f"Correct Predictions: {correct}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

    print(f"\n--- Per-Class Accuracy ---")
    for label in labels:
        print(f"Option {label}: {class_accuracies[label]:.4f} ({class_predictions[label]}/{class_totals[label]})")

    print(f"\n--- Confusion Matrix ---")
    print("True\\Pred\tA\tB\tC\tD")
    for i, label in enumerate(labels):
        row_str = f"{label}\t\t"
        for j in range(4):
            row_str += f"{cm[i][j]}\t"
        print(row_str)

    # Create confusion matrix visualization
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels,
                yticklabels=labels)
    plt.title('Confusion Matrix - Llama 3.1 8B Multiple Choice')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig('llama_confusion_matrix_mcq.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Calculate prediction distribution
    pred_distribution = {label: predictions.count(label) for label in labels}
    true_distribution = {label: true_labels.count(label) for label in labels}

    print(f"\n--- Label Distribution ---")
    print("True Label Distribution:")
    for label in labels:
        count = true_distribution[label]
        print(f"  Option {label}: {count} ({count/len(true_labels)*100:.2f}%)")

    print("\nPredicted Label Distribution:")
    for label in labels:
        count = pred_distribution[label]
        print(f"  Option {label}: {count} ({count/len(predictions)*100:.2f}%)")

    # Return metrics dictionary
    metrics = {
        'accuracy': accuracy,
        'correct_predictions': correct,
        'total_questions': len(predictions),
        'class_accuracies': class_accuracies,
        'confusion_matrix': cm,
        'pred_distribution': pred_distribution,
        'true_distribution': true_distribution
    }

    return metrics

# Test with a few examples first
print("\n" + "="*60)
print("TESTING LLAMA 3.1 WITH SAMPLE QUESTIONS")
print("="*60)

# Test with first 3 questions
test_df = df.head(3)
for idx, row in test_df.iterrows():
    print(f"\nTest {idx+1}:")
    print(f"Domain: {row['domain']}")
    print(f"Sentence: {row['sentence']}")
    print(f"Options:")
    print(f"  A. {row['option_A']}")
    print(f"  B. {row['option_B']}")
    print(f"  C. {row['option_C']}")
    print(f"  D. {row['option_D']}")
    print(f"True Answer: {row['true_label']}")

    result, raw_response = predict_multiple_choice(row)
    print(f"Model Prediction: {result}")
    print(f"Raw Response: '{raw_response}'")
    print(f"Correct: {'✓' if result == row['true_label'] else '✗'}")

# Evaluate on a sample first (optional for testing)
print("\n" + "="*60)
print("EVALUATING DATASET SAMPLE WITH LLAMA 3.1")
print("="*60)

sample_predictions, sample_true_labels, sample_raw_responses, sample_df_eval = evaluate_dataset(df, sample_size=20)

# Save sample raw predictions
sample_raw_predictions_df = pd.DataFrame({
    'id': sample_df_eval['id'],
    'domain': sample_df_eval['domain'],
    'sentence': sample_df_eval['sentence'],
    'true_label': sample_true_labels,
    'predicted_label': sample_predictions,
    'raw_model_response': sample_raw_responses,
    'correct': [p == t for p, t in zip(sample_predictions, sample_true_labels)],
    'model': 'Llama-3.1-8B',
    'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})

sample_raw_predictions_df.to_csv('llama_raw_model_predictions_sample_mcq.csv', index=False)
print("\nSample predictions saved to 'llama_raw_model_predictions_sample_mcq.csv'")

# Calculate metrics for sample
sample_metrics = calculate_comprehensive_metrics(sample_predictions, sample_true_labels)
print(f"\nSample evaluation completed with {sample_metrics['accuracy']*100:.2f}% accuracy!")

# Run on FULL DATASET
print("\n" + "="*60)
print("EVALUATING FULL DATASET WITH LLAMA 3.1 8B")
print("="*60)

full_predictions, full_true_labels, full_raw_responses, df_eval = evaluate_dataset(df)

# Save full raw predictions
full_raw_predictions_df = pd.DataFrame({
    'id': df_eval['id'],
    'domain': df_eval['domain'],
    'sentence': df_eval['sentence'],
    'option_A': df_eval['option_A'],
    'option_B': df_eval['option_B'],
    'option_C': df_eval['option_C'],
    'option_D': df_eval['option_D'],
    'true_label': full_true_labels,
    'predicted_label': full_predictions,
    'raw_model_response': full_raw_responses,
    'correct': [p == t for p, t in zip(full_predictions, full_true_labels)],
    'model': 'Llama-3.1-8B',
    'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})

full_raw_predictions_df.to_csv('llama_mcq_predictions_full.csv', index=False)
print("\nFull predictions saved to 'llama_mcq_predictions_full.csv'")

# Calculate comprehensive metrics for full dataset
print("\n" + "="*60)
print("FULL DATASET METRICS")
print("="*60)
full_metrics = calculate_comprehensive_metrics(full_predictions, full_true_labels)

# Analyze errors by domain
print("\n" + "="*60)
print("DOMAIN-WISE ANALYSIS")
print("="*60)

domain_results = {}
for domain in df_eval['domain'].unique():
    domain_mask = df_eval['domain'] == domain
    domain_preds = [p for p, m in zip(full_predictions, domain_mask) if m]
    domain_true = [t for t, m in zip(full_true_labels, domain_mask) if m]

    if len(domain_preds) > 0:
        domain_acc = sum(1 for p, t in zip(domain_preds, domain_true) if p == t) / len(domain_preds)
        domain_results[domain] = {
            'accuracy': domain_acc,
            'total': len(domain_preds),
            'correct': sum(1 for p, t in zip(domain_preds, domain_true) if p == t)
        }
        print(f"{domain}: {domain_acc:.4f} ({domain_results[domain]['correct']}/{domain_results[domain]['total']})")

# Save domain analysis
domain_analysis_df = pd.DataFrame.from_dict(domain_results, orient='index')
domain_analysis_df.to_csv('llama_domain_wise_analysis.csv')

# Analyze misclassified examples
misclassified = full_raw_predictions_df[full_raw_predictions_df['correct'] == False]
print(f"\n" + "="*60)
print("ERROR ANALYSIS")
print(f"="*60)
print(f"Total misclassifications: {len(misclassified)}")
print(f"Misclassification rate: {len(misclassified)/len(full_raw_predictions_df)*100:.2f}%")

# Save misclassified examples
misclassified.to_csv('llama_misclassified_examples_mcq.csv', index=False)

# Show sample of misclassified questions
if len(misclassified) > 0:
    print("\nSample of misclassified questions (showing first 5):")
    sample_errors = misclassified.head(5)
    for idx, row in sample_errors.iterrows():
        print(f"\n--- Error {idx+1} ---")
        print(f"Domain: {row['domain']}")
        print(f"Sentence: {row['sentence'][:100]}...")
        print(f"True Label: {row['true_label']}, Predicted: {row['predicted_label']}")
        print(f"Raw Response: '{row['raw_model_response']}'")

# Create distribution visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

labels = ['A', 'B', 'C', 'D']
true_counts = [full_metrics['true_distribution'][l] for l in labels]
pred_counts = [full_metrics['pred_distribution'][l] for l in labels]

# True labels
ax1.bar(labels, true_counts, color='blue', alpha=0.7)
ax1.set_title('True Label Distribution')
ax1.set_xlabel('Option')
ax1.set_ylabel('Count')

# Predicted labels
ax2.bar(labels, pred_counts, color='green', alpha=0.7)
ax2.set_title('Predicted Label Distribution')
ax2.set_xlabel('Option')
ax2.set_ylabel('Count')

plt.tight_layout()
plt.savefig('llama_label_distribution_mcq.png', dpi=300, bbox_inches='tight')
plt.show()

# Save metrics summary
metrics_summary = pd.DataFrame([{
    'Metric': 'Accuracy',
    'Value': full_metrics['accuracy']
}, {
    'Metric': 'Correct Predictions',
    'Value': full_metrics['correct_predictions']
}, {
    'Metric': 'Total Questions',
    'Value': full_metrics['total_questions']
}])

# Add per-class accuracies
for label in ['A', 'B', 'C', 'D']:
    metrics_summary = pd.concat([metrics_summary, pd.DataFrame([{
        'Metric': f'Class {label} Accuracy',
        'Value': full_metrics['class_accuracies'][label]
    }])], ignore_index=True)

metrics_summary.to_csv('llama_evaluation_metrics_summary_mcq.csv', index=False)

# Function to evaluate individual questions
def evaluate_single_question(sentence, option_A, option_B, option_C, option_D):
    """
    Evaluate a single multiple choice question
    """
    row = pd.Series({
        'sentence': sentence,
        'option_A': option_A,
        'option_B': option_B,
        'option_C': option_C,
        'option_D': option_D
    })

    result, raw_response = predict_multiple_choice(row)

    print(f"Sentence: {sentence}")
    print(f"Options:")
    print(f"  A. {option_A}")
    print(f"  B. {option_B}")
    print(f"  C. {option_C}")
    print(f"  D. {option_D}")
    print(f"Model Prediction: {result}")
    print(f"Raw Response: '{raw_response}'")

    return result, raw_response

print("\n" + "="*60)
print("ALL ANALYSIS COMPLETE!")
print("="*60)
print(f"Model: Llama-3.1-8B")
print(f"Total Questions Evaluated: {len(df)}")
print(f"Overall Accuracy: {full_metrics['accuracy']*100:.2f}%")
print(f"Correct Predictions: {full_metrics['correct_predictions']}/{full_metrics['total_questions']}")

# Create final summary report
summary_report = {
    'Model': 'Llama-3.1-8B',
    'Dataset': 'hard_120_samples.csv',
    'Total Questions': full_metrics['total_questions'],
    'Correct Predictions': full_metrics['correct_predictions'],
    'Overall Accuracy': full_metrics['accuracy'],
    'Timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}

# Add per-class accuracies to summary
for label in ['A', 'B', 'C', 'D']:
    summary_report[f'Class_{label}_Accuracy'] = full_metrics['class_accuracies'][label]

summary_df = pd.DataFrame([summary_report])
summary_df.to_csv('llama_evaluation_summary_mcq.csv', index=False)

print("\nResults saved to:")
print("  - 'llama_mcq_predictions_full.csv' (full predictions with details)")
print("  - 'llama_misclassified_examples_mcq.csv' (incorrect predictions)")
print("  - 'llama_domain_wise_analysis.csv' (accuracy by domain)")
print("  - 'llama_evaluation_metrics_summary_mcq.csv' (summary metrics)")
print("  - 'llama_confusion_matrix_mcq.png' (confusion matrix visualization)")
print("  - 'llama_label_distribution_mcq.png' (label distribution visualization)")
print("  - 'llama_raw_model_predictions_sample_mcq.csv' (sample predictions)")

print("\nUse evaluate_single_question() to test individual questions.")