In [41]:
import os   
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"

In [42]:
import torch
import numpy as np
import random
import pandas as pd
import evaluate
import matplotlib.pyplot as plt
import seaborn as sns
import traceback
from datasets import load_dataset
from accelerate import Accelerator

from sklearn.metrics import confusion_matrix
from collections import Counter

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AutoConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)

from peft import (
    get_peft_model,
    PrefixTuningConfig,
    TaskType,
    PeftModel
)
import logging
import warnings
import json

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ERROR! Session/line number was not unique in database. History logging moved to new session 103


## Configurations

In [43]:
# CONFIGURATION 
MODEL_NAME = "google/flan-t5-small" # flan-t5-small model is giving issues - config dim bug (num_heads=6 mismatch)
SUMMARIZATION_DATASET = "knkarthick/samsum"

BENCHMARK_GLUE="glue"
GLUE_DATASET_TASK_SC = "sst2"  # SST-2 for sentiment classification

DATASET_SIZE = 'full' # 100 or 500 or 'full' 
RUN_ABLATIONS = True  # Toggle to enable/disable ablation study (modular flag)

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
if torch.cuda.is_available():
        torch.cuda.manual_seed_all(RANDOM_SEED)

NUM_VIRTUAL_TOKENS = 32 # CHANGE: Increased from 20 to 50 for better adaptation in prefix/prompt - Why: Longer tokens allow stronger task-specific tuning, fixing weak/flat metrics in prefix/prompt
MAX_POS = 512

OUTPUT_DIR = './kaggle/working_v2/'

In [44]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [45]:
print("="*60)
print("Prefix-Tuning COMPARISON - T5-small")
print("="*60)
print(f"Dataset size: {DATASET_SIZE}")
print(f"Model: {MODEL_NAME}")
print("Methods: Prefix-Tuning")
if RUN_ABLATIONS:
    print("Ablations Enabled: Including ablated variants for study")
    print("Note: For prefix ablation, removing projection layer")
print("="*60)
print()

Prefix-Tuning COMPARISON - T5-small
Dataset size: full
Model: google/flan-t5-small
Methods: Prefix-Tuning
Ablations Enabled: Including ablated variants for study
Note: For prefix ablation, removing projection layer



## Utilities

In [46]:
def limit_dataset_size(dataset, size):
    if size == 'full':
        return dataset
    if isinstance(size, int) and size > 0:
        return dataset.select(range(min(size, len(dataset))))
    raise ValueError(f"Invalid size: {size}")

def setup_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def safe_cleanup():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

## Plots

In [47]:
def plot_learning_curves(log_history, exp_name, task_name, save_dir="./plots"):
    """Plot train/eval loss and task-specific metrics vs step."""
    os.makedirs(save_dir, exist_ok=True)
   
    # Extract data
    #steps = [log['step'] for log in log_history if 'step' in log and 'eval_loss' not in log] # Get train steps
    eval_steps = [log['step'] for log in log_history if 'eval_loss' in log] # Get eval steps
    train_losses = [log['loss'] for log in log_history if 'loss' in log] # 'loss' is train loss
    eval_losses = [log['eval_loss'] for log in log_history if 'eval_loss' in log]
   
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    sns.set_style("whitegrid")
   
    # Loss curve
    # Match train loss steps to eval steps for cleaner plots if they differ
    train_steps_for_loss = [log['step'] for log in log_history if 'loss' in log]
    axes[0].plot(train_steps_for_loss, train_losses, label='Train Loss', marker='o', alpha=0.7)
    if eval_losses:
        axes[0].plot(eval_steps, eval_losses, label='Eval Loss', marker='s')
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].set_title(f'{exp_name} - Loss Curve')
    axes[0].legend()

    # Task-specific metric
    if task_name == "classification":
        eval_accs = [log['eval_accuracy'] for log in log_history if 'eval_accuracy' in log]
        if eval_accs:
            axes[1].plot(eval_steps, eval_accs, label='Eval Accuracy', marker='o', color='green')
            axes[1].set_ylabel('Accuracy')
    else: # summarization
        eval_rouge_ls = [log['eval_rougeL'] for log in log_history if 'eval_rougeL' in log]
        if eval_rouge_ls:
            axes[1].plot(eval_steps, eval_rouge_ls, label='Eval ROUGE-L', marker='o', color='green')
            axes[1].set_ylabel('ROUGE-L')
   
    axes[1].set_xlabel('Step')
    axes[1].set_title(f'{exp_name} - {task_name.capitalize()} Metric')
    axes[1].legend()

    plt.tight_layout()
    plot_path = os.path.join(save_dir, f"{exp_name}_curves.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Learning curves saved to {plot_path}")
    return plot_path

