In [None]:

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"

import torch
import numpy as np
import pandas as pd
import evaluate
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AutoConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)

from peft import (
    get_peft_model,
    LoraConfig,
    PrefixTuningConfig,
    PromptTuningConfig,
    TaskType,
    PeftModel
)
import logging
import warnings
import json # For saving log_history if needed

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# KAGGLE TOGGLE 
IS_KAGGLE = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', None) is not None  # Detect if running on Kaggle
if IS_KAGGLE:
    os.system('pip install transformers==4.57.1 peft==0.17.1 datasets==4.3.0 torch==2.9.0 evaluate==0.4.6 rouge-score==0.1.2 scikit-learn==1.7.2 accelerate==1.11.0 matplotlib==3.10.7 seaborn==0.13.2 wandb==0.22.3 tabulate==0.9.0 --no-deps')

# DEVICE DETECTION 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if IS_KAGGLE else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")

# CONFIGURATION 
MODEL_NAME = "t5-small" # Switched from flan-t5-small to avoid config dim bug (num_heads=6 mismatch)
DATASET_SIZE = 500 # or 'full'
RUN_ABLATIONS = True  # Toggle to enable/disable ablation study (modular flag)

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

NUM_VIRTUAL_TOKENS = 20 # For truncation safety
MAX_POS = 512

print("="*60)
print("PEFT COMPARISON - T5-small")
print("="*60)
print(f"Dataset size: {DATASET_SIZE}")
print(f"Model: {MODEL_NAME}")
print("Methods: LoRA, Prefix-Tuning, Prompt-Tuning, Full FT")
if RUN_ABLATIONS:
    print("Ablations Enabled: Including ablated variants for study")
    print("Note: For LoRA ablation, using lora_alpha=0 to nullify adapter effect")
print("="*60)
print()
# UTILITIES 
def limit_dataset_size(dataset, size):
    if size == 'full':
        return dataset
    if isinstance(size, int) and size > 0:
        return dataset.select(range(min(size, len(dataset))))
    raise ValueError(f"Invalid size: {size}")
def setup_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer
def safe_cleanup():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    elif device.type == 'mps':
        torch.mps.empty_cache()
def plot_learning_curves(log_history, exp_name, task_name, save_dir="./plots"):
    """Plot train/eval loss and task-specific metrics vs step."""
    os.makedirs(save_dir, exist_ok=True)
   
    # Extract data
    steps = [log['step'] for log in log_history if 'step' in log]
    train_losses = [log['train_loss'] for log in log_history if 'train_loss' in log]
    eval_losses = [log['eval_loss'] for log in log_history if 'eval_loss' in log]
   
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    sns.set_style("whitegrid")
   
    # Loss curve
    axes[0].plot(steps[:len(train_losses)], train_losses, label='Train Loss', marker='o')
    if eval_losses:
        axes[0].plot(steps[:len(eval_losses)], eval_losses, label='Eval Loss', marker='s')
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].set_title(f'{exp_name} - Loss Curve')
    axes[0].legend()
   
    # Task-specific metric
    if task_name == "classification":
        eval_accs = [log['eval_accuracy'] for log in log_history if 'eval_accuracy' in log]
        if eval_accs:
            axes[1].plot(steps[:len(eval_accs)], eval_accs, label='Eval Accuracy', marker='o', color='green')
            axes[1].set_ylabel('Accuracy')
    else: # summarization
        eval_rouge_ls = [log['eval_rougeL'] for log in log_history if 'eval_rougeL' in log]
        if eval_rouge_ls:
            axes[1].plot(steps[:len(eval_rouge_ls)], eval_rouge_ls, label='Eval ROUGE-L', marker='o', color='green')
            axes[1].set_ylabel('ROUGE-L')
   
    axes[1].set_xlabel('Step')
    axes[1].set_title(f'{exp_name} - {task_name.capitalize()} Metric')
    axes[1].legend()
   
    plt.tight_layout()
    plot_path = f"{save_dir}/{exp_name}_curves.png"
    plt.savefig(plot_path)
    plt.close()
    print(f"✓ Learning curves saved to {plot_path}")
    return plot_path


