# LLM 02 – Prompt Tuning with PEFT

This notebook demonstrates how to perform prompt tuning using the PEFT library, following the Databricks Academy example.

In [None]:
# Install required packages
!pip install peft transformers datasets accelerate

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from peft import PromptTuningConfig, get_peft_model, PromptEncoder, get_peft_model_state_dict, PeftModel

## Load model and tokenizer

In [None]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # handle padding

## Configure prompt tuning (PEFT)

In [None]:
peft_config = PromptTuningConfig(
    task_type="CAUSAL_LM",
    num_virtual_tokens=20,
    prompt_encoder_hidden_size=512,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

## Load a dataset

In [None]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
def tokenize_fn(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512)
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

## Trainer setup

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./peft-prompt-tuning",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    logging_steps=100,
    save_steps=500,
    learning_rate=5e-4,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    data_collator=data_collator,
)

## Train!

In [None]:
trainer.train()
model.save_pretrained("./peft-prompt-tuned-model")

## Inference with the fine-tuned prompt

In [None]:
# Load base model + prompt adapter
base_model = AutoModelForCausalLM.from_pretrained(model_name)
peft_model = PeftModel.from_pretrained(base_model, "./peft-prompt-tuned-model")

peft_model.eval()
inputs = tokenizer("Once upon a time", return_tensors="pt")
outputs = peft_model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))