In [None]:
# ============================================================================
# Prefix Tuning - Summarization
# ============================================================================
import os
import sys
import gc
import json
import warnings
import numpy as np
import torch

# If running on Kaggle, install packages quietly (keeps behavior from original)
if os.environ.get("KAGGLE_KERNEL_RUN_TYPE") is not None:
    os.system("pip install -q --upgrade evaluate transformers peft protobuf==4.25.3")
# else:
#     os.system("pip install -r /home/requirements.txt")

from datasets import load_dataset
import evaluate
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from peft import PrefixTuningConfig, get_peft_model, PeftModel

import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# ============================================================================
# ENVIRONMENT DETECTION (SAFE: disable MPS for prefix tuning on Mac)
# ============================================================================
print("\n" + "=" * 80)
print("ENVIRONMENT DETECTION")
print("=" * 80)

IS_MPS_AVAILABLE = torch.backends.mps.is_available()
IS_CUDA = torch.cuda.is_available()
IS_KAGGLE = os.environ.get("KAGGLE_KERNEL_RUN_TYPE", None) is not None
IS_MAC_MACHINE = sys.platform == "darwin"

# IMPORTANT: For Prefix Tuning we avoid MPS due to cache/device bugs.
# On a Mac we force CPU (rather than MPS). On Kaggle (if CUDA) we use CUDA.
if IS_CUDA:
    DEVICE = "cuda"
    DEVICE_NAME = f"CUDA - {torch.cuda.get_device_name(0)}"
else:
    DEVICE = "cpu"
    if IS_MAC_MACHINE and IS_MPS_AVAILABLE:
        DEVICE_NAME = "CPU Only (MPS available but DISABLED for Prefix Tuning)"
    else:
        DEVICE_NAME = "CPU Only"

print(f" PyTorch Version: {torch.__version__}")
print(f" MPS Available: {IS_MPS_AVAILABLE}")
print(f" CUDA Available: {IS_CUDA}")
print(f" Kaggle Environment: {IS_KAGGLE}")
print(f" Using Device: {DEVICE_NAME}")

# ============================================================================
# ENVIRONMENT CONFIGURATION
# ============================================================================
if IS_CUDA and not IS_MAC_MACHINE:
    os.environ["WANDB_DISABLED"] = "true"
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ============================================================================
# MEMORY MANAGEMENT HELPERS
# ============================================================================
def clear_gpu_memory():
    # We only attempt CUDA empty cache; MPS empty cache is not used now.
    try:
        if IS_CUDA and torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass
    gc.collect()

def get_gpu_memory_usage():
    if IS_CUDA:
        return float(torch.cuda.memory_allocated() / 1024**3)
    return 0.0

def print_gpu_status():
    if IS_CUDA:
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        used = get_gpu_memory_usage()
        print(f" GPU Memory Used: {used:.2f} GB / {total:.2f} GB")
        print(f" GPU Device: {torch.cuda.get_device_name(0)}")
    else:
        print(" Using CPU - no GPU available")

# ============================================================================
# CONFIGURATION (you can change model/task here)
# ============================================================================
model_name = "google/flan-t5-small"
task = "summarization"  # or "classification"
dataset_size = 100 if IS_MAC_MACHINE else 'full'  # reduce for Mac

print("\n" + "=" * 80)
print("CONFIGURATION")
print("=" * 80)
print(f" Model: {model_name}")
print(f" Task: {task}")
print(f" Dataset Size: {dataset_size} (auto-reduced on Mac)")

