# Week 6 — GPT‑2 Lyrics Fine‑Tuning (LoRA) + ONNX + Gradio

This notebook is organized to run top‑to‑bottom in VS Code notebooks or Jupyter.

**Outputs created** (for submission):
- `artifacts/<run_id>/run_metrics.json`
- `artifacts/<run_id>/trainer_log_history.json`
- `artifacts/<run_id>/sweep_results.json` (minimal tuning evidence)
- `artifacts/<run_id>/onnx_export.json`
- `artifacts/<run_id>/artifact_paths.json`
- `artifacts/<run_id>/bleu_samples.json`


## 0) Install dependencies
Assignment note: the course instructions show `!pip install ...`. If you are in VS Code notebooks, `%pip install ...` also works and installs into the active kernel.


In [None]:
# If you prefer, replace !pip with %pip in VS Code notebooks
!pip -q install datasets transformers peft accelerate evaluate nltk onnx onnxruntime gradio optimum[onnxruntime]

## 1) Imports and configuration

In [None]:
import os, json, time, math, random, socket
from pathlib import Path

import numpy as np
import torch

from datasets import load_dataset
from transformers import (
    GPT2Tokenizer,
    GPT2LMHeadModel,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    set_seed,
)

from peft import LoraConfig, get_peft_model


In [None]:
# Reproducibility
SEED = 42
set_seed(SEED)

# Dataset + tokenization
DATASET_ID = "halaction/song-lyrics"
SUBSET = "train[:1000]"      # required subset used in your run
MAX_LENGTH = 128

# Output folders
ARTIFACTS_DIR = Path("./artifacts")
ARTIFACTS_DIR.mkdir(exist_ok=True)

RUN_ID = time.strftime("%Y%m%d_%H%M%S")
RUN_DIR = ARTIFACTS_DIR / RUN_ID
RUN_DIR.mkdir(parents=True, exist_ok=True)

ADAPTER_DIR = Path("./gpt2-lyrics-lora-adapter")
MERGED_DIR  = Path("./gpt2-lyrics-merged")
ONNX_DIR    = Path("./gpt2-lyrics-onnx")

print("RUN_DIR:", RUN_DIR.resolve())


## 2) Load dataset and preprocess

In [None]:
dataset = load_dataset(DATASET_ID, split=SUBSET)
print("Columns:", dataset.column_names)
print("Rows:", len(dataset))

# Pick a likely text column and clean empties
candidate_columns = ["lyrics", "text", "song", "content"]
text_col = next((c for c in candidate_columns if c in dataset.column_names), dataset.column_names[0])

dataset = dataset.filter(lambda x: x.get(text_col) is not None and str(x.get(text_col)).strip() != "")
dataset = dataset.select_columns([text_col])

print("Using text column:", text_col)
print("Rows after cleaning:", len(dataset))


## 3) Tokenizer + tokenization
Important details:
- GPT‑2 has no pad token by default, so we set `pad_token = eos_token`.
- We keep `raw_text` so BLEU can read held‑out text even if we later set the dataset format to torch.


In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(batch):
    raw_texts = batch[text_col]
    # Add EOS so generations learn an end-of-sequence signal
    texts = [str(t) + tokenizer.eos_token for t in raw_texts]
    tok = tokenizer(
        texts,
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH,
    )
    tok["labels"] = tok["input_ids"].copy()
    tok["raw_text"] = raw_texts
    return tok

tokenized = dataset.map(tokenize_function, batched=True, remove_columns=[text_col])
splits = tokenized.train_test_split(test_size=0.1, seed=SEED)
train_dataset = splits["train"]
eval_dataset  = splits["test"]

# Set torch format for training columns (raw_text stays in Arrow storage)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

print("Train rows:", len(train_dataset), "| Eval rows:", len(eval_dataset))


## 4) Load GPT‑2 and apply LoRA

In [None]:
base_model = GPT2LMHeadModel.from_pretrained("gpt2")
base_model.resize_token_embeddings(len(tokenizer))