In [48]:
def plot_ablation_comparisons(results, task_name, save_dir="./plots"):
    """Graphical analysis: Compare baselines vs ablations for a task."""
    os.makedirs(save_dir, exist_ok=True)
    methods = list(results.keys())
    baselines = [m for m in methods if "_ablated_" not in m]
    ablations = [m for m in methods if "_ablated_" in m]
    
    if not ablations:
        return None
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    sns.set_style("whitegrid")
    
    # Trainable params comparison
    trainable_pcts = [100 * results[m]["trainable_params"] / results[m]["total_params"] for m in methods]
    sns.barplot(x=methods, y=trainable_pcts, ax=axes[0])
    axes[0].set_ylabel('Trainable %')
    axes[0].set_title(f'Trainable Params Comparison - {task_name.capitalize()}')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Metric comparison (use key metric)
    if task_name == "classification":
        metrics = [results[m]["test_metrics"].get("eval_accuracy", 0) for m in methods]
        metric_label = 'Accuracy'
    else:
        metrics = [results[m]["test_metrics"].get("eval_rougeL", 0) for m in methods]
        metric_label = 'ROUGE-L'
    
    sns.barplot(x=methods, y=metrics, ax=axes[1])
    axes[1].set_ylabel(metric_label)
    axes[1].set_title(f'Performance Comparison - {task_name.capitalize()}')
    axes[1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plot_path = os.path.join(save_dir, f"ablation_comparison_{task_name}.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Ablation comparison plot saved to {plot_path}")
    return plot_path

## Load dataset

In [49]:
## LOAD DATASETS 
print("Loading datasets")
# Summarization dataset - SAMSum
summarization_dataset = load_dataset(SUMMARIZATION_DATASET)

# Load tokenizer
tokenizer = setup_tokenizer(MODEL_NAME)

if DATASET_SIZE != 'full':
    print(f"Limiting dataset size to {DATASET_SIZE} for train.")
    
    summarization_dataset['train'] = limit_dataset_size(summarization_dataset['train'], DATASET_SIZE)
    summarization_dataset['validation'] = limit_dataset_size(summarization_dataset['validation'], DATASET_SIZE // 4)
    summarization_dataset['test'] = limit_dataset_size(summarization_dataset['test'], DATASET_SIZE // 4)

print("Datasets loaded\n")

Loading datasets
Datasets loaded



In [50]:
# Print 10 samples from each train dataset before preprocessing
print("Original Sample Datasets")

print("\nSummarization Train Samples (Before Preprocessing):")
for i in range(min(10, len(summarization_dataset['train']))):
    print(summarization_dataset["train"][i])

Original Sample Datasets

Summarization Train Samples (Before Preprocessing):
{'id': '13818513', 'dialogue': "Amanda: I baked  cookies. Do you want some?\nJerry: Sure!\nAmanda: I'll bring you tomorrow :-)", 'summary': 'Amanda baked cookies and will bring Jerry some tomorrow.'}
{'id': '13728867', 'dialogue': 'Olivia: Who are you voting for in this election? \nOliver: Liberals as always.\nOlivia: Me too!!\nOliver: Great', 'summary': 'Olivia and Olivier are voting for liberals in this election. '}
{'id': '13681000', 'dialogue': "Tim: Hi, what's up?\nKim: Bad mood tbh, I was going to do lots of stuff but ended up procrastinating\nTim: What did you plan on doing?\nKim: Oh you know, uni stuff and unfucking my room\nKim: Maybe tomorrow I'll move my ass and do everything\nKim: We were going to defrost a fridge so instead of shopping I'll eat some defrosted veggies\nTim: For doing stuff I recommend Pomodoro technique where u use breaks for doing chores\nTim: It really helps\nKim: thanks, maybe 

## Pre-process

In [51]:
def preprocess_summarization(examples):
    # Create input dialogues with the required prefix
    inputs = [f"Summarize the following conversation:\n{dialogue}" for dialogue in examples["dialogue"]]
    
    # Define max length for inputs
    max_input_len = MAX_POS - NUM_VIRTUAL_TOKENS
    
    # Tokenize inputs with truncation and padding
    model_inputs = tokenizer(inputs, max_length=max_input_len, truncation=True, padding="max_length")
    
    # Define max length for summaries
    max_label_len = 128

    labels_tokenized = tokenizer(text_target=examples["summary"],
                                 max_length=max_label_len,
                                 truncation=True,
                                 padding="max_length")

    # Convert pad token ids in labels -> -100 so loss ignores padding
    labels = []
    for seq in labels_tokenized["input_ids"]:
        labels.append([tok if tok != tokenizer.pad_token_id else -100 for tok in seq])
    model_inputs["labels"] = labels
    
    return model_inputs

In [52]:
# Apply preprocessing
print("\nApplying preprocessing...")
tokenized_summarization = summarization_dataset.map(preprocess_summarization, batched=True, remove_columns=summarization_dataset["train"].column_names)

# Print samples from each post preprocessing
POST_PROCESS_SAMPLES = 5

print("\nPost-Preprocessing Sample Datasets")


Applying preprocessing...


Map: 100%|██████████████████████████████████████████████████████████████████| 819/819 [00:00<00:00, 1168.05 examples/s]


Post-Preprocessing Sample Datasets





## Decode

In [53]:
# Decode a single example (input + label)
def _decode_example(example: dict, tokenizer, task: str) -> dict:
    """
    Returns a dict with:
        - "input_text"   : the original prompt (e.g. "Classify sentiment: …")
        - "label_text"   : the gold label (positive/negative or the full summary)
        - "input_ids"    : first 30 tokens (for sanity check)
        - "label_ids"    : first 15 tokens of the label
    """
    # 1. Decode the **input** (skip special tokens, keep the prompt)
    input_txt = tokenizer.decode(example["input_ids"], skip_special_tokens=False)
    # remove the padding part after the EOS token
    input_txt = input_txt.split(tokenizer.eos_token)[0] + tokenizer.eos_token

    # 2. Decode the **label**
    # Labels contain -100 for ignored positions → replace with pad token first
    label_ids = [
        tok_id if tok_id != -100 else tokenizer.pad_token_id for tok_id in example["labels"]
    ]
    label_txt = tokenizer.decode(label_ids, skip_special_tokens=True)

    # 3. Short token previews (optional, makes the output tidy)
    input_preview = " ".join(map(str, example["input_ids"][:30]))
    label_preview = " ".join(map(str, label_ids[:15]))

    return {
        "input_text": input_txt,
        "label_text": label_txt,
        "input_ids_preview": input_preview,
        "label_ids_preview": label_preview,
    }

In [54]:
# Print summarisation samples
print("\n=== Summarisation – post-preprocessing (5 examples) ===")
for i, ex in enumerate(tokenized_summarization["train"].select(range(min(POST_PROCESS_SAMPLES, len(tokenized_summarization["train"]))))):
    decoded = _decode_example(ex, tokenizer, task="summarization")
    print(f"\n--- Example {i+1} ---")
    print(f"INPUT  : {decoded['input_text']}")
    print(f"SUMMARY: {decoded['label_text']}")
    # print(f"input_ids  (first 30) : {decoded['input_ids_preview']}")
    # print(f"label_ids  (first 15) : {decoded['label_ids_preview']}")

print("\nPreprocessing complete\n")


=== Summarisation – post-preprocessing (5 examples) ===

--- Example 1 ---
INPUT  : Summarize the following conversation: Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)</s>
SUMMARY: Amanda baked cookies and will bring Jerry some tomorrow.

--- Example 2 ---
INPUT  : Summarize the following conversation: Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great</s>
SUMMARY: Olivia and Olivier are voting for liberals in this election. 

--- Example 3 ---
INPUT  : Summarize the following conversation: Tim: Hi, what's up? Kim: Bad mood tbh, I was going to do lots of stuff but ended up procrastinating Tim: What did you plan on doing? Kim: Oh you know, uni stuff and unfucking my room Kim: Maybe tomorrow I'll move my ass and do everything Kim: We were going to defrost a fridge so instead of shopping I'll eat some defrosted veggies Tim: For doing stuff I recommend Pomodoro technique where u use brea

## Metrics

In [55]:
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
rouge_metric = evaluate.load("rouge")

In [56]:
def compute_summarization_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        
        # Handling prediction tensors
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)
        
        # Replace -100 in predictions/labels with pad_token_id
        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

        # Validate predictions and labels for negative values
        if np.any(predictions < 0) or np.any(labels < 0):
            logger.warning(f"Found negative values in predictions or labels. Clamping to 0.")
            predictions = np.clip(predictions, 0, None)
            labels = np.clip(labels, 0, None)
        
        # Decode the predictions and labels
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
        # CHANGE: Added sample logging for debug - Why: To inspect poor generations causing decreasing ROUGE
        logger.info(f"Sample pred: {decoded_preds[0]}, label: {decoded_labels[0]}")  # Log first sample
        
        # Normalize the decoded texts
        decoded_preds = [p.strip() if p.strip() else "empty" for p in decoded_preds]
        decoded_labels = [l.strip() if l.strip() else "empty" for l in decoded_labels]
        
        # Compute ROUGE scores
        result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        
        # CHANGE: Ensure keys always returned - Why: Fixes empty plots by guaranteeing 'eval_rougeL' in logs
        return {
            "rouge1": result.get("rouge1", 0.0),
            "rouge2": result.get("rouge2", 0.0),
            "rougeL": result.get("rougeL", 0.0),
            "rougeLsum": result.get("rougeLsum", 0.0)
        }
    
    except Exception as e:
        # CHANGE: More verbose error logging - Why: Catches silent failures in metrics computation
        logger.error(f"Summarization metrics error: {e}. Returning defaults.")
        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "rougeLsum": 0.0}

In [57]:
# Plot confusion matrix - @TODO: Integrate into main flow
def plot_confusion_matrix(y_true, y_pred, classes=None, title='Confusion matrix', cmap=plt.cm.Blues):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, cbar=False, xticklabels=classes, yticklabels=classes)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.title(title)
    plt.show()