# ============================================================================
# LOAD TOKENIZER
# ============================================================================
print("\n Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(" Tokenizer loaded")

# ============================================================================
# PREPROCESSING
# ============================================================================
def preprocess_summarization(examples):
    inputs = [f"summarize: {text}" for text in examples["text"]]
    model_inputs = tokenizer(
        inputs, max_length=512, truncation=True, padding=False
    )
    labels = tokenizer(
        examples["summary"], max_length=128, truncation=True, padding=False
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def preprocess_classification(examples):
    label_map = {0: "negative", 1: "positive"}
    inputs = [
        f"classify sentiment: {text}" for text in examples["text"]
    ]
    model_inputs = tokenizer(
        inputs, max_length=512, truncation=True, padding=False
    )
    labels = tokenizer(
        [label_map[label] for label in examples["label"]],
        max_length=8,
        truncation=True,
        padding=False,
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# ============================================================================
# LOAD DATASET
# ============================================================================
print(f"\n Loading {task} dataset...")
clear_gpu_memory()
try:
    if task == "summarization":
        dataset = load_dataset("knkarthick/samsum")
        if "val" in dataset:
            dataset["validation"] = dataset.pop("val")
        dataset = dataset.rename_column("dialogue", "text")
        preprocess_function = preprocess_summarization
        metric = evaluate.load("rouge")
        splits = ["train", "validation", "test"]
    elif task == "classification":
        dataset = load_dataset("imdb")
        if "unsupervised" in dataset:
            del dataset["unsupervised"]
        train_val = dataset["train"].train_test_split(test_size=0.1, seed=42)
        dataset["train"] = train_val["train"]
        dataset["validation"] = train_val["test"]
        preprocess_function = preprocess_classification
        metric = evaluate.load("accuracy")
        splits = ["train", "validation", "test"]
    else:
        raise ValueError("Task must be 'summarization' or 'classification'")
    print(f" Dataset loaded. Splits: {list(dataset.keys())}")
except Exception as e:
    print(f" Error loading dataset: {e}")
    print(" Please check your internet connection")
    raise SystemExit(1)

# ============================================================================
# SUBSAMPLE DATASET (if requested)
# ============================================================================
print(f"\n Subsampling dataset (size={dataset_size})...")
if dataset_size != "full":
    for split in splits:
        if split in dataset:
            available = len(dataset[split])
            num_samples = min(
                dataset_size if split == "train" else max(50, dataset_size // 2),
                available,
            )
            dataset[split] = dataset[split].shuffle(seed=42).select(
                range(num_samples)
            )
            print(f" {split}: {num_samples} samples")

# ============================================================================
# TOKENIZE DATASETS
# ============================================================================
print("\n Tokenizing datasets...")
clear_gpu_memory()
tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    batch_size=32,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing",
)
print(" Datasets tokenized")

# ============================================================================
# MODEL LOADING (base model), print config JSON, then lock PrefixTuningConfig
# ============================================================================
print(f"\n Loading model {model_name} with Prefix Tuning...")
clear_gpu_memory()
try:
    # Load base model into CPU first (safe)
    print(" Loading base model on CPU (device_map='cpu')...")
    base_model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        device_map="cpu",
        low_cpu_mem_usage=True,
    )
    print(" Base model loaded on CPU")

    # Print model.config JSON (so you have a record)
    try:
        cfg_dict = base_model.config.to_dict()
        print("\n--- MODEL CONFIG (start) ---")
        print(json.dumps(cfg_dict, indent=2))
        print("--- MODEL CONFIG (end) ---\n")
    except Exception:
        print("\nmodel.config:", base_model.config)

    # --- Derive safe, exact PrefixTuning constants from base_model.config ---
    cfg = base_model.config

    # num_layers: prefer decoder layers
    num_layers = None
    for attr in ("num_decoder_layers", "decoder_layers", "num_layers", "num_hidden_layers"):
        if hasattr(cfg, attr) and getattr(cfg, attr) is not None:
            num_layers = int(getattr(cfg, attr))
            break
    if num_layers is None:
        num_layers = int(getattr(cfg, "num_layers", 12))

    # num_attention_heads
    num_attention_heads = None
    for attr in ("num_heads", "num_attention_heads"):
        if hasattr(cfg, attr) and getattr(cfg, attr) is not None:
            num_attention_heads = int(getattr(cfg, attr))
            break
    if num_attention_heads is None:
        num_attention_heads = 8

    # token_dim: prefer d_kv (T5) else fallback to d_model
    if hasattr(cfg, "d_kv") and cfg.d_kv is not None:
        raw_d_kv = int(cfg.d_kv)
    elif hasattr(cfg, "d_model") and cfg.d_model is not None:
        raw_d_kv = int(cfg.d_model) // num_attention_heads
    else:
        raw_d_kv = 64

    # For PEFT, token_dim is head_dim * num_attention_heads
    token_dim = raw_d_kv * num_attention_heads

    # choose num_virtual_tokens scaled for Mac
    num_virtual_tokens = 10 if IS_MAC_MACHINE else 20

    # Lock and create exact PrefixTuningConfig
    prefix_config = PrefixTuningConfig(
        task_type="SEQ_2_SEQ_LM",
        num_virtual_tokens=num_virtual_tokens,
        prefix_projection=True,
        num_layers=num_layers,
        num_attention_heads=num_attention_heads,
        token_dim=token_dim,
    )

    print(
        f" Applying PrefixTuningConfig -> num_virtual_tokens={num_virtual_tokens},"
        f" num_layers={num_layers}, num_attention_heads={num_attention_heads}, token_dim={token_dim}"
    )

    # Wrap base model with PEFT prefix adapter (wrapped on CPU initially)
    model = get_peft_model(base_model, prefix_config)
    print(" PEFT wrapper created (currently on CPU)")

    # Move PEFT-wrapped model to device (single move)
    print(f" Moving PEFT-wrapped model to device: {DEVICE} ...")
    model = model.to(DEVICE)
    print(" Model moved to device")

    # Defensive: disable use_cache during training
    try:
        model.config.use_cache = False
    except Exception:
        pass

    # Optional check: print trainable params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f" Prefix Tuning Applied | Trainable Params: {trainable} / {total}")

except Exception as e:
    print(f" Error loading/applying prefix tuning: {e}")
    raise SystemExit(1)

# ============================================================================
# DATA COLLATOR
# ============================================================================
data_collator = DataCollatorForSeq2Seq(
    tokenizer, model=model, padding=True
)

# ============================================================================
# METRICS
# ============================================================================
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    # Replace -100 with pad_token_id (ignore_index)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    try:
        decoded_preds = tokenizer.batch_decode(
            predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
    except OverflowError:
        max_id = tokenizer.vocab_size - 1
        predictions = np.clip(predictions, 0, max_id)
        decoded_preds = tokenizer.batch_decode(
            predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
    decoded_labels = tokenizer.batch_decode(
        labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )
    if task == "summarization":
        try:
            result = metric.compute(
                predictions=decoded_preds, references=decoded_labels, use_stemmer=True
            )
            return {
                "rouge1": round(result["rouge1"], 4),
                "rouge2": round(result["rouge2"], 4),
                "rougeL": round(result["rougeL"], 4),
            }
        except Exception as e:
            print(f" Error computing ROUGE: {e}")
            return {"rouge1": 0.0}
    else:
        preds = [p.strip() for p in decoded_preds]
        refs = [r.strip() for r in decoded_labels]
        label_map_rev = {"negative": 0, "positive": 1}
        pred_ids = [label_map_rev.get(p, -1) for p in preds]
        ref_ids = [label_map_rev.get(r, -1) for r in refs]
        pairs = [(p, r) for p, r in zip(pred_ids, ref_ids) if p != -1 and r != -1]
        if not pairs:
            return {"accuracy": 0.0}
        pred_ids, ref_ids = zip(*pairs)
        result = metric.compute(predictions=pred_ids, references=ref_ids)
        return result

# ============================================================================
# TRAINING ARGUMENTS
# ============================================================================
print("\n Configuring training arguments...")
num_epochs = 2 if IS_MAC_MACHINE else 5
grad_accum = 2 if IS_MAC_MACHINE else 4
fp16_enabled = IS_CUDA and not IS_MAC_MACHINE

training_args = Seq2SeqTrainingArguments(
    output_dir=f"./{task}_flan_t5_small_prefix_{dataset_size}_{'mac' if IS_MAC_MACHINE else 'gpu'}",
    num_train_epochs=num_epochs,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=grad_accum,
    warmup_steps=50,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=fp16_enabled,
    bf16=False,
    report_to="none",
    gradient_checkpointing=False,
    optim="adamw_torch",
    learning_rate=5e-5,
    seed=42,
    dataloader_pin_memory=False,
    dataloader_num_workers=0,
    remove_unused_columns=False,
    predict_with_generate=True,
    generation_max_length=128,
    generation_num_beams=1,
)
print(" Training arguments configured")

# ============================================================================
# INITIALIZE TRAINER
# ============================================================================
print("\n Initializing Trainer...")
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets.get("validation", tokenized_datasets["train"]),
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
print(" Trainer initialized")

# ============================================================================
# TRAIN
# ============================================================================
print("\n" + "=" * 80)
print("STARTING TRAINING")
print("=" * 80)
print_gpu_status()
print()
clear_gpu_memory()
try:
    model.train()
    trainer.train()
    print("\n Training completed successfully!")
except KeyboardInterrupt:
    print("\n Training interrupted by user")
except RuntimeError as e:
    print("\n RuntimeError during training:", e)
    if "does not require grad" in str(e):
        print(" Check that PEFT parameters are trainable and gradient_checkpointing is False.")
    if "out of memory" in str(e).lower():
        print(" Out of memory. Try reducing dataset_size or lowering num_virtual_tokens.")
    raise

# ============================================================================
# PLOTTING METRICS (unchanged)
# ============================================================================
print("\n" + "=" * 80)
print("PLOTTING METRICS")
print("=" * 80)
log_history = trainer.state.log_history
# Extract data
steps = []
train_losses = []
epochs = []
eval_losses = []
eval_metrics = {}
for log in log_history:
    if 'loss' in log:
        steps.append(log['step'])
        train_losses.append(log['loss'])
    if 'eval_loss' in log:
        epochs.append(log['epoch'])
        eval_losses.append(log['eval_loss'])
        for k in log:
            if k.startswith('eval_') and k != 'eval_loss' and k != 'eval_runtime' and k != 'eval_samples_per_second' and k != 'eval_steps_per_second':
                if k not in eval_metrics:
                    eval_metrics[k] = []
                eval_metrics[k].append(log[k])

if train_losses:
    plt.figure(figsize=(10,5))
    plt.plot(steps, train_losses, label='Training Loss')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training Loss over Steps')
    plt.legend()
    plt.savefig('training_loss.png')
    plt.show()
    plt.close()

if eval_losses:
    plt.figure(figsize=(10,5))
    plt.plot(epochs, eval_losses, label='Evaluation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Evaluation Loss over Epochs')
    plt.legend()
    plt.savefig('eval_loss.png')
    plt.show()
    plt.close()

for metric_name, values in eval_metrics.items():
    plt.figure(figsize=(10,5))
    plt.plot(epochs, values, label=metric_name)
    plt.xlabel('Epochs')
    plt.ylabel(metric_name.split('eval_')[1])
    plt.title(f'{metric_name} over Epochs')
    plt.legend()
    plt.savefig(f'{metric_name}.png')
    plt.show()
    plt.close()

# ============================================================================
# EVALUATION
# ============================================================================
print("\n" + "=" * 80)
print("EVALUATING")
print("=" * 80)
clear_gpu_memory()
try:
    test_results = trainer.evaluate(
        tokenized_datasets.get("test", tokenized_datasets["validation"])
    )
    print("\n Test Results:")
    for k, v in test_results.items():
        print(f" {k}: {v}")
except Exception as e:
    print(f"\n Evaluation error: {e}")
    test_results = {}

# ============================================================================
# SAVE PEFT ADAPTER & TOKENIZER (deterministic)
# ============================================================================
print("\n Saving PEFT adapter and tokenizer...")
peft_save_dir = f"./peft_prefix_flan_t5_small_{dataset_size}_{'mac' if IS_MAC_MACHINE else 'gpu'}"
os.makedirs(peft_save_dir, exist_ok=True)
try:
    # `model` is the PEFT-wrapped model
    model.save_pretrained(peft_save_dir)
    tokenizer.save_pretrained(peft_save_dir)
    print(f" Saved PEFT adapter + tokenizer to: {peft_save_dir}")
except Exception as e:
    print(f" Warning when saving: {e}")

# ============================================================================
# OPTIONAL: example of loading the PEFT adapter back for inference
# ============================================================================
print("\n Demonstrating reload of base model + PEFT adapter for inference...")
try:
    base_for_load = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        device_map="cpu",
        low_cpu_mem_usage=True,
    )
    loaded_peft = PeftModel.from_pretrained(base_for_load, peft_save_dir)
    loaded_peft = loaded_peft.to(DEVICE)
    loaded_tokenizer = AutoTokenizer.from_pretrained(peft_save_dir)
    print(" Reloaded PEFT model and tokenizer successfully.")
except Exception as e:
    print(" Warning: could not reload PEFT adapter automatically:", e)

# ============================================================================
# SAVE METRICS & SUMMARY
# ============================================================================
print("\n Saving metrics & summary...")
metrics_file = f"{task}_prefix_metrics_{dataset_size}_{'mac' if IS_MAC_MACHINE else 'gpu'}.json"
metrics_dict = {
    "model": model_name,
    "task": task,
    "platform": "Mac (CPU)" if IS_MAC_MACHINE else "GPU (CUDA)" if IS_CUDA else "CPU",
    "dataset_size": dataset_size,
    "test_results": test_results,
    "training_config": {
        "epochs": training_args.num_train_epochs,
        "batch_size": training_args.per_device_train_batch_size,
        "gradient_accumulation": training_args.gradient_accumulation_steps,
        "fp16": fp16_enabled,
        "optimizer": training_args.optim,
    },
}
with open(metrics_file, "w") as f:
    json.dump(metrics_dict, f, indent=4)
print(f" Metrics saved to {metrics_file}")

# ============================================================================
# FINAL SUMMARY
# ============================================================================
print("\n" + "=" * 80)
print("TRAINING SUMMARY")
print("=" * 80)
print(f" Platform: {'Mac (CPU)' if IS_MAC_MACHINE else 'GPU (CUDA)' if IS_CUDA else 'CPU'}")
print(f" Task: {task}")
print(f" Model: {model_name}")
print(f" Dataset Size: {dataset_size}")
print(f" Training Epochs: {training_args.num_train_epochs}")
print(f" Gradient Accumulation: {training_args.gradient_accumulation_steps}")
print(f" Test Results: {test_results}")
print(f" PEFT saved to: {peft_save_dir}")
print_gpu_status()
print("=" * 80)
print("\n Script complete.\n")