In [None]:
# ================================
# 1. Environment and dependencies
# ================================
!pip install --no-deps evaluate protobuf<4.0

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"

import torch
import numpy as np
import pandas as pd
import evaluate
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import Dataset, DatasetDict
from sklearn.metrics import confusion_matrix
from collections import Counter

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AutoConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)

from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    PeftModel,
)

import logging
import warnings
import json
import traceback

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ================
# 2. Device setup
# ================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# =================
# 3. Configuration
# =================
MODEL_NAME = "google/flan-t5-small"

# NOTE: No external HF datasets are used now.
# BENCHMARK_GLUE = "glue"
# GLUE_DATASET_TASK_SC = "sst2"
# SUMMARIZATION_DATASET = "knkarthick/samsum"

PROGRAM_NAME = "ift-lora"
DATASET_SIZE = "full"  # still used as a logical flag, but we define tiny datasets below
RUN_ABLATIONS = False

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

NUM_VIRTUAL_TOKENS = 50
MAX_POS = 512

OUTPUT_DIR = f"/kaggle/working/outputs_{PROGRAM_NAME}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("=" * 60)
print(f"LoRA COMPARISON - {MODEL_NAME.split('/')[-1]}")
print("=" * 60)
print(f"Dataset size: {DATASET_SIZE}")
print(f"Model: {MODEL_NAME}")
print("Methods: LoRA")
print("=" * 60)
print()

# ==========================
# 4. Helper / utility funcs
# ==========================
def limit_dataset_size(dataset, size):
    if size == "full":
        return dataset
    if isinstance(size, int) and size > 0:
        return dataset.select(range(min(size, len(dataset))))
    raise ValueError(f"Invalid size {size}")

def setup_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def safe_cleanup():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def plot_learning_curves(log_history, exp_name, task_name, save_dir="plots"):
    os.makedirs(save_dir, exist_ok=True)

    steps = [log["step"] for log in log_history if "step" in log and "eval_loss" not in log]
    eval_steps = [log["step"] for log in log_history if "eval_loss" in log]
    train_losses = [log["loss"] for log in log_history if "loss" in log]
    eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    sns.set(style="whitegrid")

    # Loss curve
    train_steps_for_loss = [log["step"] for log in log_history if "loss" in log]
    axes[0].plot(train_steps_for_loss, train_losses, label="Train Loss", marker="o", alpha=0.7)
    if eval_losses:
        axes[0].plot(eval_steps, eval_losses, label="Eval Loss", marker="s")
    axes[0].set_xlabel("Step")
    axes[0].set_ylabel("Loss")
    axes[0].set_title(f"{exp_name} - Loss Curve")
    axes[0].legend()

    # Task-specific metric
    if task_name == "classification":
        eval_accs = [log["eval_accuracy"] for log in log_history if "eval_accuracy" in log]
        if eval_accs:
            axes[1].plot(eval_steps, eval_accs, label="Eval Accuracy", marker="o", color="green")
        axes[1].set_ylabel("Accuracy")
    else:
        eval_rougels = [log["eval_rougeL"] for log in log_history if "eval_rougeL" in log]
        if eval_rougels:
            axes[1].plot(eval_steps, eval_rougels, label="Eval ROUGE-L", marker="o", color="green")
        axes[1].set_ylabel("ROUGE-L")

    axes[1].set_xlabel("Step")
    axes[1].set_title(f"{exp_name} - {task_name.capitalize()} Metric")
    axes[1].legend()

    plt.tight_layout()
    plot_path = os.path.join(save_dir, f"{exp_name}_curves.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Learning curves saved to {plot_path}")
    return plot_path