def plot_ablation_comparisons(results, task_name, save_dir="./plots"):
    """Graphical analysis: Compare baselines vs ablations for a task."""
    os.makedirs(save_dir, exist_ok=True)
    methods = list(results.keys())
    baselines = [m for m in methods if "_ablated_" not in m]
    ablations = [m for m in methods if "_ablated_" in m]
    
    if not ablations:
        return None
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    sns.set_style("whitegrid")
    
    # Trainable params comparison
    trainable_pcts = [100 * results[m]["trainable_params"] / results[m]["total_params"] for m in methods]
    sns.barplot(x=methods, y=trainable_pcts, ax=axes[0])
    axes[0].set_ylabel('Trainable %')
    axes[0].set_title(f'Trainable Params Comparison - {task_name.capitalize()}')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Metric comparison (use key metric)
    if task_name == "classification":
        metrics = [results[m]["test_metrics"].get("eval_accuracy", 0) for m in methods]
        metric_label = 'Accuracy'
    else:
        metrics = [results[m]["test_metrics"].get("eval_rougeL", 0) for m in methods]
        metric_label = 'ROUGE-L'
    
    sns.barplot(x=methods, y=metrics, ax=axes[1])
    axes[1].set_ylabel(metric_label)
    axes[1].set_title(f'Performance Comparison - {task_name.capitalize()}')
    axes[1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plot_path = f"{save_dir}/ablation_comparison_{task_name}.png"
    plt.savefig(plot_path)
    plt.close()
    print(f"✓ Ablation comparison plot saved to {plot_path}")
    return plot_path


# LOAD DATASETS 
print("Loading datasets...")
classification_dataset = load_dataset("glue", "sst2")
summarization_dataset = load_dataset("knkarthick/samsum")
tokenizer = setup_tokenizer(MODEL_NAME)
if DATASET_SIZE != 'full':
    classification_dataset['train'] = limit_dataset_size(classification_dataset['train'], DATASET_SIZE)
    classification_dataset['validation'] = limit_dataset_size(classification_dataset['validation'], DATASET_SIZE // 4)
    classification_dataset['test'] = limit_dataset_size(classification_dataset.get('test', classification_dataset['validation']), DATASET_SIZE // 4)
    summarization_dataset['train'] = limit_dataset_size(summarization_dataset['train'], DATASET_SIZE)
    summarization_dataset['validation'] = limit_dataset_size(summarization_dataset['validation'], DATASET_SIZE // 4)
    summarization_dataset['test'] = limit_dataset_size(summarization_dataset['test'], DATASET_SIZE // 4)
print("✓ Datasets loaded\n")
# PREPROCESSING 
def preprocess_classification(examples):
    inputs = [f"Classify sentiment: {text}" for text in examples["sentence"]]
    max_input_len = MAX_POS - NUM_VIRTUAL_TOKENS
    model_inputs = tokenizer(inputs, max_length=max_input_len, truncation=True, padding="max_length")
    labels_text = ["negative" if label == 0 else "positive" for label in examples["label"]]
    labels = tokenizer(text_target=labels_text, max_length=10, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
def preprocess_summarization(examples):
    inputs = [f"Summarize the following conversation:\n{dialogue}" for dialogue in examples["dialogue"]]
    max_input_len = MAX_POS - NUM_VIRTUAL_TOKENS
    model_inputs = tokenizer(inputs, max_length=max_input_len, truncation=True, padding="max_length")
    max_label_len = 128 - NUM_VIRTUAL_TOKENS
    labels = tokenizer(text_target=examples["summary"], max_length=max_label_len, truncation=True, padding="max_length").input_ids
    model_inputs["labels"] = labels
    return model_inputs
tokenized_classification = classification_dataset.map(preprocess_classification, batched=True, remove_columns=classification_dataset["train"].column_names)
tokenized_summarization = summarization_dataset.map(preprocess_summarization, batched=True, remove_columns=summarization_dataset["train"].column_names)
print("✓ Preprocessing complete\n")
# METRICS 
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
rouge_metric = evaluate.load("rouge")
def compute_classification_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        decoded_preds = [p.strip().lower() for p in decoded_preds]
        decoded_labels = [l.strip().lower() for l in decoded_labels]
        pred_binary = [1 if 'positive' in p else 0 for p in decoded_preds]
        label_binary = [1 if 'positive' in l else 0 for l in decoded_labels]
        acc = accuracy_metric.compute(predictions=pred_binary, references=label_binary)
        f1 = f1_metric.compute(predictions=pred_binary, references=label_binary, average="weighted")
        return {"accuracy": acc["accuracy"], "f1": f1["f1"]}
    except Exception as e:
        logger.error(f"Metrics error: {e}")
        return {"accuracy": 0.0, "f1": 0.0}
def compute_summarization_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)
        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        decoded_preds = [p.strip() if p.strip() else "empty" for p in decoded_preds]
        decoded_labels = [l.strip() if l.strip() else "empty" for l in decoded_labels]
        result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        return {
            "rouge1": result["rouge1"],
            "rouge2": result["rouge2"],
            "rougeL": result["rougeL"],
            "rougeLsum": result["rougeLsum"]
        }
    except Exception as e:
        logger.error(f"Metrics error: {e}")
        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "rougeLsum": 0.0}
# TRAINING ARGS 
def get_training_args(method_name, task_name):
    is_peft = method_name in ["lora", "prefix", "prompt"] or "_ablated_" in method_name
    lr = 1e-3 if is_peft else 5e-5
    if DATASET_SIZE == 'full':
        epochs, batch, eval_steps = 3, 8, 500
    elif DATASET_SIZE <= 500:
        epochs, batch, eval_steps = 5, 4, 50
    else:
        epochs, batch, eval_steps = 3, 8, 100
    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    use_fp16 = False  # Disabled to avoid NaN losses
    load_best = method_name == "full_ft" or "lora" in method_name
    return Seq2SeqTrainingArguments(
        output_dir=f"./results/{task_name}/{method_name}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch,
        per_device_eval_batch_size=batch * 2,
        learning_rate=lr,
        warmup_steps=min(100, DATASET_SIZE // 10) if DATASET_SIZE != 'full' else 500,
        weight_decay=0.01,
        eval_strategy="steps" if DATASET_SIZE != 'full' else "epoch",
        eval_steps=eval_steps if DATASET_SIZE != 'full' else None,
        save_strategy="steps" if DATASET_SIZE != 'full' else "epoch",
        save_steps=eval_steps if DATASET_SIZE != 'full' else None,
        load_best_model_at_end=load_best,
        metric_for_best_model="eval_loss",
        save_total_limit=2,
        logging_steps=20 if DATASET_SIZE != 'full' else 100,
        bf16=use_bf16,
        fp16=use_fp16,
        dataloader_num_workers=0,
        dataloader_drop_last=True, # Avoid incomplete batches for stability
        report_to="none",
        predict_with_generate=True,
        max_grad_norm=1.0,  # Added to prevent gradient explosions
    )
# MAIN TRAINING LOOP 
base_methods = ["lora", "prefix", "prompt", "full_ft"]
ablation_methods = ["lora_ablated_alpha0", "prefix_ablated_no_proj", "prompt_ablated_short"]
methods_to_run = base_methods + (ablation_methods if RUN_ABLATIONS else [])
tasks = {
    "classification": (tokenized_classification, compute_classification_metrics),
    "summarization": (tokenized_summarization, compute_summarization_metrics)
}
results = {}
os.makedirs("./results", exist_ok=True)
os.makedirs("./models", exist_ok=True)
os.makedirs("./plots", exist_ok=True) # For curves
for method_name in methods_to_run:
    for task_name, (dataset, compute_metrics) in tasks.items():
        print(f"\n{'='*60}")
        print(f"EXPERIMENT: {method_name.upper()} on {task_name.upper()}")
        print(f"{'='*60}\n")
        try:
            config = AutoConfig.from_pretrained(MODEL_NAME)
            use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
            model = AutoModelForSeq2SeqLM.from_pretrained(
                MODEL_NAME,
                config=config,
                dtype=torch.bfloat16 if use_bf16 else torch.float32,
            )
            model.to(device)
            # Note: t5-small has correct dims (num_heads=8, head_dim=64); PEFT handles DynamicCache natively.
            # Create PEFT configs dynamically from model.config
            if method_name != "full_ft":
                d_model = model.config.d_model
                num_heads = model.config.num_heads
                total_layers = model.config.num_layers + model.config.num_decoder_layers
                peft_configs_local = {
                    "lora": LoraConfig(
                        r=16,
                        lora_alpha=32,
                        target_modules=["q", "v"],
                        lora_dropout=0.05,
                        bias="none",
                        task_type=TaskType.SEQ_2_SEQ_LM
                    ),
                    "lora_ablated_alpha0": LoraConfig(
                        r=16,
                        lora_alpha=0,  # Ablation: zero scaling, no effect from adapter
                        target_modules=["q", "v"],
                        lora_dropout=0.05,
                        bias="none",
                        task_type=TaskType.SEQ_2_SEQ_LM
                    ),
                    "prefix": PrefixTuningConfig(
                        task_type=TaskType.SEQ_2_SEQ_LM,
                        inference_mode=False,
                        num_virtual_tokens=NUM_VIRTUAL_TOKENS,
                        token_dim=d_model,
                        num_transformer_submodules=2,
                        num_attention_heads=num_heads,
                        num_layers=total_layers,
                        encoder_hidden_size=d_model,
                        prefix_projection=True  # Baseline with projection
                    ),
                    "prefix_ablated_no_proj": PrefixTuningConfig(  # Ablation: Remove projection layer
                        task_type=TaskType.SEQ_2_SEQ_LM,
                        inference_mode=False,
                        num_virtual_tokens=NUM_VIRTUAL_TOKENS,
                        token_dim=d_model,
                        num_transformer_submodules=2,
                        num_attention_heads=num_heads,
                        num_layers=total_layers,
                        encoder_hidden_size=d_model,
                        prefix_projection=False  # Ablated
                    ),
                    "prompt": PromptTuningConfig(
                        num_virtual_tokens=NUM_VIRTUAL_TOKENS,
                        task_type=TaskType.SEQ_2_SEQ_LM,
                        prompt_tuning_init="RANDOM"
                    ),
                    "prompt_ablated_short": PromptTuningConfig(  # Ablation: Fewer tokens (e.g., half)
                        num_virtual_tokens=NUM_VIRTUAL_TOKENS // 2,
                        task_type=TaskType.SEQ_2_SEQ_LM,
                        prompt_tuning_init="RANDOM"
                    )
                }
                model = get_peft_model(model, peft_configs_local[method_name])
                model.print_trainable_parameters()
            else:
                trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
                total = sum(p.numel() for p in model.parameters())
                print(f"trainable params: {trainable:,} || all params: {total:,} || trainable%: 100.00")
            training_args = get_training_args(method_name, task_name)
            data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
            trainer = Seq2SeqTrainer(
                model=model,
                args=training_args,
                train_dataset=dataset["train"],
                eval_dataset=dataset["validation"],
                data_collator=data_collator,
                compute_metrics=compute_metrics,
                tokenizer=tokenizer
            )
            print("Training...")
            train_result = trainer.train()
            # Manual load best model for prefix/prompt methods
            if not training_args.load_best_model_at_end and trainer.state.best_model_checkpoint:
                print(f"Loading best checkpoint manually: {trainer.state.best_model_checkpoint}")
                base_model = AutoModelForSeq2SeqLM.from_pretrained(
                    MODEL_NAME,
                    config=config,
                    dtype=torch.bfloat16 if use_bf16 else torch.float32,
                )
                base_model.to(device)
                model = PeftModel.from_pretrained(base_model, trainer.state.best_model_checkpoint)
                trainer.model = model
                model.to(device)  # Ensure the full PEFT model is on device
            print("Evaluating...")
            test_dataset = dataset.get("test", dataset["validation"])
            gen_kwargs = {
                "max_length": 128 if task_name == "summarization" else 10,
                "num_beams": 4,
                "early_stopping": True,
            }
            training_args.generation_max_length = gen_kwargs["max_length"]
            training_args.generation_num_beams = gen_kwargs["num_beams"]
            test_metrics = trainer.evaluate(test_dataset)
            exp_name = f"{method_name}_{task_name}"
            trainable = model.num_parameters(only_trainable=True) if hasattr(model, 'num_parameters') else sum(p.numel() for p in model.parameters() if p.requires_grad)
            total = model.num_parameters() if hasattr(model, 'num_parameters') else sum(p.numel() for p in model.parameters())
            results[exp_name] = {
                "train_metrics": train_result.metrics,
                "test_metrics": test_metrics,
                "trainable_params": trainable,
                "total_params": total,
                "log_history": trainer.state.log_history # Collect for plotting
            }
            save_path = f"./models/{task_name}/{method_name}"
            os.makedirs(save_path, exist_ok=True)
            trainer.save_model(save_path)
            print(f"✓ Completed and saved to {save_path}\n")
            del model, trainer
            safe_cleanup()
        except Exception as e:
            logger.error(f"ERROR in {method_name}_{task_name}: {e}")
            import traceback
            logger.error(traceback.format_exc())
            try:
                del model, trainer
            except:
                pass
            safe_cleanup()
print("\n" + "="*60)
print("ALL EXPERIMENTS COMPLETED")
print("="*60)
# RESULTS 
if results:
    print("\nRESULTS SUMMARY:")
    print("="*60)
    for exp_name, exp_data in results.items():
        method, task = exp_name.split("_", 1)
        metrics = exp_data["test_metrics"]
        pct = 100 * exp_data["trainable_params"] / exp_data["total_params"]
        print(f"\n{method.upper()} - {task.capitalize()}:")
        print(f" Trainable: {pct:.2f}%")
        if task == "classification":
            print(f" Accuracy: {metrics.get('eval_accuracy', 0):.4f}")
            print(f" F1: {metrics.get('eval_f1', 0):.4f}")
        else:
            print(f" ROUGE-1: {metrics.get('eval_rouge1', 0):.4f}")
            print(f" ROUGE-L: {metrics.get('eval_rougeL', 0):.4f}")
    # Ablation deltas if enabled
    if RUN_ABLATIONS:
        print("\nABLATION DELTAS:")
        for exp_name, exp_data in results.items():
            method, task = exp_name.split("_", 1)
            if "_ablated_" in method:
                base_method = method.split("_ablated_")[0] + "_" + task
                if base_method in results:
                    base_metrics = results[base_method]["test_metrics"]
                    delta = {k: exp_data["test_metrics"].get(k, 0) - base_metrics.get(k, 0) for k in base_metrics if "eval_" in k}
                    print(f"Delta for {method.upper()} - {task.capitalize()}: {delta}")
    # Plot learning curves for each experiment
    print("\nGenerating learning curves...")
    plot_paths = {}
    for exp_name, exp_data in results.items():
        task_name = exp_name.split("_", 1)[1]
        plot_path = plot_learning_curves(exp_data["log_history"], exp_name, task_name)
        plot_paths[exp_name] = plot_path
    # Graphical ablation comparisons per task
    if RUN_ABLATIONS:
        print("\nGenerating ablation comparison plots...")
        ablation_plot_paths = {}
        for task_name in tasks.keys():
            task_results = {k: v for k, v in results.items() if k.endswith(f"_{task_name}")}
            if task_results:
                ablation_plot_path = plot_ablation_comparisons(task_results, task_name)
                if ablation_plot_path:
                    ablation_plot_paths[task_name] = ablation_plot_path
    results_df = []
    for exp_name, exp_data in results.items():
        method, task = exp_name.split("_", 1)
        results_df.append({
            "Method": method.upper(),
            "Task": task.capitalize(),
            "Trainable %": 100 * exp_data["trainable_params"] / exp_data["total_params"],
            **{k: v for k, v in exp_data["test_metrics"].items() if isinstance(v, (int, float))}
        })
    df = pd.DataFrame(results_df)
    cols = ["Method", "Task", "Trainable %"]
    metric_cols = [c for c in df.columns if c.startswith("eval_")]
    cols.extend(sorted(metric_cols))
    df = df[cols]
    df.to_csv("peft_results.csv", index=False)
    print(f"\n✓ Results saved to 'peft_results.csv'")
    with open("final_report.md", "w") as f:
        f.write(f"# PEFT Comparison Results - T5-small\n\n")
        f.write(f"## Configuration\n")
        f.write(f"- Model: {MODEL_NAME} (switched from flan-t5-small to fix config dim bug)\n")
        f.write(f"- Dataset Size: {DATASET_SIZE}\n")
        f.write(f"- Methods: LoRA, Prefix-Tuning, Prompt-Tuning, Full Fine-Tuning\n")
        if RUN_ABLATIONS:
            f.write(f"- Ablations: Enabled (including ablated variants); LoRA ablation uses lora_alpha=0 for no adaptation effect\n")
        f.write(f"- Special: Native DynamicCache support; correct dims (num_heads=8, head_dim=64)\n\n")
        f.write(f"## Summary Table\n\n")
        f.write(df.to_markdown(index=False))
        f.write("\n\n## Learning Curves\n")
        for exp_name, plot_path in plot_paths.items():
            f.write(f"- [{exp_name}]({plot_path})\n")
        if RUN_ABLATIONS and ablation_plot_paths:
            f.write("\n## Ablation Comparisons\n")
            for task_name, plot_path in ablation_plot_paths.items():
                f.write(f"- [{task_name.capitalize()} Ablation Comparison]({plot_path})\n")
    print("✓ Report saved to 'final_report.md' (includes plot links)")
    # Generate dynamic outcome insights based on results
    print("\nOUTCOME INSIGHTS:")
    if results:
        # General insights from trainable params and metrics
        for task in tasks.keys():
            task_exps = {k: v for k, v in results.items() if k.endswith(task)}
            if task_exps:
                # Find method with lowest trainable %
                min_trainable_method = min(task_exps, key=lambda k: 100 * task_exps[k]["trainable_params"] / task_exps[k]["total_params"])
                min_pct = 100 * task_exps[min_trainable_method]["trainable_params"] / task_exps[min_trainable_method]["total_params"]
                print(f"- For {task.capitalize()}, {min_trainable_method.split('_')[0].upper()} has the lowest trainable params ({min_pct:.2f}%).")
                
                # Find best performing method (use key metric)
                key_metric = 'eval_accuracy' if task == 'classification' else 'eval_rougeL'
                best_method = max(task_exps, key=lambda k: task_exps[k]["test_metrics"].get(key_metric, 0))
                best_score = task_exps[best_method]["test_metrics"].get(key_metric, 0)
                print(f"- {best_method.split('_')[0].upper()} achieves the highest {key_metric.replace('eval_', '').upper()} score ({best_score:.4f}) on {task.capitalize()}.")
        
        # Ablation-specific insights
        if RUN_ABLATIONS:
            for exp_name, exp_data in results.items():
                method, task = exp_name.split("_", 1)
                if "_ablated_" in method:
                    base_method = method.split("_ablated_")[0] + "_" + task
                    if base_method in results:
                        base_metrics = results[base_method]["test_metrics"]
                        delta = {k: exp_data["test_metrics"].get(k, 0) - base_metrics.get(k, 0) for k in base_metrics if "eval_" in k}
                        key_delta = delta.get('eval_accuracy' if task == 'classification' else 'eval_rougeL', 0)
                        impact = "degradation" if key_delta < 0 else "improvement" if key_delta > 0 else "no change"
                        print(f"- Ablation in {method.upper()} on {task.capitalize()} leads to {impact} in performance (delta: {key_delta:.4f}).")
        
        print(f"View plots in ./plots/ for detailed curves (loss/metric vs step) and comparisons.")
print("\n" + "="*60)
print("SUCCESS - All 4 PEFT methods completed!" + (" With ablations!" if RUN_ABLATIONS else ""))
print("="*60)
print("\nKey Features:")
print("LoRA: Most efficient and reliable")
print("Prefix-Tuning: Fully compatible with correct T5 config (no reshape errors)")
print("Prompt-Tuning: Ultra parameter-efficient")
print("Full Fine-Tuning: Baseline comparison")
if RUN_ABLATIONS:
    print("Ablation Study: Modular, toggle with RUN_ABLATIONS flag; includes deltas and comparison plots")
print("="*60)