In [1]:
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Model and Tokenizer
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"

In [None]:
# QLoRA Parameters
# NF4 (NormalFloat), Double Quantization, and Bfloat16 compute dtype
# bfloat16 ensures numerical stability
BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for computation
)

In [None]:
# LoRA Parameters
# lora_alpha=16 is the scaling factor
LORA_CONFIG = LoraConfig(
    r=64,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:
# Training Arguments (Adjust these based on your application/data)
TRAINING_ARGS = TrainingArguments(
    output_dir="./qlora_results",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4, 
    optim="paged_adamw_8bit", # Uses Paged Optimizers for memory spike management
    logging_steps=10,
    learning_rate=2e-4,
    fp16=False,
    bf16=True, # Use bfloat16 training precision
    max_grad_norm=0.3, # Standard for QLoRA to prevent exploding gradients
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
)

In [None]:
print("Loading model and tokenizer...")

# Load model with 4-bit quantization configuration
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=BNB_CONFIG,
    device_map="auto"
)

Loading model and tokenizer...


Loading checkpoint shards: 100%|██████████| 3/3 [00:15<00:00,  5.17s/it]


In [7]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token # Set padding token

In [None]:
# Enable gradient checkpointing and prepare for k-bit training
# This is a key memory-saving technique that recomputes intermediate activation layers
# during the backward pass instead of storing them, saving significant VRAM.
model = prepare_model_for_kbit_training(model)

In [None]:
# Inject LoRA adapters
model = get_peft_model(model, LORA_CONFIG)

In [10]:
# Print trainable parameter summary
model.print_trainable_parameters() 


trainable params: 167,772,160 || all params: 7,409,504,256 || trainable%: 2.2643


In [None]:

# Using a small, public instruction dataset for demonstration
dataset = load_dataset("yahma/alpaca-cleaned", split="train[:100]") # Use a small slice for quick testing


In [12]:
def formatting_func(example):
    text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}{tokenizer.eos_token}"
    return {"text": text}

In [13]:
# Apply formatting
dataset = dataset.map(formatting_func, remove_columns=['instruction', 'output', 'input'])


Map: 100%|██████████| 100/100 [00:00<00:00, 5926.00 examples/s]


In [14]:
print("Starting QLoRA fine-tuning...")

# Use the Hugging Face TRL SFTTrainer for simple fine-tuning
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=LORA_CONFIG,
    args=TRAINING_ARGS
    )

trainer.train()

Starting QLoRA fine-tuning...


Adding EOS to train dataset: 100%|██████████| 100/100 [00:00<00:00, 12345.99 examples/s]
Tokenizing train dataset: 100%|██████████| 100/100 [00:00<00:00, 1487.04 examples/s]
Truncating train dataset: 100%|██████████| 100/100 [00:00<00:00, 15206.12 examples/s]
  return fn(*args, **kwargs)


Step,Training Loss
10,1.2223
20,0.972
30,0.784
40,0.53
50,0.5739
60,0.2857
70,0.2862


TrainOutput(global_step=75, training_loss=0.6391031630833943, metrics={'train_runtime': 585.0486, 'train_samples_per_second': 0.513, 'train_steps_per_second': 0.128, 'total_flos': 2325459105792000.0, 'train_loss': 0.6391031630833943, 'entropy': 0.49892324656248094, 'num_tokens': 53250.0, 'mean_token_accuracy': 0.8992553889751435, 'epoch': 3.0})