def compute_and_plot_confusion_matrix_classification(decoded_labels, decoded_preds):
    # Convert text labels to binary 0/1
    label_binary = [1 if 'positive' in l else 0 for l in decoded_labels]
    pred_binary = [1 if 'positive' in p else 0 for p in decoded_preds]
    plot_confusion_matrix(label_binary, pred_binary, classes=['negative', 'positive'], title='Classification Confusion Matrix')

def compute_and_plot_confusion_matrix_summarization(decoded_labels, decoded_preds, tokenizer):
    # For summarization, generate token-level confusion matrix based on token matches
    label_tokens = [tokenizer.tokenize(l) for l in decoded_labels]
    pred_tokens = [tokenizer.tokenize(p) for p in decoded_preds]

    true_tokens = []
    pred_tokens_flat = []
    for lt, pt in zip(label_tokens, pred_tokens):
        min_len = min(len(lt), len(pt))
        true_tokens.extend(lt[:min_len])
        pred_tokens_flat.extend(pt[:min_len])

    # Limit to top 10 tokens for visualization
    all_tokens = list(set(true_tokens + pred_tokens_flat))
    if len(all_tokens) > 10:
        all_tokens = all_tokens[:10]

## Training Configurations

In [60]:
def get_training_args(method_name, task_name):
    is_peft = method_name in ["prefix"] or "_ablated_" in method_name
    # CHANGE: Lowered LR for PEFT/ablation to 1e-3, Full FT to 1e-4 - Why: High LR caused instability/overfitting/decreasing metrics; matches t5-small recommendations
    
    if DATASET_SIZE == 'full':
        # CHANGE: Increased epochs to 5 for summarization - Why: Smaller dataset needs more passes for convergence, fixing underfitting/low ROUGE
        epochs = 5
        batch, eval_steps = 8, 500
    elif DATASET_SIZE <= 500:
        # Use more epochs for very small datasets to allow for learning
        epochs, batch, eval_steps = 10, 4, 20 # Eval more frequently
    else:
        epochs, batch, eval_steps = 3, 8, 100

    # Adjust steps based on actual dataset size
    if DATASET_SIZE != 'full':
        total_steps = (DATASET_SIZE // batch) * epochs
        # Ensure eval_steps is not 0 and is reasonable
        eval_steps = max(1, min(total_steps // 5, 50)) # Eval 5 times per run, max 50
        logging_steps = max(1, eval_steps // 2)
        save_steps = eval_steps
        eval_strategy = "steps"
        save_strategy = "steps"
    else:
        eval_strategy = "epoch"
        save_strategy = "epoch"
        logging_steps = 100
        save_steps = None
        eval_steps = None

    if "no_proj" in method_name:
        lr = 1e-2  # High LR for ablation
    elif is_peft:
        lr = 5e-3 # CHANGED: Increased from 5e-4 to 1e-2. Soft prompts need high LR.
        epochs = 20
    else:
        lr = 1e-4
    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    # CHANGE: Set fp16=True if not bf16 - Why: Faster training/mixed precision, fixing slow runs/low metrics if GPU supports
    use_fp16 = not use_bf16 and torch.cuda.is_available()  # Enable fp16 on CUDA if bf16 unavailable
    
    # For prompt tuning in PEFT can cause errors
    load_best = False
    
    return Seq2SeqTrainingArguments(
        output_dir=f"{OUTPUT_DIR}/results/{task_name}/{method_name}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch,
        per_device_eval_batch_size=batch * 2,
        learning_rate=lr,
        # CHANGE: Increased warmup_steps to 1000 - Why: Smoother optimization start, fixing oscillation/stuck loss in full FT/ablations
        #warmup_steps=1000 if DATASET_SIZE == 'full' else min(100, DATASET_SIZE // 10),
        # CHANGE: Increased weight_decay to 0.1 - Why: Stronger regularization prevents overfitting, fixing loss→0 but metrics drop
        weight_decay=0.01,
        warmup_ratio=0.05,
        lr_scheduler_type="cosine",
        eval_strategy=eval_strategy,
        eval_steps=eval_steps,
        save_strategy=save_strategy,
        save_steps=save_steps,
        load_best_model_at_end=load_best,
        metric_for_best_model="rougeLsum",
        save_total_limit=2,
        logging_steps=logging_steps,
        bf16=use_bf16,
        fp16=use_fp16,
        dataloader_num_workers=0,
        dataloader_drop_last=True, # Avoid incomplete batches for stability
        report_to="none",
        predict_with_generate=True,
        max_grad_norm=1.0,  # Added to prevent gradient explosions
        # CHANGE: Added gradient_accumulation_steps=4 - Why: Stabilizes training with small effective batches, fixing oscillation in ablations
        gradient_accumulation_steps=4,
        label_smoothing_factor=0.1,
        # CHANGE: Set optim to 'adamw_torch' - Why: More robust for PEFT, fixing instability in ablations/Full FT
        optim='adamw_torch',
        # CHANGE: Set gradient_checkpointing=False - Why: Avoids grad flow issues in PEFT/T5, fixing "no grad_fn" error; trade memory for stability
        gradient_checkpointing=False
    )

ERROR! Session/line number was not unique in database. History logging moved to new session 104


In [61]:
class CustomTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = Accelerator().unwrap_model(model)
            if isinstance(unwrapped_model, PeftModel):
                model_base = unwrapped_model.base_model
                if hasattr(model_base, "model"):
                    model_name = model_base.model._get_name()
                else:
                    model_name = model_base._get_name()
                if any(name in model_name for name in ["GPT", "opt", "bloom", "llama", "gemma"]):
                    loss = self.label_smoother(outputs, labels, shift_labels=True)
                else:
                    loss = self.label_smoother(outputs, labels)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

## Training loop

In [62]:
# MAIN TRAINING LOOP 
base_methods = ["prefix"]
ablation_methods = ["prefix_ablated_no_proj"] if RUN_ABLATIONS else []
methods_to_run = base_methods + ablation_methods
tasks = {
    "summarization": (tokenized_summarization, compute_summarization_metrics)
}

results = {}
os.makedirs(f"{OUTPUT_DIR}/results", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/models", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/plots", exist_ok=True) 

for method_name in methods_to_run:
    for task_name, (dataset, compute_metrics) in tasks.items():
        print(f"\n{'='*60}")
        print(f"EXPERIMENT: {method_name.upper()} on {task_name.upper()}")
        print(f"{'='*60}\n")
        try:
            config = AutoConfig.from_pretrained(MODEL_NAME)
            use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
            model = AutoModelForSeq2SeqLM.from_pretrained(
                MODEL_NAME,
                config=config,
                torch_dtype=torch.bfloat16 if use_bf16 else torch.float32,
            )

            model.to(device)
            
            # Note: t5-small has correct dims (num_heads=8, head_dim=64); PEFT handles DynamicCache natively.
            # Create PEFT configs dynamically from model.config
            d_model = model.config.d_model
            num_heads = model.config.num_heads
            total_layers = model.config.num_layers
            effective_token_dim = num_heads * model.config.d_kv
            peft_configs_local = {
                "prefix": PrefixTuningConfig(
                    task_type=TaskType.SEQ_2_SEQ_LM,
                    inference_mode=False,
                    num_virtual_tokens=NUM_VIRTUAL_TOKENS,
                    token_dim=effective_token_dim,
                    num_transformer_submodules=2,
                    num_attention_heads=num_heads,
                    num_layers=total_layers,
                    encoder_hidden_size=d_model,
                    prefix_projection=True  # Baseline with projection
                ),
                "prefix_ablated_no_proj": PrefixTuningConfig(  # Ablation: Remove projection layer
                    task_type=TaskType.SEQ_2_SEQ_LM,
                    inference_mode=False,
                    num_virtual_tokens=NUM_VIRTUAL_TOKENS,
                    token_dim=effective_token_dim,
                    num_transformer_submodules=2,
                    num_attention_heads=num_heads,
                    num_layers=total_layers,
                    encoder_hidden_size=d_model,
                    prefix_projection=False  # Ablated
                )
            }
            model = get_peft_model(model, peft_configs_local[method_name])
            model.print_trainable_parameters()
            
            training_args = get_training_args(method_name, task_name)
            data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
            trainer = CustomTrainer(
                model=model,
                args=training_args,
                train_dataset=dataset["train"],
                eval_dataset=dataset["validation"],
                data_collator=data_collator,
                compute_metrics=compute_metrics,
                tokenizer=tokenizer
            )

            print("Training...")
            train_result = trainer.train()
            
            # Manual load best model for prefix/prompt methods
            if not training_args.load_best_model_at_end and trainer.state.best_model_checkpoint:
                print(f"Loading best checkpoint manually: {trainer.state.best_model_checkpoint}")
                base_model = AutoModelForSeq2SeqLM.from_pretrained(
                    MODEL_NAME,
                    config=config,
                    torch_dtype=torch.bfloat16 if use_bf16 else torch.float32,
                )
                base_model.to(device)
                model = PeftModel.from_pretrained(base_model, trainer.state.best_model_checkpoint)
                trainer.model = model
                model.to(device)  # Ensure the full PEFT model is on device
            
            print("Evaluating...")
            test_dataset = dataset.get("test", dataset["validation"])
            gen_kwargs = {
                # CHANGE: For classification, max_length=5; summ=128; num_beams=6 - Why: Short for classification enforces concise labels (fixes verbose outputs/low acc); more beams improves quality (fixes poor ROUGE)
                "max_length": 128,
                "num_beams": 4,
                # "repetition_penalty": 2.5,
                # "no_repeat_ngram_size": 3,
                "early_stopping": True,
                "do_sample": False,
                # "top_p": 0.95,
                # "temperature": 0.7
            }
            
            # Set generation kwargs for trainer.evaluate
            training_args.generation_max_length = gen_kwargs["max_length"]
            training_args.generation_num_beams = gen_kwargs["num_beams"]
            test_metrics = trainer.evaluate(test_dataset)
            # CHANGE: Added trainer.predict for sample logging post-eval - Why: Debugs generations, fixing empty/low metrics
            predictions = trainer.predict(dataset["validation"])
            # CHANGE: Clean predictions before decoding - Why: Handles -100/invalid IDs, fixing OverflowError in batch_decode
            cleaned_predictions = np.where(predictions.predictions != -100, predictions.predictions, tokenizer.pad_token_id)
            cleaned_predictions = np.clip(cleaned_predictions, 0, tokenizer.vocab_size - 1)
            logger.info(f"Sample generations: {tokenizer.batch_decode(cleaned_predictions[:5], skip_special_tokens=True)}")
            exp_name = f"{method_name}_{task_name}"
            trainable = model.num_parameters(only_trainable=True) if hasattr(model, 'num_parameters') else sum(p.numel() for p in model.parameters() if p.requires_grad)
            total = model.num_parameters() if hasattr(model, 'num_parameters') else sum(p.numel() for p in model.parameters())
            
            results[exp_name] = {
                "train_metrics": train_result.metrics,
                "test_metrics": test_metrics,
                "trainable_params": trainable,
                "total_params": total,
                "log_history": trainer.state.log_history # Collect for plotting
            }
            
            save_path = f"{OUTPUT_DIR}models/{task_name}/{method_name}"
            os.makedirs(save_path, exist_ok=True)
            trainer.save_model(save_path)
            print(f"Completed and saved to {save_path}\n")
            del model, trainer
            safe_cleanup()
        except Exception as e:
            logger.error(f"ERROR in {method_name}_{task_name}: {e}")
            logger.error(traceback.format_exc())
            try:
                del model, trainer
            except:
                pass
            safe_cleanup()


EXPERIMENT: PREFIX on SUMMARIZATION

trainable params: 3,361,280 || all params: 80,322,432 || trainable%: 4.1847
Training...


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,3.4681,3.47511,0.072427,0.001091,0.066198,0.066153
2,3.7129,3.534191,0.074126,0.0037,0.065851,0.065656
3,6.7083,6.644226,0.00771,0.000223,0.00759,0.007546
4,6.4541,6.670243,0.009754,0.000187,0.009701,0.009692
5,5.8377,4.511928,0.051689,0.000457,0.049105,0.048851
6,4.6174,4.028275,0.054418,0.000168,0.05441,0.05429
7,3.8046,3.415081,0.064918,0.000707,0.062715,0.062717
8,3.9893,3.529185,0.067315,7.9e-05,0.063905,0.063847
9,3.7299,3.779026,0.00048,0.0,0.00048,0.000479
10,3.6571,3.36015,0.053936,0.0,0.05131,0.051277


INFO:__main__:Sample pred:                    , label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred:   s         is  Tom Tom Tom   , label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred: ssss .... .........., label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred:                    , label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred: sss                , labe

Loading best checkpoint manually: ./kaggle/working_v2//results/summarization/prefix\checkpoint-461
Evaluating...


INFO:__main__:Sample pred:          for     and and is and  l and      Larry and    a a a a a a a a a a a and a a a and Larry calls Larry calls Larry called Larry called Larry called Larry called Larry called Larry called Larry called Larry called Larry called Larry called Larry calls Larry and Larry called Larry called Larry called her last time to her last time a her last time a to her last time a to her last time they a a Larry., label: Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry.
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred:          for                      ., label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample generations: ['         for                      .', '         for Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma Emma

Completed and saved to ./kaggle/working_v2/models/summarization/prefix


EXPERIMENT: PREFIX_ABLATED_NO_PROJ on SUMMARIZATION

trainable params: 196,608 || all params: 77,157,760 || trainable%: 0.2548
Training...


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,3.41,3.240115,0.29926,0.099158,0.235816,0.235502
2,3.368,3.197887,0.29059,0.092733,0.231382,0.231657
3,3.357,3.186075,0.292785,0.091461,0.232314,0.232355
4,3.3512,3.182191,0.290766,0.090851,0.231529,0.231805
5,3.3354,3.180396,0.290458,0.090104,0.231682,0.231754


INFO:__main__:Sample pred: B., label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred: B: Is as in the animal shelter., label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred: B: Is as in the ., label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred: B: Is as in the ., label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred: B: Is as in the ., label: A will go to the ani

Loading best checkpoint manually: ./kaggle/working_v2//results/summarization/prefix_ablated_no_proj\checkpoint-461
Evaluating...


INFO:__main__:Sample pred: Betty's number is at the park with Betty has Betty's Betty's has Betty has Betty has Betty has Betty's number is not found. Amanda's Betty's's number. Larry's number Larry's's number Larry's's phone number. Amanda't., label: Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry.
INFO:absl:Using default tokenizer.
INFO:__main__:Sample pred: B., label: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
INFO:absl:Using default tokenizer.
INFO:__main__:Sample generations: ['B.', 'Lauren wants for her kids are filled with chocolates are all the advent calendar is going to buy an advent calendar is looking for the advent calendar is looking for Emma is looking for her children.', "Jackie doesn's is pregnant Madison is pregnant, Jackie doesn's pregnant Madison is pregnant Madison is pregnant Madison is pregnant Madison is pregnant Madison is preg

Completed and saved to ./kaggle/working_v2/models/summarization/prefix_ablated_no_proj



In [63]:
if results:
    print("\nRESULTS SUMMARY:")
    print("="*60)
    for exp_name, exp_data in results.items():
        # Handle cases where task name might have underscores
        method_task_split = exp_name.split('_', 1)
        method = method_task_split[0]
        task = method_task_split[1] if len(method_task_split) > 1 else 'unknown'
        
        metrics = exp_data["test_metrics"]
        pct = 100 * exp_data["trainable_params"] / exp_data["total_params"]
        print(f"\n{method.upper()} - {task.capitalize()}:")
        print(f" Trainable: {pct:.2f}%")
        if task == "classification":
            print(f" Accuracy: {metrics.get('eval_accuracy', 0):.4f}")
            print(f" F1: {metrics.get('eval_f1', 0):.4f}")
        else:
            print(f" ROUGE-1: {metrics.get('eval_rouge1', 0):.4f}")
            print(f" ROUGE-L: {metrics.get('eval_rougeL', 0):.4f}")

    # Ablation deltas if enabled
    if RUN_ABLATIONS:
        print("\nABLATION DELTAS:")
        for exp_name, exp_data in results.items():
            if "_ablated_" in exp_name:
                method_task_split = exp_name.split('_ablated_')[0]
                task = exp_name.split('_', 1)[1] # Get task name
                base_method_name = f"{method_task_split}_{task}"
                
                if base_method_name in results:
                    base_metrics = results[base_method_name]["test_metrics"]
                    delta = {k: exp_data["test_metrics"].get(k, 0) - base_metrics.get(k, 0) for k in base_metrics if "eval_" in k}
                    print(f"Delta for {exp_name.upper()}: {delta}")

    # Plot learning curves for each experiment
    print("\nGenerating learning curves...")
    plot_paths = {}
    plot_save_dir = f"{OUTPUT_DIR}/plots" # [FIX] Define plot save dir
    for exp_name, exp_data in results.items():
        task_name = exp_name.split("_", 1)[1]
        # [FIX] Pass the correct save_dir to the plotting function
        plot_path = plot_learning_curves(exp_data["log_history"], exp_name, task_name, save_dir=plot_save_dir)
        plot_paths[exp_name] = plot_path
    
    # Graphical ablation comparisons per task
    ablation_plot_paths = {}
    if RUN_ABLATIONS:
        print("\nGenerating ablation comparison plots...")
        for task_name in tasks.keys():
            task_results = {k: v for k, v in results.items() if k.endswith(f"_{task_name}")}
            if task_results:
                # [FIX] Pass the correct save_dir to the plotting function
                ablation_plot_path = plot_ablation_comparisons(task_results, task_name, save_dir=plot_save_dir)
                if ablation_plot_path:
                    ablation_plot_paths[task_name] = ablation_plot_path

    # --- Results DataFrame ---
    results_df = []
    for exp_name, exp_data in results.items():
        method, task = exp_name.split("_", 1)
        results_df.append({
            "Method": method.upper(),
            "Task": task.capitalize(),
            "Trainable %": 100 * exp_data["trainable_params"] / exp_data["total_params"],
            **{k: v for k, v in exp_data["test_metrics"].items() if isinstance(v, (int, float))}
        })
    
    df = pd.DataFrame(results_df)
    cols = ["Method", "Task", "Trainable %"]
    metric_cols = [c for c in df.columns if c.startswith("eval_")]
    cols.extend(sorted(metric_cols))
    df = df[cols]
    df.to_csv(f"{OUTPUT_DIR}/prefix_results.csv", index=False)
    print(f"\nResults saved to '{OUTPUT_DIR}/prefix_results.csv'")
    
    # --- Final Report --- 
    # Use relative paths for plots in the markdown report
    report_path = f"{OUTPUT_DIR}/prefix_final_report.md"
    report_dir = os.path.dirname(report_path)

    with open(report_path, "w") as f:
        f.write(f"# Prefix-Tuning Adaptation Results - T5-small\n\n")
        f.write(f"## Configuration\n")
        f.write(f"- Model: {MODEL_NAME} (switched from flan-t5-small to fix config dim bug)\n")
        f.write(f"- Dataset Size: {DATASET_SIZE}\n")
        f.write(f"- Methods: Prefix-Tuning\n")
        if RUN_ABLATIONS:
            f.write(f"- Ablations: Enabled (including ablated variants); prefix ablation removes projection layer\n")
        f.write(f"- Special: Native DynamicCache support; correct dims (num_heads=8, head_dim=64)\n\n")
        f.write(f"## Summary Table\n\n")
        f.write(df.to_markdown(index=False))
        f.write("\n\n## Learning Curves\n")
        for exp_name, plot_path in plot_paths.items():
            relative_plot_path = os.path.relpath(plot_path, start=report_dir)
            f.write(f"- [{exp_name}]({relative_plot_path})\n")
        if RUN_ABLATIONS and ablation_plot_paths:
            f.write("\n## Ablation Comparisons\n")
            for task_name, plot_path in ablation_plot_paths.items():
                relative_plot_path = os.path.relpath(plot_path, start=report_dir)
                f.write(f"- [{task_name.capitalize()} Ablation Comparison]({relative_plot_path})\n")
    
    print(f"Report saved to '{report_path}' (includes plot links)")

    # Generate dynamic outcome insights based on results
    print("\nOUTCOME INSIGHTS:")
    if results:
        # General insights from trainable params and metrics
        for task in tasks.keys():
            task_exps = {k: v for k, v in results.items() if k.endswith(task)}
            if task_exps:
                # Find method with lowest trainable %
                min_trainable_method = min(task_exps, key=lambda k: 100 * task_exps[k]["trainable_params"] / task_exps[k]["total_params"])
                min_pct = 100 * task_exps[min_trainable_method]["trainable_params"] / task_exps[min_trainable_method]["total_params"]
                print(f"- For {task.capitalize()}, {min_trainable_method.split('_')[0].upper()} has the lowest trainable params ({min_pct:.2f}%).")
                
                # Find best performing method (use key metric)
                key_metric = 'eval_accuracy' if task == 'classification' else 'eval_rougeL'
                best_method = max(task_exps, key=lambda k: task_exps[k]["test_metrics"].get(key_metric, 0))
                best_score = task_exps[best_method]["test_metrics"].get(key_metric, 0)
                print(f"- {best_method.split('_')[0].upper()} achieves the highest {key_metric.replace('eval_', '').upper()} score ({best_score:.4f}) on {task.capitalize()}.")
        
        # Ablation-specific insights
        if RUN_ABLATIONS:
            for exp_name, exp_data in results.items():
                if "_ablated_" in exp_name:
                    method_task_split = exp_name.split('_ablated_')[0]
                    task = exp_name.split('_', 1)[1] # Get task name
                    base_method_name = f"{method_task_split}_{task}"
                    
                    if base_method_name in results:
                        base_metrics = results[base_method_name]["test_metrics"]
                        delta = {k: exp_data["test_metrics"].get(k, 0) - base_metrics.get(k, 0) for k in base_metrics if "eval_" in k}
                        key_delta = delta.get('eval_accuracy' if task == 'classification' else 'eval_rougeL', 0)
                        impact = "degradation" if key_delta < 0 else "improvement" if key_delta > 0 else "no change"
                        print(f"- Ablation in {exp_name.upper()} leads to {impact} in performance (delta: {key_delta:.4f}).")
        
        print(f"View plots in {OUTPUT_DIR}/plots/ for detailed curves (loss/metric vs step) and comparisons.")
else:
    print("\nNo results were generated. Check the training loop for errors.")

# In[15]:
print("\n" + "="*60)
print("SUCCESS - Prefix-Tuning method completed!" + (" With ablations!" if RUN_ABLATIONS else ""))
print("="*60)


RESULTS SUMMARY:

PREFIX - Summarization:
 Trainable: 0.00%
 ROUGE-1: 0.1155
 ROUGE-L: 0.0935

PREFIX - Ablated_no_proj_summarization:
 Trainable: 0.00%
 ROUGE-1: 0.2645
 ROUGE-L: 0.2096

ABLATION DELTAS:
Delta for PREFIX_ABLATED_NO_PROJ_SUMMARIZATION: {'eval_loss': 0.0, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 0.0, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0}

Generating learning curves...
Learning curves saved to ./kaggle/working_v2//plots\prefix_summarization_curves.png
Learning curves saved to ./kaggle/working_v2//plots\prefix_ablated_no_proj_summarization_curves.png

Generating ablation comparison plots...
Ablation comparison plot saved to ./kaggle/working_v2//plots\ablation_comparison_summarization.png

Results saved to './kaggle/working_v2//prefix_results.csv'
Report saved to './kaggle/working_v2//prefix_final_report.md' (includes plot links)

OUTCOME INSIGHTS:
- For Summarization, PREFIX has the lowest 