In [None]:
#####################################################################
#Project: Compare low-resource adaptation techniques: 
# (prefix-tuning, LoRA, prompt-tuning) on two downstream tasks 
# (classification & summarization). 
# Report parameter-efficiency vs performance curves.
#####################################################################

############## Local Working Version. Use Python 3.12.10 ##############
#File Namme: lora_ft_e1.ipynb
#Create a venv using python3.12 -m venv .venv
#Activate the venv using source .venv/bin/activate
#Install dependencies using pip install -r requirements.txt
#################################################################

import os   
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"

KAGGLE_REQUIREMENTS_PATH = '/kaggle/input/dependencies/requirements-kaggle-v1.0.txt'

# KAGGLE TOGGLE 
IS_KAGGLE = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', None) is not None  # Detect if running on Kaggle
print(f"Running on Kaggle: {IS_KAGGLE}")
if IS_KAGGLE:
    # Make sure you upload 'requirements-kaggle.txt' as a Kaggle Dataset
    # and add it to this notebook with the path '/kaggle/input/requirements/'
    if os.path.exists(KAGGLE_REQUIREMENTS_PATH):
        print(f"Installing dependencies from {KAGGLE_REQUIREMENTS_PATH}...")
        os.system(f'pip install -r {KAGGLE_REQUIREMENTS_PATH}')
    else:
        print(f"WARNING: Could not find {KAGGLE_REQUIREMENTS_PATH}.")
        print("Please upload 'requirements-kaggle.txt' as a Kaggle Dataset and add it.")


# In[2]:
import torch
import numpy as np
import pandas as pd
import evaluate
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset

from sklearn.metrics import confusion_matrix
from collections import Counter

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AutoConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)

from peft import (
    get_peft_model,
    LoraConfig,
    PrefixTuningConfig,
    PromptTuningConfig,
    TaskType,
    PeftModel
)
import logging
import warnings
import json # For saving log_history if needed

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# In[3]:
# DEVICE DETECTION 
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# CONFIGURATION 
MODEL_NAME = "t5-small" # flan-t5-small model is giving issues - config dim bug (num_heads=6 mismatch)
SUMMARIZATION_DATASET = "knkarthick/samsum"

BENCHMARK_GLUE="glue"
GLUE_DATASET_TASK_SC = "sst2"  # SST-2 for sentiment classification

DATASET_SIZE = 'full' # 100 or 500 or 'full' 
# WARNING: DATASET_SIZE=100 is very small and only good for a 'smoke test'.
# The resulting performance will be near random chance and not suitable for a real comparison.
# Please set to 'full' or a larger number (e.g., 5000) for a meaningful benchmark.
RUN_ABLATIONS = True  # Toggle to enable/disable ablation study (modular flag)

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

NUM_VIRTUAL_TOKENS = 50 # CHANGE: Increased from 20 to 50 for better adaptation in prefix/prompt - Why: Longer tokens allow stronger task-specific tuning, fixing weak/flat metrics in prefix/prompt
MAX_POS = 512

OUTPUT_DIR = './outputs/lora-ft-d3-v3' if not IS_KAGGLE else '/kaggle/working/outputs/lora-ft-d3-v3'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# In[4]:
print("="*60)
print("PEFT COMPARISON - T5-small")
print("="*60)
print(f"Dataset size: {DATASET_SIZE}")
print(f"Model: {MODEL_NAME}")
print("Methods: LoRA, Prefix-Tuning, Prompt-Tuning, Full FT")
if RUN_ABLATIONS:
    print("Ablations Enabled: Including ablated variants for study")
    print("Note: For LoRA ablation, using lora_alpha=0 to nullify adapter effect")
print("="*60)
print()

# In[5]:
# UTILITIES 
def limit_dataset_size(dataset, size):
    if size == 'full':
        return dataset
    if isinstance(size, int) and size > 0:
        return dataset.select(range(min(size, len(dataset))))
    raise ValueError(f"Invalid size: {size}")

def setup_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def safe_cleanup():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    elif device.type == 'mps':
        torch.mps.empty_cache()