def plot_ablation_comparisons(results, task_name, save_dir="plots"):
    os.makedirs(save_dir, exist_ok=True)

    methods = list(results.keys())
    baselines = [m for m in methods if "ablated" not in m]
    ablations = [m for m in methods if "ablated" in m]

    if not ablations:
        return None

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    sns.set(style="whitegrid")

    trainable_pcts = [
        100 * results[m]["trainable_params"] / results[m]["total_params"]
        for m in methods
    ]
    sns.barplot(x=methods, y=trainable_pcts, ax=axes[0])
    axes[0].set_ylabel("Trainable %")
    axes[0].set_title(f"Trainable Params Comparison - {task_name.capitalize()}")
    axes[0].tick_params(axis="x", rotation=45)

    if task_name == "classification":
        metrics = [results[m]["test_metrics"].get("eval_accuracy", 0) for m in methods]
        metric_label = "Accuracy"
    else:
        metrics = [results[m]["test_metrics"].get("eval_rougeL", 0) for m in methods]
        metric_label = "ROUGE-L"

    sns.barplot(x=methods, y=metrics, ax=axes[1])
    axes[1].set_ylabel(metric_label)
    axes[1].set_title(f"Performance Comparison - {task_name.capitalize()}")
    axes[1].tick_params(axis="x", rotation=45)

    plt.tight_layout()
    plot_path = os.path.join(save_dir, f"ablation_comparison_{task_name}.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Ablation comparison plot saved to {plot_path}")
    return plot_path

# ===========================
# 5. Offline toy datasets
# ===========================

print("Loading offline toy datasets (no HF Hub download)...")

# Small SST-2-like classification dataset
sst2_train_examples = {
    "sentence": [
        "I love this movie, it is fantastic!",
        "This film was terrible and boring.",
        "What a great experience, highly recommended.",
        "I hated every minute of this.",
        "The plot was interesting and engaging.",
        "The acting was awful and the script was bad.",
    ],
    "label": [1, 0, 1, 0, 1, 0],
    "idx": list(range(6)),
}

sst2_val_examples = {
    "sentence": [
        "Absolutely wonderful!",
        "Not good at all.",
    ],
    "label": [1, 0],
    "idx": [100, 101],
}

sst2_test_examples = {
    "sentence": [
        "It was okay, not the best.",
        "Really enjoyed it.",
    ],
    "label": [0, 1],  # ground truth for evaluation
    "idx": [200, 201],
}

classification_dataset = DatasetDict(
    {
        "train": Dataset.from_dict(sst2_train_examples),
        "validation": Dataset.from_dict(sst2_val_examples),
        "test": Dataset.from_dict(sst2_test_examples),
    }
)

# Small SAMSum-like summarization dataset
samsum_train_examples = {
    "id": ["1", "2"],
    "dialogue": [
        "A: I baked cookies.\nB: Really?\nA: Yes, do you want some?\nB: Sure!",
        "A: Who are you voting for in this election?\nB: The liberals as always.\nA: Me too!",
    ],
    "summary": [
        "A baked cookies and B wants some.",
        "A and B are voting for the liberals in this election.",
    ],
}

samsum_val_examples = {
    "id": ["3"],
    "dialogue": [
        "A: Hi, what's up?\nB: I'm in a bad mood.\nA: Why?\nB: I procrastinated.",
    ],
    "summary": [
        "B is in a bad mood because of procrastination.",
    ],
}

samsum_test_examples = {
    "id": ["4"],
    "dialogue": [
        "A: When and where are we meeting?\nB: I thought you were busy.\nA: I quit my job.",
    ],
    "summary": [
        "A quit job and wants to meet B.",
    ],
}

summarization_dataset = DatasetDict(
    {
        "train": Dataset.from_dict(samsum_train_examples),
        "validation": Dataset.from_dict(samsum_val_examples),
        "test": Dataset.from_dict(samsum_test_examples),
    }
)

tokenizer = setup_tokenizer(MODEL_NAME)

print("Datasets loaded (offline toy versions).")