lora_cfg = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["c_attn", "c_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()


## 5) Training setup (weight decay + early stopping)
This matches your run settings (effective batch size 16, LR 2e‑4, weight decay 0.01, 5 epochs).


In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir=str(ADAPTER_DIR),
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,  # effective batch ~ 16
    learning_rate=2e-4,
    weight_decay=0.01,
    logging_steps=20,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to=[],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)


## 6) Minimal hyperparameter tuning evidence (quick LR sweep)
This runs short trials to compare two learning rates. If you already ran this once, you can skip.


In [None]:
def quick_lr_trial(lr, max_steps=150):
    tmp_args = TrainingArguments(
        output_dir=str(RUN_DIR / f"trial_lr_{lr}"),
        overwrite_output_dir=True,
        num_train_epochs=1,
        max_steps=max_steps,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=lr,
        weight_decay=0.01,
        logging_steps=50,
        evaluation_strategy="steps",
        eval_steps=max_steps,
        save_strategy="no",
        load_best_model_at_end=False,
        report_to=[],
    )
    # fresh model for fair comparison
    bm = GPT2LMHeadModel.from_pretrained("gpt2")
    bm.resize_token_embeddings(len(tokenizer))
    m = get_peft_model(bm, lora_cfg)

    t = Trainer(
        model=m,
        args=tmp_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
    )

    t0 = time.time()
    t.train()
    metrics = t.evaluate()
    secs = time.time() - t0

    ppl = float(math.exp(metrics["eval_loss"])) if metrics.get("eval_loss") else float("nan")
    return {"learning_rate": lr, "eval_loss": float(metrics["eval_loss"]), "perplexity": ppl, "seconds": round(secs, 2)}

sweep_results = [
    {"name":"lr_2e-4", **quick_lr_trial(2e-4, max_steps=150)},
    {"name":"lr_1e-4", **quick_lr_trial(1e-4, max_steps=150)},
]

with open(RUN_DIR / "sweep_results.json", "w") as f:
    json.dump(sweep_results, f, indent=2)

sweep_results


## 7) Train the final LoRA model

In [None]:
t0 = time.time()
train_output = trainer.train()
train_seconds = time.time() - t0

eval_metrics = trainer.evaluate()
perplexity = float(math.exp(eval_metrics["eval_loss"]))

print("Eval loss:", eval_metrics["eval_loss"])
print("Perplexity:", perplexity)
print("Train seconds:", round(train_seconds, 2))


## 8) Save adapter, merge weights, export ONNX

In [None]:
# Save LoRA adapter
trainer.model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(ADAPTER_DIR)

# Merge LoRA into base model and save a standard Transformers folder
merged_model = trainer.model.merge_and_unload()
MERGED_DIR.mkdir(parents=True, exist_ok=True)
merged_model.save_pretrained(MERGED_DIR)
tokenizer.save_pretrained(MERGED_DIR)

artifact_paths = {
    "adapter_dir": str(ADAPTER_DIR.resolve()),
    "merged_dir": str(MERGED_DIR.resolve()),
}
with open(RUN_DIR / "artifact_paths.json", "w") as f:
    json.dump(artifact_paths, f, indent=2)

print("Saved adapter + merged model.")


In [None]:
# Export to ONNX (Optimum ONNX Runtime)
from optimum.onnxruntime import ORTModelForCausalLM

ONNX_DIR.mkdir(parents=True, exist_ok=True)
onnx_model = ORTModelForCausalLM.from_pretrained(MERGED_DIR, export=True)
onnx_model.save_pretrained(ONNX_DIR)
tokenizer.save_pretrained(ONNX_DIR)

onnx_export = {"onnx_dir": str(ONNX_DIR.resolve())}
with open(RUN_DIR / "onnx_export.json", "w") as f:
    json.dump(onnx_export, f, indent=2)

print("ONNX saved to:", ONNX_DIR.resolve())


## 9) Generation helper (ONNX preferred, PyTorch fallback)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

_tok = None
_pt_model = None
_onnx_model = None