# In[6]:
# PLOTS - can be moved down
def plot_learning_curves(log_history, exp_name, task_name, save_dir="./plots"):
    """Plot train/eval loss and task-specific metrics vs step."""
    os.makedirs(save_dir, exist_ok=True)
   
    # Extract data
    steps = [log['step'] for log in log_history if 'step' in log and 'eval_loss' not in log] # Get train steps
    eval_steps = [log['step'] for log in log_history if 'eval_loss' in log] # Get eval steps
    train_losses = [log['loss'] for log in log_history if 'loss' in log] # 'loss' is train loss
    eval_losses = [log['eval_loss'] for log in log_history if 'eval_loss' in log]
   
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    sns.set_style("whitegrid")
   
    # Loss curve
    # Match train loss steps to eval steps for cleaner plots if they differ
    train_steps_for_loss = [log['step'] for log in log_history if 'loss' in log]
    axes[0].plot(train_steps_for_loss, train_losses, label='Train Loss', marker='o', alpha=0.7)
    if eval_losses:
        axes[0].plot(eval_steps, eval_losses, label='Eval Loss', marker='s')
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].set_title(f'{exp_name} - Loss Curve')
    axes[0].legend()
   
    # Task-specific metric
    if task_name == "classification":
        eval_accs = [log['eval_accuracy'] for log in log_history if 'eval_accuracy' in log]
        if eval_accs:
            axes[1].plot(eval_steps, eval_accs, label='Eval Accuracy', marker='o', color='green')
            axes[1].set_ylabel('Accuracy')
    else: # summarization
        eval_rouge_ls = [log['eval_rougeL'] for log in log_history if 'eval_rougeL' in log]
        if eval_rouge_ls:
            axes[1].plot(eval_steps, eval_rouge_ls, label='Eval ROUGE-L', marker='o', color='green')
            axes[1].set_ylabel('ROUGE-L')
   
    axes[1].set_xlabel('Step')
    axes[1].set_title(f'{exp_name} - {task_name.capitalize()} Metric')
    axes[1].legend()
   
    plt.tight_layout()
    plot_path = os.path.join(save_dir, f"{exp_name}_curves.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"✓ Learning curves saved to {plot_path}")
    return plot_path


def plot_ablation_comparisons(results, task_name, save_dir="./plots"):
    """Graphical analysis: Compare baselines vs ablations for a task."""
    os.makedirs(save_dir, exist_ok=True)
    methods = list(results.keys())
    baselines = [m for m in methods if "_ablated_" not in m]
    ablations = [m for m in methods if "_ablated_" in m]
    
    if not ablations:
        return None
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    sns.set_style("whitegrid")
    
    # Trainable params comparison
    trainable_pcts = [100 * results[m]["trainable_params"] / results[m]["total_params"] for m in methods]
    sns.barplot(x=methods, y=trainable_pcts, ax=axes[0])
    axes[0].set_ylabel('Trainable %')
    axes[0].set_title(f'Trainable Params Comparison - {task_name.capitalize()}')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Metric comparison (use key metric)
    if task_name == "classification":
        metrics = [results[m]["test_metrics"].get("eval_accuracy", 0) for m in methods]
        metric_label = 'Accuracy'
    else:
        metrics = [results[m]["test_metrics"].get("eval_rougeL", 0) for m in methods]
        metric_label = 'ROUGE-L'
    
    sns.barplot(x=methods, y=metrics, ax=axes[1])
    axes[1].set_ylabel(metric_label)
    axes[1].set_title(f'Performance Comparison - {task_name.capitalize()}')
    axes[1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plot_path = os.path.join(save_dir, f"ablation_comparison_{task_name}.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"✓ Ablation comparison plot saved to {plot_path}")
    return plot_path

# In[7]:
## LOAD DATASETS 
print("Loading datasets")
# Classification dataset - SST-2
classification_dataset = load_dataset(BENCHMARK_GLUE, GLUE_DATASET_TASK_SC)
# Summarization dataset - SAMSum
summarization_dataset = load_dataset(SUMMARIZATION_DATASET)

# Load tokenizer
tokenizer = setup_tokenizer(MODEL_NAME)

if DATASET_SIZE != 'full':
    print(f"Limiting dataset size to {DATASET_SIZE} for train.")
    classification_dataset['train'] = limit_dataset_size(classification_dataset['train'], DATASET_SIZE)
    classification_dataset['validation'] = limit_dataset_size(classification_dataset['validation'], DATASET_SIZE // 4)
    classification_dataset['test'] = limit_dataset_size(classification_dataset.get('test', classification_dataset['validation']), DATASET_SIZE // 4)
    
    summarization_dataset['train'] = limit_dataset_size(summarization_dataset['train'], DATASET_SIZE)
    summarization_dataset['validation'] = limit_dataset_size(summarization_dataset['validation'], DATASET_SIZE // 4)
    summarization_dataset['test'] = limit_dataset_size(summarization_dataset['test'], DATASET_SIZE // 4)

print("Datasets loaded\n")

# In[8]:
# Print 10 samples from each train dataset before preprocessing
print("Original Sample Datasets")

print("Classification Train Samples (Before Preprocessing):")
for i in range(min(10, len(classification_dataset['train']))):
    print(classification_dataset["train"][i])

print("\nSummarization Train Samples (Before Preprocessing):")
for i in range(min(10, len(summarization_dataset['train']))):
    print(summarization_dataset["train"][i])

# In[9]:
# Preprocessing for Classification
def preprocess_classification(examples):
    # Create input sentences with the required prefix
    inputs = [f"Classify sentiment: {text}" for text in examples["sentence"]]
    
    # Define max length for inputs
    max_input_len = MAX_POS - NUM_VIRTUAL_TOKENS
    
    # Tokenize inputs with truncation and padding
    model_inputs = tokenizer(inputs, max_length=max_input_len, truncation=True, padding="max_length")
    
    # Convert labels from numerical to text
    labels_text = ["negative" if label == 0 else "positive" for label in examples["label"]]
    
    # Tokenize labels similar to inputs
    labels = tokenizer(text_target=labels_text, max_length=10, truncation=True, padding="max_length")
    
    # Add tokenized labels to model inputs
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

# Preprocessing for Summarization
def preprocess_summarization(examples):
    # Create input dialogues with the required prefix
    inputs = [f"Summarize the following conversation:\n{dialogue}" for dialogue in examples["dialogue"]]
    
    # Define max length for inputs
    max_input_len = MAX_POS - NUM_VIRTUAL_TOKENS
    
    # Tokenize inputs with truncation and padding
    model_inputs = tokenizer(inputs, max_length=max_input_len, truncation=True, padding="max_length")
    
    # Define max length for summaries
    max_label_len = 128 - NUM_VIRTUAL_TOKENS
    
    # Tokenize summaries with truncation and padding
    labels = tokenizer(text_target=examples["summary"], max_length=max_label_len, truncation=True, padding="max_length").input_ids
    
    # Add tokenized summaries to model inputs
    model_inputs["labels"] = labels
    
    return model_inputs

# Apply preprocessing
print("\nApplying preprocessing...")
tokenized_classification = classification_dataset.map(preprocess_classification, batched=True, remove_columns=classification_dataset["train"].column_names)
tokenized_summarization = summarization_dataset.map(preprocess_summarization, batched=True, remove_columns=summarization_dataset["train"].column_names)

# Print samples from each post preprocessing
POST_PROCESS_SAMPLES = 5

print("\nPost-Preprocessing Sample Datasets")


# Decode a single example (input + label)
def _decode_example(example: dict, tokenizer, task: str) -> dict:
    """
    Returns a dict with:
        - "input_text"   : the original prompt (e.g. "Classify sentiment: …")
        - "label_text"   : the gold label (positive/negative or the full summary)
        - "input_ids"    : first 30 tokens (for sanity check)
        - "label_ids"    : first 15 tokens of the label
    """
    # 1. Decode the **input** (skip special tokens, keep the prompt)
    input_txt = tokenizer.decode(example["input_ids"], skip_special_tokens=False)
    # remove the padding part after the EOS token
    input_txt = input_txt.split(tokenizer.eos_token)[0] + tokenizer.eos_token

    # 2. Decode the **label**
    # Labels contain -100 for ignored positions → replace with pad token first
    label_ids = [
        tok_id if tok_id != -100 else tokenizer.pad_token_id for tok_id in example["labels"]
    ]
    label_txt = tokenizer.decode(label_ids, skip_special_tokens=True)

    # 3. Short token previews (optional, makes the output tidy)
    input_preview = " ".join(map(str, example["input_ids"][:30]))
    label_preview = " ".join(map(str, label_ids[:15]))

    return {
        "input_text": input_txt,
        "label_text": label_txt,
        "input_ids_preview": input_preview,
        "label_ids_preview": label_preview,
    }

# Print classification samples
print("\n=== Classification – post-preprocessing ===")
for i, ex in enumerate(tokenized_classification["train"].select(range(min(POST_PROCESS_SAMPLES, len(tokenized_classification["train"]))))):
    decoded = _decode_example(ex, tokenizer, task="classification")
    print(f"\n--- Example {i+1} ---")
    print(f"INPUT  : {decoded['input_text']}")
    print(f"LABEL  : {decoded['label_text']}")
    print(f"input_ids  (first 30) : {decoded['input_ids_preview']}")
    print(f"label_ids  (first 15) : {decoded['label_ids_preview']}")

# Print summarisation samples
print("\n=== Summarisation – post-preprocessing (5 examples) ===")
for i, ex in enumerate(tokenized_summarization["train"].select(range(min(POST_PROCESS_SAMPLES, len(tokenized_summarization["train"]))))):
    decoded = _decode_example(ex, tokenizer, task="summarization")
    print(f"\n--- Example {i+1} ---")
    print(f"INPUT  : {decoded['input_text']}")
    print(f"SUMMARY: {decoded['label_text']}")
    print(f"input_ids  (first 30) : {decoded['input_ids_preview']}")
    print(f"label_ids  (first 15) : {decoded['label_ids_preview']}")

print("\nPreprocessing complete\n")

# In[10]:
# Load metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
rouge_metric = evaluate.load("rouge")

def compute_classification_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        
        # Handling prediction tensors
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)
        
        # Replace -100 in labels with pad_token_id
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)

        # Validate predictions and labels for negative values
        if np.any(predictions < 0) or np.any(labels < 0):
            logger.warning(f"Found negative values in predictions or labels. Clamping to 0.")
            predictions = np.clip(predictions, 0, None)
            labels = np.clip(labels, 0, None)
        
        # Decode the predictions and labels
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
        # CHANGE: Added sample logging for debug - Why: To diagnose poor generations causing flat/low metrics
        logger.info(f"Sample pred: {decoded_preds[0]}, label: {decoded_labels[0]}")  # Log first sample
        
        # Normalize the decoded texts
        decoded_preds = [p.strip().lower() for p in decoded_preds]
        decoded_labels = [l.strip().lower() for l in decoded_labels]
        
        # CHANGE: Use exact match instead of 'in' - Why: Prevents false positives from verbose outputs, fixing brittle mapping and low accuracy
        pred_binary = [1 if p == 'positive' else 0 for p in decoded_preds]
        label_binary = [1 if l == 'positive' else 0 for l in decoded_labels]
        
        # Compute metrics
        acc = accuracy_metric.compute(predictions=pred_binary, references=label_binary)
        f1 = f1_metric.compute(predictions=pred_binary, references=label_binary, average="weighted")
        
        return {"accuracy": acc["accuracy"], "f1": f1["f1"]}
    
    except Exception as e:
        # CHANGE: More verbose error logging - Why: Catches silent failures causing empty plots/0.0 metrics
        logger.error(f"Classification metrics error: {e}. Returning defaults.")
        return {"accuracy": 0.0, "f1": 0.0}

def compute_summarization_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        
        # Handling prediction tensors
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)
        
        # Replace -100 in predictions/labels with pad_token_id
        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

        # Validate predictions and labels for negative values
        if np.any(predictions < 0) or np.any(labels < 0):
            logger.warning(f"Found negative values in predictions or labels. Clamping to 0.")
            predictions = np.clip(predictions, 0, None)
            labels = np.clip(labels, 0, None)
        
        # Decode the predictions and labels
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
        # CHANGE: Added sample logging for debug - Why: To inspect poor generations causing decreasing ROUGE
        logger.info(f"Sample pred: {decoded_preds[0]}, label: {decoded_labels[0]}")  # Log first sample
        
        # Normalize the decoded texts
        decoded_preds = [p.strip() if p.strip() else "empty" for p in decoded_preds]
        decoded_labels = [l.strip() if l.strip() else "empty" for l in decoded_labels]
        
        # Compute ROUGE scores
        result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        
        return {
            "rouge1": result["rouge1"],
            "rouge2": result["rouge2"],
            "rougeL": result["rougeL"],
            "rougeLsum": result["rougeLsum"]
        }
    
    except Exception as e:
        # CHANGE: More verbose error logging - Why: Catches silent failures in metrics computation
        logger.error(f"Summarization metrics error: {e}. Returning defaults.")
        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "rougeLsum": 0.0}

# In[11]:
# Plot confusion matrix - @TODO: Integrate into main flow
def plot_confusion_matrix(y_true, y_pred, classes=None, title='Confusion matrix', cmap=plt.cm.Blues):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, cbar=False, xticklabels=classes, yticklabels=classes)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.title(title)
    plt.show()

def compute_and_plot_confusion_matrix_classification(decoded_labels, decoded_preds):
    # Convert text labels to binary 0/1
    label_binary = [1 if 'positive' in l else 0 for l in decoded_labels]
    pred_binary = [1 if 'positive' in p else 0 for p in decoded_preds]
    plot_confusion_matrix(label_binary, pred_binary, classes=['negative', 'positive'], title='Classification Confusion Matrix')

def compute_and_plot_confusion_matrix_summarization(decoded_labels, decoded_preds, tokenizer):
    # For summarization, generate token-level confusion matrix based on token matches
    label_tokens = [tokenizer.tokenize(l) for l in decoded_labels]
    pred_tokens = [tokenizer.tokenize(p) for p in decoded_preds]

    true_tokens = []
    pred_tokens_flat = []
    for lt, pt in zip(label_tokens, pred_tokens):
        min_len = min(len(lt), len(pt))
        true_tokens.extend(lt[:min_len])
        pred_tokens_flat.extend(pt[:min_len])

    # Limit to top 10 tokens for visualization
    all_tokens = list(set(true_tokens + pred_tokens_flat))
    if len(all_tokens) > 10:
        all_tokens = all_tokens[:10]

    #plot_confusion_matrix(true_tokens, pred_tokens_flat, classes=all_tokens, title='Summarization Token-level Confusion Matrix')

# In[12]:
# TRAINING ARGS 
def get_training_args(method_name, task_name):
    is_peft = method_name in ["lora", "prefix", "prompt"] or "_ablated_" in method_name
    # CHANGE: Lowered LR for PEFT/ablation to 3e-4, Full FT to 1e-5 - Why: High LR caused instability/overfitting/decreasing metrics; matches t5-small recommendations
    lr = 3e-4 if is_peft else 1e-5
    
    if DATASET_SIZE == 'full':
        # CHANGE: Increased epochs to 5 for summarization - Why: Smaller dataset needs more passes for convergence, fixing underfitting/low ROUGE
        epochs = 5 if task_name == 'summarization' else 3
        batch, eval_steps = 8, 500
    elif DATASET_SIZE <= 500:
        # Use more epochs for very small datasets to allow for learning
        epochs, batch, eval_steps = 10, 4, 20 # Eval more frequently
    else:
        epochs, batch, eval_steps = 3, 8, 100

    # Adjust steps based on actual dataset size
    if DATASET_SIZE != 'full':
        total_steps = (DATASET_SIZE // batch) * epochs
        # Ensure eval_steps is not 0 and is reasonable
        eval_steps = max(1, min(total_steps // 5, 50)) # Eval 5 times per run, max 50
        logging_steps = max(1, eval_steps // 2)
        save_steps = eval_steps
        eval_strategy = "steps"
        save_strategy = "steps"
    else:
        eval_strategy = "epoch"
        save_strategy = "epoch"
        logging_steps = 100
        save_steps = None
        eval_steps = None

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    use_fp16 = False  # Disabled to avoid NaN losses
    load_best = method_name == "full_ft" or "lora" in method_name
    
    return Seq2SeqTrainingArguments(
        output_dir=f"{OUTPUT_DIR}/results/{task_name}/{method_name}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch,
        per_device_eval_batch_size=batch * 2,
        learning_rate=lr,
        warmup_steps=min(100, DATASET_SIZE // 10) if DATASET_SIZE != 'full' else 500,
        # CHANGE: Increased weight_decay to 0.1 - Why: Stronger regularization prevents overfitting, fixing loss→0 but metrics drop
        weight_decay=0.1,
        eval_strategy=eval_strategy,
        eval_steps=eval_steps,
        save_strategy=save_strategy,
        save_steps=save_steps,
        load_best_model_at_end=load_best,
        metric_for_best_model="eval_loss",
        save_total_limit=2,
        logging_steps=logging_steps,
        bf16=use_bf16,
        fp16=use_fp16,
        dataloader_num_workers=0,
        dataloader_drop_last=True, # Avoid incomplete batches for stability
        report_to="none",
        predict_with_generate=True,
        max_grad_norm=1.0,  # Added to prevent gradient explosions
        # CHANGE: Added gradient_accumulation_steps=4 - Why: Stabilizes training with small effective batches, fixing oscillation in ablations
        gradient_accumulation_steps=4,
        # CHANGE: Added label_smoothing_factor=0.1 - Why: Prevents overconfidence in generations, fixing decreasing ROUGE in summ
        #label_smoothing_factor=0.1,
        # CHANGE: Set optimizer to 'adamw_torch' - Why: More robust for PEFT, fixing instability in ablations/Full FT
        optim='adamw_torch'
    )

# In[13]:
# MAIN TRAINING LOOP 
base_methods = ["lora", "prefix", "prompt", "full_ft"]
ablation_methods = ["lora_ablated_alpha0", "prefix_ablated_no_proj", "prompt_ablated_short"]
methods_to_run = base_methods + (ablation_methods if RUN_ABLATIONS else [])
tasks = {
    "classification": (tokenized_classification, compute_classification_metrics),
    "summarization": (tokenized_summarization, compute_summarization_metrics)
}

results = {}
os.makedirs(f"{OUTPUT_DIR}/results", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/models", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/plots", exist_ok=True) 

for method_name in methods_to_run:
    for task_name, (dataset, compute_metrics) in tasks.items():
        print(f"\n{'='*60}")
        print(f"EXPERIMENT: {method_name.upper()} on {task_name.upper()}")
        print(f"{'='*60}\n")
        try:
            config = AutoConfig.from_pretrained(MODEL_NAME)
            use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
            model = AutoModelForSeq2SeqLM.from_pretrained(
                MODEL_NAME,
                config=config,
                dtype=torch.bfloat16 if use_bf16 else torch.float32,
            )

            model.to(device)
            
            # Note: t5-small has correct dims (num_heads=8, head_dim=64); PEFT handles DynamicCache natively.
            # Create PEFT configs dynamically from model.config
            if method_name != "full_ft":
                d_model = model.config.d_model
                num_heads = model.config.num_heads
                total_layers = model.config.num_layers + model.config.num_decoder_layers
                peft_configs_local = {
                    "lora": LoraConfig(
                        r=16,
                        lora_alpha=32,
                        target_modules=["q", "v"],
                        lora_dropout=0.05,
                        bias="none",
                        task_type=TaskType.SEQ_2_SEQ_LM
                    ),
                    "lora_ablated_alpha0": LoraConfig(
                        r=16,
                        lora_alpha=0,  # Ablation: zero scaling, no effect from adapter
                        target_modules=["q", "v"],
                        lora_dropout=0.05,
                        bias="none",
                        task_type=TaskType.SEQ_2_SEQ_LM
                    ),
                    "prefix": PrefixTuningConfig(
                        task_type=TaskType.SEQ_2_SEQ_LM,
                        inference_mode=False,
                        num_virtual_tokens=NUM_VIRTUAL_TOKENS,
                        token_dim=d_model,
                        num_transformer_submodules=2,
                        num_attention_heads=num_heads,
                        num_layers=total_layers,
                        encoder_hidden_size=d_model,
                        prefix_projection=True  # Baseline with projection
                    ),
                    "prefix_ablated_no_proj": PrefixTuningConfig(  # Ablation: Remove projection layer
                        task_type=TaskType.SEQ_2_SEQ_LM,
                        inference_mode=False,
                        num_virtual_tokens=NUM_VIRTUAL_TOKENS,
                        token_dim=d_model,
                        num_transformer_submodules=2,
                        num_attention_heads=num_heads,
                        num_layers=total_layers,
                        encoder_hidden_size=d_model,
                        prefix_projection=False  # Ablated
                    ),
                    "prompt": PromptTuningConfig(
                        num_virtual_tokens=NUM_VIRTUAL_TOKENS,
                        task_type=TaskType.SEQ_2_SEQ_LM,
                        prompt_tuning_init="RANDOM"
                    ),
                    "prompt_ablated_short": PromptTuningConfig(  # Ablation: Fewer tokens (e.g., half)
                        num_virtual_tokens=NUM_VIRTUAL_TOKENS // 2,
                        task_type=TaskType.SEQ_2_SEQ_LM,
                        prompt_tuning_init="RANDOM"
                    )
                }
                model = get_peft_model(model, peft_configs_local[method_name])
                model.print_trainable_parameters()
            else:
                trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
                total = sum(p.numel() for p in model.parameters())
                print(f"trainable params: {trainable:,} || all params: {total:,} || trainable%: 100.00")
            
            training_args = get_training_args(method_name, task_name)
            data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
            trainer = Seq2SeqTrainer(
                model=model,
                args=training_args,
                train_dataset=dataset["train"],
                eval_dataset=dataset["validation"],
                data_collator=data_collator,
                compute_metrics=compute_metrics,
                tokenizer=tokenizer
            )

            print("Training...")
            train_result = trainer.train()
            
            # Manual load best model for prefix/prompt methods
            if not training_args.load_best_model_at_end and trainer.state.best_model_checkpoint:
                print(f"Loading best checkpoint manually: {trainer.state.best_model_checkpoint}")
                base_model = AutoModelForSeq2SeqLM.from_pretrained(
                    MODEL_NAME,
                    config=config,
                    dtype=torch.bfloat16 if use_bf16 else torch.float32,
                )
                base_model.to(device)
                model = PeftModel.from_pretrained(base_model, trainer.state.best_model_checkpoint)
                trainer.model = model
                model.to(device)  # Ensure the full PEFT model is on device
            
            print("Evaluating...")
            test_dataset = dataset.get("test", dataset["validation"])
            gen_kwargs = {
                # CHANGE: For classification, max_length=5; summ=128; num_beams=6 - Why: Short for classification enforces concise labels (fixes verbose outputs/low acc); more beams improves quality (fixes poor ROUGE)
                "max_length": 5 if task_name == "classification" else 128,
                "num_beams": 6,
                "early_stopping": True,
            }
            
            # Set generation kwargs for trainer.evaluate
            training_args.generation_max_length = gen_kwargs["max_length"]
            training_args.generation_num_beams = gen_kwargs["num_beams"]
            test_metrics = trainer.evaluate(test_dataset)
            exp_name = f"{method_name}_{task_name}"
            trainable = model.num_parameters(only_trainable=True) if hasattr(model, 'num_parameters') else sum(p.numel() for p in model.parameters() if p.requires_grad)
            total = model.num_parameters() if hasattr(model, 'num_parameters') else sum(p.numel() for p in model.parameters())
            
            results[exp_name] = {
                "train_metrics": train_result.metrics,
                "test_metrics": test_metrics,
                "trainable_params": trainable,
                "total_params": total,
                "log_history": trainer.state.log_history # Collect for plotting
            }
            
            save_path = f"{OUTPUT_DIR}/models/{task_name}/{method_name}"
            os.makedirs(save_path, exist_ok=True)
            trainer.save_model(save_path)
            print(f"Completed and saved to {save_path}\n")
            del model, trainer
            safe_cleanup()
        except Exception as e:
            logger.error(f"ERROR in {method_name}_{task_name}: {e}")
            import traceback
            logger.error(traceback.format_exc())
            try:
                del model, trainer
            except:
                pass
            safe_cleanup()

print("\n" + "="*60)
print("ALL EXPERIMENTS COMPLETED")
print("="*60)

# In[14]:
# RESULTS 
if results:
    print("\nRESULTS SUMMARY:")
    print("="*60)
    for exp_name, exp_data in results.items():
        # Handle cases where task name might have underscores
        method_task_split = exp_name.split('_', 1)
        method = method_task_split[0]
        task = method_task_split[1] if len(method_task_split) > 1 else 'unknown'
        
        metrics = exp_data["test_metrics"]
        pct = 100 * exp_data["trainable_params"] / exp_data["total_params"]
        print(f"\n{method.upper()} - {task.capitalize()}:")
        print(f" Trainable: {pct:.2f}%")
        if task == "classification":
            print(f" Accuracy: {metrics.get('eval_accuracy', 0):.4f}")
            print(f" F1: {metrics.get('eval_f1', 0):.4f}")
        else:
            print(f" ROUGE-1: {metrics.get('eval_rouge1', 0):.4f}")
            print(f" ROUGE-L: {metrics.get('eval_rougeL', 0):.4f}")

    # Ablation deltas if enabled
    if RUN_ABLATIONS:
        print("\nABLATION DELTAS:")
        for exp_name, exp_data in results.items():
            if "_ablated_" in exp_name:
                method_task_split = exp_name.split('_ablated_')[0]
                task = exp_name.split('_', 1)[1] # Get task name
                base_method_name = f"{method_task_split}_{task}"
                
                if base_method_name in results:
                    base_metrics = results[base_method_name]["test_metrics"]
                    delta = {k: exp_data["test_metrics"].get(k, 0) - base_metrics.get(k, 0) for k in base_metrics if "eval_" in k}
                    print(f"Delta for {exp_name.upper()}: {delta}")

    # Plot learning curves for each experiment
    print("\nGenerating learning curves...")
    plot_paths = {}
    plot_save_dir = f"{OUTPUT_DIR}/plots" # [FIX] Define plot save dir
    for exp_name, exp_data in results.items():
        task_name = exp_name.split("_", 1)[1]
        # [FIX] Pass the correct save_dir to the plotting function
        plot_path = plot_learning_curves(exp_data["log_history"], exp_name, task_name, save_dir=plot_save_dir)
        plot_paths[exp_name] = plot_path
    
    # Graphical ablation comparisons per task
    ablation_plot_paths = {}
    if RUN_ABLATIONS:
        print("\nGenerating ablation comparison plots...")
        for task_name in tasks.keys():
            task_results = {k: v for k, v in results.items() if k.endswith(f"_{task_name}")}
            if task_results:
                # [FIX] Pass the correct save_dir to the plotting function
                ablation_plot_path = plot_ablation_comparisons(task_results, task_name, save_dir=plot_save_dir)
                if ablation_plot_path:
                    ablation_plot_paths[task_name] = ablation_plot_path

    # --- Results DataFrame ---
    results_df = []
    for exp_name, exp_data in results.items():
        method, task = exp_name.split("_", 1)
        results_df.append({
            "Method": method.upper(),
            "Task": task.capitalize(),
            "Trainable %": 100 * exp_data["trainable_params"] / exp_data["total_params"],
            **{k: v for k, v in exp_data["test_metrics"].items() if isinstance(v, (int, float))}
        })
    
    df = pd.DataFrame(results_df)
    cols = ["Method", "Task", "Trainable %"]
    metric_cols = [c for c in df.columns if c.startswith("eval_")]
    cols.extend(sorted(metric_cols))
    df = df[cols]
    df.to_csv(f"{OUTPUT_DIR}/peft_results.csv", index=False)
    print(f"\nResults saved to '{OUTPUT_DIR}/peft_results.csv'")
    
    # --- Final Report --- 
    # Use relative paths for plots in the markdown report
    report_path = f"{OUTPUT_DIR}/final_report.md"
    report_dir = os.path.dirname(report_path)

    with open(report_path, "w") as f:
        f.write(f"# PEFT Comparison Results - T5-small\n\n")
        f.write(f"## Configuration\n")
        f.write(f"- Model: {MODEL_NAME} (switched from flan-t5-small to fix config dim bug)\n")
        f.write(f"- Dataset Size: {DATASET_SIZE}\n")
        f.write(f"- Methods: LoRA, Prefix-Tuning, Prompt-Tuning, Full Fine-Tuning\n")
        if RUN_ABLATIONS:
            f.write(f"- Ablations: Enabled (including ablated variants); LoRA ablation uses lora_alpha=0 for no adaptation effect\n")
        f.write(f"- Special: Native DynamicCache support; correct dims (num_heads=8, head_dim=64)\n\n")
        f.write(f"## Summary Table\n\n")
        f.write(df.to_markdown(index=False))
        f.write("\n\n## Learning Curves\n")
        for exp_name, plot_path in plot_paths.items():
            relative_plot_path = os.path.relpath(plot_path, start=report_dir)
            f.write(f"- [{exp_name}]({relative_plot_path})\n")
        if RUN_ABLATIONS and ablation_plot_paths:
            f.write("\n## Ablation Comparisons\n")
            for task_name, plot_path in ablation_plot_paths.items():
                relative_plot_path = os.path.relpath(plot_path, start=report_dir)
                f.write(f"- [{task_name.capitalize()} Ablation Comparison]({relative_plot_path})\n")
    
    print(f"Report saved to '{report_path}' (includes plot links)")

    # Generate dynamic outcome insights based on results
    print("\nOUTCOME INSIGHTS:")
    if results:
        # General insights from trainable params and metrics
        for task in tasks.keys():
            task_exps = {k: v for k, v in results.items() if k.endswith(task)}
            if task_exps:
                # Find method with lowest trainable %
                min_trainable_method = min(task_exps, key=lambda k: 100 * task_exps[k]["trainable_params"] / task_exps[k]["total_params"])
                min_pct = 100 * task_exps[min_trainable_method]["trainable_params"] / task_exps[min_trainable_method]["total_params"]
                print(f"- For {task.capitalize()}, {min_trainable_method.split('_')[0].upper()} has the lowest trainable params ({min_pct:.2f}%).")
                
                # Find best performing method (use key metric)
                key_metric = 'eval_accuracy' if task == 'classification' else 'eval_rougeL'
                best_method = max(task_exps, key=lambda k: task_exps[k]["test_metrics"].get(key_metric, 0))
                best_score = task_exps[best_method]["test_metrics"].get(key_metric, 0)
                print(f"- {best_method.split('_')[0].upper()} achieves the highest {key_metric.replace('eval_', '').upper()} score ({best_score:.4f}) on {task.capitalize()}.")
        
        # Ablation-specific insights
        if RUN_ABLATIONS:
            for exp_name, exp_data in results.items():
                if "_ablated_" in exp_name:
                    method_task_split = exp_name.split('_ablated_')[0]
                    task = exp_name.split('_', 1)[1] # Get task name
                    base_method_name = f"{method_task_split}_{task}"
                    
                    if base_method_name in results:
                        base_metrics = results[base_method_name]["test_metrics"]
                        delta = {k: exp_data["test_metrics"].get(k, 0) - base_metrics.get(k, 0) for k in base_metrics if "eval_" in k}
                        key_delta = delta.get('eval_accuracy' if task == 'classification' else 'eval_rougeL', 0)
                        impact = "degradation" if key_delta < 0 else "improvement" if key_delta > 0 else "no change"
                        print(f"- Ablation in {exp_name.upper()} leads to {impact} in performance (delta: {key_delta:.4f}).")
        
        print(f"View plots in {OUTPUT_DIR}/plots/ for detailed curves (loss/metric vs step) and comparisons.")
else:
    print("\nNo results were generated. Check the training loop for errors.")

# In[15]:
print("\n" + "="*60)
print("SUCCESS - All 4 PEFT methods completed!" + (" With ablations!" if RUN_ABLATIONS else ""))
print("="*60)