# QLoRA Fine-tuning (Gemma) on fine_tune_data.json

This basic notebook fine-tunes a small Gemma Instruct model with QLoRA on pairs from `fine_tuning_data/fine_tune_data.json`.

- Model: adjustable (defaults to `google/gemma-2-2b-it`).
- Method: QLoRA (4-bit quantization via bitsandbytes + PEFT LoRA) using TRL's SFTTrainer.
- Data: list of objects with keys `llm_message` (input/prompt) and `user_message` (target/response).

Run top-to-bottom on a GPU runtime (e.g., Runpod).


In [None]:
pip install -q transformers datasets accelerate peft trl bitsandbytes einops


In [None]:
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
import json

import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, TaskType

@dataclass
class TrainCfg:
    base_model: str = "google/gemma-2-2b-it"  # small instruct model; adjust for VRAM
    out_dir: str = "./outputs/gemma-qlora"
    bf16: bool = True
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 4
    max_steps: int = 500
    learning_rate: float = 1e-4
    warmup_ratio: float = 0.03
    logging_steps: int = 10
    max_seq_len: int = 1024

cfg = TrainCfg()

# 4-bit quantization
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if cfg.bf16 else torch.float16,
    bnb_4bit_use_double_quant=True,
)

device_map = "auto"

tokenizer = AutoTokenizer.from_pretrained(cfg.base_model, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    cfg.base_model,
    device_map=device_map,
    quantization_config=bnb_cfg,
    trust_remote_code=True,
)

peft_cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)



In [None]:
# Load fine_tune_data.json
data_path = Path("fine_tuning_data/fine_tune_data.json")
with open(data_path, "r") as f:
    pairs = json.load(f)

# Expect list of {"llm_message": str, "user_message": str}
assert isinstance(pairs, list) and all("llm_message" in r and "user_message" in r for r in pairs)

# Format into supervised text field (prompt + target)
BOS = ""
EOS = tokenizer.eos_token

def format_row(r):
    prompt = r["llm_message"].strip()
    target = r["user_message"].strip()
    return {"text": f"{BOS}{prompt}\n{target}{EOS}"}

train_ds = Dataset.from_list([format_row(r) for r in pairs])
print(train_ds[0]["text"][:200])


In [None]:
training_args = SFTConfig(
    output_dir=cfg.out_dir,
    bf16=cfg.bf16,
    per_device_train_batch_size=cfg.per_device_train_batch_size,
    gradient_accumulation_steps=cfg.gradient_accumulation_steps,
    max_steps=cfg.max_steps,
    learning_rate=cfg.learning_rate,
    warmup_ratio=cfg.warmup_ratio,
    logging_steps=cfg.logging_steps,
    save_steps=0,
    dataset_text_field="text",
    packing=True,
    max_seq_length=cfg.max_seq_len,
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_ds,
    peft_config=peft_cfg,
    args=training_args,
)
trainer.train()


In [None]:
# Save adapter and tokenizer
Path(cfg.out_dir).mkdir(parents=True, exist_ok=True)
trainer.model.save_pretrained(cfg.out_dir)
tokenizer.save_pretrained(cfg.out_dir)
print(f"Saved adapter to {cfg.out_dir}")

# Quick inference: merge adapter on the fly
from peft import PeftModel
base = AutoModelForCausalLM.from_pretrained(
    cfg.base_model,
    device_map=device_map,
    quantization_config=bnb_cfg,
    trust_remote_code=True,
)
adapted = PeftModel.from_pretrained(base, cfg.out_dir)
adapted.eval()

prompt = "hey how are you?"
inputs = tokenizer(prompt, return_tensors="pt").to(adapted.device)
with torch.no_grad():
    out = adapted.generate(**inputs, max_new_tokens=40, do_sample=True, temperature=0.7)
print(tokenizer.decode(out[0], skip_special_tokens=True))