# =========================
# 6. Preprocessing functions
# =========================
def preprocess_classification(examples):
    inputs = [f"Classify sentiment: {text}" for text in examples["sentence"]]
    max_input_len = MAX_POS - NUM_VIRTUAL_TOKENS
    model_inputs = tokenizer(
        inputs,
        max_length=max_input_len,
        truncation=True,
        padding="max_length",
    )
    labels_text = ["negative" if label == 0 else "positive" for label in examples["label"]]
    labels = tokenizer(
        text_target=labels_text,
        max_length=10,
        truncation=True,
        padding="max_length",
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def preprocess_summarization(examples):
    inputs = [f"Summarize the following conversation: {dialogue}" for dialogue in examples["dialogue"]]
    max_input_len = MAX_POS - NUM_VIRTUAL_TOKENS
    model_inputs = tokenizer(
        inputs,
        max_length=max_input_len,
        truncation=True,
        padding="max_length",
    )
    max_label_len = 128 - NUM_VIRTUAL_TOKENS
    labels = tokenizer(
        text_target=examples["summary"],
        max_length=max_label_len,
        truncation=True,
        padding="max_length",
    )["input_ids"]
    model_inputs["labels"] = labels
    return model_inputs

print("\nPreprocessing...")
tokenized_classification = classification_dataset.map(
    preprocess_classification, batched=True, remove_columns=classification_dataset["train"].column_names
)
tokenized_summarization = summarization_dataset.map(
    preprocess_summarization, batched=True, remove_columns=summarization_dataset["train"].column_names
)
print("Preprocessing complete.\n")

# =========================
# 7. Decoding helper
# =========================
def decode_example(example: dict, tokenizer, task: str):
    input_txt = tokenizer.decode(example["input_ids"], skip_special_tokens=False)
    input_txt = input_txt.split(tokenizer.eos_token)[0].replace(tokenizer.eos_token, "")
    label_ids = [tok_id if tok_id != -100 else tokenizer.pad_token_id for tok_id in example["labels"]]
    label_txt = tokenizer.decode(label_ids, skip_special_tokens=True)
    input_preview = " ".join(map(str, example["input_ids"][:30]))
    label_preview = " ".join(map(str, label_ids[:15]))
    return {
        "input_text": input_txt,
        "label_text": label_txt,
        "input_ids_preview": input_preview,
        "label_ids_preview": label_preview,
    }

print("Classification post-preprocessing\n")
for i, ex in enumerate(tokenized_classification["train"].select(range(min(5, len(tokenized_classification["train"])) ))):
    decoded = decode_example(ex, tokenizer, task="classification")
    print(f"--- Example {i+1} ---")
    print("INPUT:", decoded["input_text"])
    print("LABEL:", decoded["label_text"])
    print("input_ids first 30:", decoded["input_ids_preview"])
    print("label_ids first 15:", decoded["label_ids_preview"])
    print()

print("Summarisation post-preprocessing\n")
for i, ex in enumerate(tokenized_summarization["train"].select(range(min(5, len(tokenized_summarization["train"])) ))):
    decoded = decode_example(ex, tokenizer, task="summarization")
    print(f"--- Example {i+1} ---")
    print("INPUT:", decoded["input_text"])
    print("SUMMARY:", decoded["label_text"])
    print("input_ids first 30:", decoded["input_ids_preview"])
    print("label_ids first 15:", decoded["label_ids_preview"])
    print()

# ==================
# 8. Metrics setup
# ==================
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
rouge_metric = evaluate.load("rouge")

def compute_classification_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)

        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)

        if np.any(predictions < 0) or np.any(labels < 0):
            logger.warning("Found negative values in predictions or labels. Clamping to 0.")
            predictions = np.clip(predictions, 0, None)
            labels = np.clip(labels, 0, None)

        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        logger.info(f"Sample pred: {decoded_preds[0]}, label: {decoded_labels[0]}")

        decoded_preds = [p.strip().lower() for p in decoded_preds]
        decoded_labels = [l.strip().lower() for l in decoded_labels]

        pred_binary = [1 if p == "positive" else 0 for p in decoded_preds]
        label_binary = [1 if l == "positive" else 0 for l in decoded_labels]

        acc = accuracy_metric.compute(predictions=pred_binary, references=label_binary)
        f1 = f1_metric.compute(predictions=pred_binary, references=label_binary, average="weighted")
        return {
            "accuracy": acc.get("accuracy", 0.0),
            "f1": f1.get("f1", 0.0),
        }
    except Exception as e:
        logger.error(f"Classification metrics error: {e}. Returning defaults.")
        return {"accuracy": 0.0, "f1": 0.0}