def _get_tokenizer():
    global _tok
    if _tok is None:
        try:
            _tok = AutoTokenizer.from_pretrained(ONNX_DIR)
        except Exception:
            _tok = AutoTokenizer.from_pretrained(MERGED_DIR)
        if _tok.pad_token is None:
            _tok.pad_token = _tok.eos_token
    return _tok

def _get_pt_model():
    global _pt_model
    if _pt_model is None:
        _pt_model = AutoModelForCausalLM.from_pretrained(MERGED_DIR)
        _pt_model.eval()
        if torch.cuda.is_available():
            _pt_model.to("cuda")
    return _pt_model

def _get_onnx_model():
    global _onnx_model
    if _onnx_model is None:
        from optimum.onnxruntime import ORTModelForCausalLM
        _onnx_model = ORTModelForCausalLM.from_pretrained(ONNX_DIR)
    return _onnx_model

def generate_lyrics(prompt, backend="onnx", max_new_tokens=60, temperature=0.9, top_p=0.95, repetition_penalty=1.1):
    tok = _get_tokenizer()
    inputs = tok(prompt, return_tensors="pt", truncation=True)

    if backend == "pytorch":
        model = _get_pt_model()
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=int(max_new_tokens),
                do_sample=True,
                temperature=float(temperature),
                top_p=float(top_p),
                repetition_penalty=float(repetition_penalty),
                pad_token_id=tok.eos_token_id,
            )
        return tok.decode(out[0], skip_special_tokens=True)

    try:
        model = _get_onnx_model()
        out = model.generate(
            **inputs,
            max_new_tokens=int(max_new_tokens),
            do_sample=True,
            temperature=float(temperature),
            top_p=float(top_p),
            repetition_penalty=float(repetition_penalty),
            pad_token_id=tok.eos_token_id,
        )
        return tok.decode(out[0], skip_special_tokens=True)
    except Exception:
        return generate_lyrics(prompt, backend="pytorch", max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)

print(generate_lyrics("late night drive, neon lights", backend="pytorch", max_new_tokens=40)[:300])


## 10) BLEU evaluation (multiple samples)
BLEU is a lightweight sanity check here. Lyric generation is open-ended, so exact n‑gram overlap with a single reference is rare.


