# QLoRA Fine-Tuning for Policy Compliance

Fine-tunes Llama 3.1 8B on policy data. Runtime: 2-3 hours on T4.

In [None]:
!pip install -q torch transformers accelerate peft bitsandbytes trl datasets sentencepiece

In [None]:
import torch
import os
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
from huggingface_hub import login
login()

## Upload Training Data

In [None]:
from google.colab import files
import json
uploaded = files.upload()
DATA_FILE = list(uploaded.keys())[0]
print(f'Uploaded: {DATA_FILE}')

## Load Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch

MODEL_NAME = 'meta-llama/Meta-Llama-3.1-8B-Instruct'

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

print('Loading model...')
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb_config, device_map='auto', torch_dtype=torch.float32)
model = prepare_model_for_kbit_training(model)
model.gradient_checkpointing_enable()

# Force all non-quantized params to float32
for name, param in model.named_parameters():
    if param.dtype == torch.bfloat16:
        param.data = param.data.to(torch.float32)

print('Model loaded!')

## Configure LoRA

In [None]:
lora_config = LoraConfig(
    r=64, lora_alpha=128, lora_dropout=0.05, bias='none',
    task_type='CAUSAL_LM',
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
)
model = get_peft_model(model, lora_config)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Trainable params: {trainable:,}')

## Prepare Data

In [None]:
from datasets import Dataset
import json

data = []
with open(DATA_FILE, 'r') as f:
    for line in f:
        if line.strip(): data.append(json.loads(line))
print(f'Loaded {len(data)} examples')

PROMPT = '''<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a compliance assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

{q}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{a}<|eot_id|>'''

def fmt(ex):
    q = ex.get('question', ex.get('instruction', ''))
    a = ex.get('answer', ex.get('output', ''))
    return {'text': PROMPT.format(q=q, a=a)}

dataset = Dataset.from_list([fmt(d) for d in data])
split = dataset.train_test_split(test_size=0.1, seed=42)
train_ds, eval_ds = split['train'], split['test']
print(f'Train: {len(train_ds)}, Eval: {len(eval_ds)}')

## Train

In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
import gc

# Clear memory
gc.collect()
torch.cuda.empty_cache()

# Tokenize data with shorter length
def tokenize(example):
    return tokenizer(example['text'], truncation=True, max_length=128, padding='max_length')

train_tokenized = train_ds.map(tokenize, remove_columns=['text'])

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

args = TrainingArguments(
    output_dir='./policy-llama',
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    num_train_epochs=3,
    learning_rate=2e-4,
    warmup_ratio=0.03,
    lr_scheduler_type='cosine',
    optim='adamw_bnb_8bit',
    logging_steps=25,
    eval_strategy='no',
    save_strategy='epoch',
    bf16=False,
    fp16=False,
    report_to='none',
    dataloader_pin_memory=False,
    gradient_checkpointing=True,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_tokenized,
    data_collator=data_collator,
)
print('Ready to train!')

In [None]:
trainer.train()

In [None]:
trainer.save_model('./policy-llama/final')
tokenizer.save_pretrained('./policy-llama/final')
print('Saved!')

## Download Model

In [None]:
!zip -r policy-llama.zip ./policy-llama/final
files.download('policy-llama.zip')