def compute_summarization_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)

        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

        if np.any(predictions < 0) or np.any(labels < 0):
            logger.warning("Found negative values in predictions or labels. Clamping to 0.")
            predictions = np.clip(predictions, 0, None)
            labels = np.clip(labels, 0, None)

        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        logger.info(f"Sample pred: {decoded_preds[0]}, label: {decoded_labels[0]}")

        decoded_preds = [p.strip() if p.strip() else "" for p in decoded_preds]
        decoded_labels = [l.strip() if l.strip() else "" for l in decoded_labels]

        result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        return {
            "rouge1": result.get("rouge1", 0.0),
            "rouge2": result.get("rouge2", 0.0),
            "rougeL": result.get("rougeL", 0.0),
            "rougeLsum": result.get("rougeLsum", 0.0),
        }
    except Exception as e:
        logger.error(f"Summarization metrics error: {e}. Returning defaults.")
        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "rougeLsum": 0.0}

# =====================
# 9. Training arguments
# =====================
def get_training_args(method_name, task_name):
    is_peft = ("lora" in method_name) or ("ablated" in method_name)
    lr = 3e-4 if is_peft else 1e-5

    if DATASET_SIZE == "full":
        epochs = 5 if task_name == "summarization" else 3
        batch_size = 2
        eval_steps = None
        eval_strategy = "epoch"
        save_strategy = "epoch"
        logging_steps = 10
        save_steps = None
    else:
        epochs, batch_size, eval_steps = 3, 2, 10
        eval_strategy = "steps"
        save_strategy = "steps"
        logging_steps = max(1, eval_steps // 2)
        save_steps = eval_steps

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    use_fp16 = (not use_bf16) and torch.cuda.is_available()

    load_best = "lora" in method_name

    return Seq2SeqTrainingArguments(
        output_dir=f"{OUTPUT_DIR}/results/{task_name}/{method_name}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=lr,
        warmup_steps=200 if DATASET_SIZE == "full" else 50,
        weight_decay=0.1,
        eval_strategy=eval_strategy,
        eval_steps=eval_steps,
        save_strategy=save_strategy,
        save_steps=save_steps,
        load_best_model_at_end=load_best,
        metric_for_best_model="eval_loss",
        save_total_limit=2,
        logging_steps=logging_steps,
        bf16=use_bf16,
        fp16=use_fp16,
        dataloader_num_workers=0,
        dataloader_drop_last=True,
        report_to="none",
        predict_with_generate=True,
        max_grad_norm=1.0,
        gradient_accumulation_steps=1,
        labels_smoothing_factor=0.0,
        optim="adamw_torch",
        gradient_checkpointing=False,
    )

# ==========================
# 10. Main training loop
# ==========================
base_methods = ["lora"]
ablation_methods = ["lora_ablated_alpha0"] if RUN_ABLATIONS else []
methods_to_run = base_methods + ablation_methods

tasks = {
    "classification": (tokenized_classification, compute_classification_metrics),
    "summarization": (tokenized_summarization, compute_summarization_metrics),
}

results = {}
os.makedirs(f"{OUTPUT_DIR}/results", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/models", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/plots", exist_ok=True)

for method_name in methods_to_run:
    for task_name, (dataset, compute_metrics) in tasks.items():
        print("=" * 60)
        print(f"EXPERIMENT {method_name.upper()} on {task_name.upper()}")
        print("=" * 60)

        try:
            config = AutoConfig.from_pretrained(MODEL_NAME)
            if getattr(config, "num_heads", None) != 8:
                config.num_heads = 8

            use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
            model = AutoModelForSeq2SeqLM.from_pretrained(
                MODEL_NAME,
                config=config,
                torch_dtype=torch.bfloat16 if use_bf16 else torch.float32,
                ignore_mismatched_sizes=True,
            )

            model.to(device)

            d_model = model.config.d_model
            num_heads = getattr(model.config, "num_heads", 8)

            peft_configs_local = {}

            if method_name == "lora":
                peft_configs_local["lora"] = LoraConfig(
                    r=32,
                    lora_alpha=32,
                    target_modules=["q", "v"],
                    lora_dropout=0.05,
                    bias="none",
                    task_type=TaskType.SEQ_2_SEQ_LM,
                )
            elif method_name == "lora_ablated_alpha0":
                peft_configs_local["lora_ablated_alpha0"] = LoraConfig(
                    r=32,
                    lora_alpha=0,
                    target_modules=["q", "v"],
                    lora_dropout=0.05,
                    bias="none",
                    task_type=TaskType.SEQ_2_SEQ_LM,
                )

            if method_name in peft_configs_local:
                model = get_peft_model(model, peft_configs_local[method_name])
                model.print_trainable_parameters()

            training_args = get_training_args(method_name, task_name)

            data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)

            trainer = Seq2SeqTrainer(
                model=model,
                args=training_args,
                train_dataset=dataset["train"],
                eval_dataset=dataset["validation"],
                data_collator=data_collator,
                compute_metrics=compute_metrics,
                tokenizer=tokenizer,
            )

            print("Training...")
            train_result = trainer.train()

            print("Evaluating...")
            test_dataset = dataset["test"]
            gen_kwargs = {
                "max_length": 5 if task_name == "classification" else 64,
                "num_beams": 4,
                "early_stopping": True,
            }
            training_args.generation_max_length = gen_kwargs["max_length"]
            training_args.generation_num_beams = gen_kwargs["num_beams"]

            test_metrics = trainer.evaluate(test_dataset)

            predictions = trainer.predict(dataset["validation"])
            cleaned_predictions = np.where(
                predictions.predictions != -100,
                predictions.predictions,
                tokenizer.pad_token_id,
            )
            cleaned_predictions = np.clip(cleaned_predictions, 0, tokenizer.vocab_size - 1)
            logger.info(
                f"Sample generations: {tokenizer.batch_decode(cleaned_predictions[:5], skip_special_tokens=True)}"
            )

            trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
            total = sum(p.numel() for p in model.parameters())

            exp_name = f"{method_name}_{task_name}"
            results[exp_name] = {
                "train_metrics": train_result.metrics,
                "test_metrics": test_metrics,
                "trainable_params": trainable,
                "total_params": total,
                "log_history": trainer.state.log_history,
            }

            save_path = f"{OUTPUT_DIR}/models/{task_name}/{method_name}"
            os.makedirs(save_path, exist_ok=True)
            trainer.save_model(save_path)
            print(f"Completed and saved to {save_path}")

            del model, trainer
            safe_cleanup()

        except Exception as e:
            logger.error(f"ERROR in {method_name} {task_name}: {e}")
            logger.error(traceback.format_exc())
            try:
                del model, trainer
            except:
                pass
            safe_cleanup()