In [None]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def split_prompt_reference(text, prompt_words=18, ref_max_words=80):
    words = (text or "").split()
    if len(words) < prompt_words + 5:
        prompt = " ".join(words[: max(5, len(words)//2)])
        ref = " ".join(words[len(prompt.split()):])
        return prompt, ref
    prompt = " ".join(words[:prompt_words])
    ref = " ".join(words[prompt_words:prompt_words + ref_max_words])
    return prompt, ref

NUM_BLEU_SAMPLES = 10
PROMPT_WORDS = 18

# Read raw text from Arrow storage (safe even if set_format was used)
eval_dataset.reset_format()
eval_texts = eval_dataset.data.column("raw_text").to_pylist() if "raw_text" in eval_dataset.column_names else []

samples = []
bleu_scores = []
smooth = SmoothingFunction().method1

for i in range(min(NUM_BLEU_SAMPLES, len(eval_texts))):
    text = eval_texts[i]
    prompt, reference_text = split_prompt_reference(text, prompt_words=PROMPT_WORDS)
    if not reference_text.strip():
        continue

    generated_full = generate_lyrics(prompt, backend="onnx", max_new_tokens=60)

    continuation = generated_full[len(prompt):].strip() if generated_full.lower().startswith(prompt.lower()) else generated_full.strip()
    reference_tokens = [reference_text.split()]
    candidate_tokens = continuation.split()

    bleu = sentence_bleu(reference_tokens, candidate_tokens, smoothing_function=smooth)
    bleu_scores.append(float(bleu))

    samples.append({
        "i": i,
        "prompt": prompt,
        "reference_continuation": reference_text,
        "generated_full": generated_full,
        "generated_continuation": continuation,
        "bleu": float(bleu),
    })

avg_bleu = float(sum(bleu_scores) / len(bleu_scores)) if bleu_scores else float("nan")
print("Average BLEU:", avg_bleu)

with open(RUN_DIR / "bleu_samples.json", "w") as f:
    json.dump({"avg_bleu": avg_bleu, "n": len(samples), "samples": samples}, f, indent=2)

print("Saved:", (RUN_DIR / "bleu_samples.json").resolve())


## 11) Save final run metrics + logs

In [None]:
run_metrics = {
    "run_id": RUN_ID,
    "timestamp_local": time.strftime("%Y-%m-%dT%H:%M:%S"),
    "base_model": "gpt2",
    "dataset": f"{DATASET_ID} ({SUBSET})",
    "seed": SEED,
    "train_rows": len(train_dataset),
    "eval_rows": len(eval_dataset),
    "max_length": MAX_LENGTH,
    "lora_config": {
        "r": lora_cfg.r,
        "lora_alpha": lora_cfg.lora_alpha,
        "lora_dropout": lora_cfg.lora_dropout,
        "target_modules": lora_cfg.target_modules,
    },
    "training_args": training_args.to_dict(),
    "eval_metrics": {k: float(v) for k, v in eval_metrics.items() if isinstance(v, (int, float))},
    "perplexity": perplexity,
    "train_seconds": round(train_seconds, 2),
}

with open(RUN_DIR / "run_metrics.json", "w") as f:
    json.dump(run_metrics, f, indent=2)

with open(RUN_DIR / "trainer_log_history.json", "w") as f:
    json.dump(trainer.state.log_history, f, indent=2)

print("Saved run_metrics.json and trainer_log_history.json to:", RUN_DIR.resolve())


## 12) Gradio app (local deployment)
VS Code notebooks note: older Gradio versions do not support `block=`. This launch code avoids port conflicts and works locally.


In [None]:
import gradio as gr

def gr_generate(backend, prompt, max_new_tokens, temperature, top_p):
    return generate_lyrics(
        prompt=prompt,
        backend=backend,
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
    )

try:
    demo = gr.Interface(
        fn=gr_generate,
        inputs=[
            gr.Dropdown(choices=["onnx", "pytorch"], value="onnx", label="Backend"),
            gr.Textbox(lines=2, label="Prompt"),
            gr.Slider(20, 200, value=80, step=1, label="max_new_tokens"),
            gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="temperature"),
            gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="top_p"),
        ],
        outputs=gr.Textbox(lines=10, label="Generated Lyrics"),
        title="GPT-2 Lyrics Generator (LoRA Fine-Tuned)",
        description="Runs locally. ONNX backend is recommended when available; PyTorch is the fallback.",
        flagging_mode="never",
    )
except TypeError:
    demo = gr.Interface(
        fn=gr_generate,
        inputs=[
            gr.Dropdown(choices=["onnx", "pytorch"], value="onnx", label="Backend"),
            gr.Textbox(lines=2, label="Prompt"),
            gr.Slider(20, 200, value=80, step=1, label="max_new_tokens"),
            gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="temperature"),
            gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="top_p"),
        ],
        outputs=gr.Textbox(lines=10, label="Generated Lyrics"),
        title="GPT-2 Lyrics Generator (LoRA Fine-Tuned)",
        description="Runs locally. ONNX backend is recommended when available; PyTorch is the fallback.",
        allow_flagging="never",
    )

# Avoid proxy surprises for localhost
for k in ["HTTP_PROXY","HTTPS_PROXY","ALL_PROXY","http_proxy","https_proxy","all_proxy"]:
    os.environ.pop(k, None)
os.environ["NO_PROXY"] = "127.0.0.1,localhost"
os.environ["no_proxy"] = "127.0.0.1,localhost"

# Pick an open port (Windows safe)
sock = socket.socket()
sock.bind(("127.0.0.1", 0))
port = sock.getsockname()[1]
sock.close()

print("Launching Gradio on port:", port)
demo.launch(server_name="127.0.0.1", server_port=port, prevent_thread_lock=True)
print(f"Open: http://127.0.0.1:{port}/")