print("=" * 60)
print("ALL EXPERIMENTS COMPLETED")
print("=" * 60)

# ==========================
# 11. Results summary/report
# ==========================
if results:
    print("\nSUMMARY")
    print("=" * 60)
    for exp_name, exp_data in results.items():
        method_task_split = exp_name.split("_", 1)
        method = method_task_split[0]
        task = method_task_split[1] if len(method_task_split) > 1 else "unknown"

        metrics = exp_data["test_metrics"]
        pct = 100 * exp_data["trainable_params"] / exp_data["total_params"]
        print(f"{method.upper()} - {task.capitalize()}")
        print(f"  Trainable: {pct:.2f}%")
        if task == "classification":
            print(f"  Accuracy: {metrics.get('eval_accuracy', 0):.4f}")
            print(f"  F1:       {metrics.get('eval_f1', 0):.4f}")
        else:
            print(f"  ROUGE-1:  {metrics.get('eval_rouge1', 0):.4f}")
            print(f"  ROUGE-L:  {metrics.get('eval_rougeL', 0):.4f}")
        print()

    print("Learning curves...")
    plot_paths = {}
    plot_save_dir = f"{OUTPUT_DIR}/plots"
    os.makedirs(plot_save_dir, exist_ok=True)

    for exp_name, exp_data in results.items():
        method_task_split = exp_name.split("_", 1)
        task = method_task_split[1] if len(method_task_split) > 1 else "unknown"
        plot_path = plot_learning_curves(exp_data["log_history"], exp_name, task, save_dir=plot_save_dir)
        plot_paths[exp_name] = plot_path

    ablation_plot_paths = {}
    if RUN_ABLATIONS:
        print("Ablation comparison plots...")
        for task_name in tasks.keys():
            task_results = {k: v for k, v in results.items() if k.endswith(task_name)}
            if task_results:
                ablation_plot_path = plot_ablation_comparisons(task_results, task_name, save_dir=plot_save_dir)
                if ablation_plot_path:
                    ablation_plot_paths[task_name] = ablation_plot_path

    results_df_rows = []
    for exp_name, exp_data in results.items():
        method, task = exp_name.split("_", 1)
        row = {
            "Method": method.upper(),
            "Task": task.capitalize(),
            "Trainable %": 100 * exp_data["trainable_params"] / exp_data["total_params"],
        }
        for k, v in exp_data["test_metrics"].items():
            if isinstance(v, (int, float)):
                row[k] = v
        results_df_rows.append(row)

    df = pd.DataFrame(results_df_rows)
    cols = ["Method", "Task", "Trainable %"]
    metric_cols = [c for c in df.columns if c.startswith("eval_")]
    cols.extend(sorted(metric_cols))
    df = df[cols]
    df.to_csv(f"{OUTPUT_DIR}/lora_results.csv", index=False)
    print(f"Results saved to {OUTPUT_DIR}/lora_results.csv")

    report_path = f"{OUTPUT_DIR}/lora_final_report.md"
    report_dir = os.path.dirname(report_path)
    with open(report_path, "w") as f:
        f.write("# LoRA Adaptation Results - T5-small\n\n")
        f.write("## Configuration\n")
        f.write(f"- Model: {MODEL_NAME}\n")
        f.write(f"- Dataset Size: {DATASET_SIZE}\n")
        f.write("- Methods: LoRA\n")
        if RUN_ABLATIONS:
            f.write("- Ablations: enabled\n")
        f.write("\n## Summary Table\n")
        f.write(df.to_markdown(index=False))
        f.write("\n\n## Learning Curves\n")
        for exp_name, plot_path in plot_paths.items():
            relative_plot_path = os.path.relpath(plot_path, start=report_dir)
            f.write(f"- {exp_name}: {relative_plot_path}\n")
        if RUN_ABLATIONS and ablation_plot_paths:
            f.write("\n## Ablation Comparisons\n")
            for task_name, plot_path in ablation_plot_paths.items():
                relative_plot_path = os.path.relpath(plot_path, start=report_dir)
                f.write(f"- {task_name.capitalize()} Ablation Comparison: {relative_plot_path}\n")

    print(f"Report saved to {report_path}")
