# Fine-tuning Llama-3-8B-Instruct with QLoRA

For this tutorial, we’ll fine-tune the Llama 3 8B-Instruct model using the ruslanmv/ai-medical-chatbot dataset. The dataset contains 250k dialogues between a patient and a doctor. QLoRA stands for Quantized Low-Rank Adapter, and it's a method introduced to fine-tune LLMs models using much less GPU memory - without sacrificing much performance.

QLoRA combines two main ideas:
1. Quantization (specifically 4-bit)
    - It loads the base model weights in 4-bit precision instead of 16- or 32-bit.
    - This saves massive amounts of VRAM (e.g., you can fine-tune LLaMA 13B on a single 24GB GPU).

2. LoRA (Low-Rank Adaptation)
    - Instead of updating all model weights, LoRA freezes them and adds small trainable "adapter" layers.
    - These adapters inject learnable parameters into the model — they’re fast and cheap to train.

In [1]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())

2.6.0+cu118
True


In [2]:
pip show torch # We need to verify that the installed versions of PyTorch and CUDA are compatible. It should display the version along with CUDA support

Name: torch
Version: 2.6.0+cu118
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /home/stefany/interview_Material/interviews/interview_env/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu11, nvidia-cuda-cupti-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-runtime-cu11, nvidia-cudnn-cu11, nvidia-cufft-cu11, nvidia-curand-cu11, nvidia-cusolver-cu11, nvidia-cusparse-cu11, nvidia-nccl-cu11, nvidia-nvtx-cu11, sympy, triton, typing-extensions
Required-by: accelerate, auto_gptq, bitsandbytes, peft, torchaudio, torchvision
Note: you may need to restart the kernel to use updated packages.


### Validate setup
Let's start validating the setup. This ensures that the model, the GPU and libraries are configured correctly. Here’s what this script does:
- Confirms that CUDA is available and functioning.
- Verifies that the GPU is correctly detected.
- Ensures the model and tokenizer load without errors.

In [3]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# Check GPU availability
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# Data type and attention implementation
torch_dtype = torch.float16
attn_implementation = "eager"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)
print("QLoRA setup loaded successfully!")

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
print("LLaMA 3 loaded successfully!")

CUDA Available: True
GPU: NVIDIA RTX A5500 Laptop GPU
QLoRA setup loaded successfully!


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.16s/it]

LLaMA 3 loaded successfully!





In [7]:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [8]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

In [9]:
# Load the dataset
dataset = load_dataset("ruslanmv/ai-medical-chatbot", split="all")
dataset = dataset.shuffle(seed=65).select(range(1000)) # Only use 1000 samples
print(dataset[1])  # Inspect the first sample

def format_chat_template(row):
    row_json = [{"role": "user", "content": row["Patient"]},
               {"role": "assistant", "content": row["Doctor"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc=4,
)
print(dataset[1]) # Inspect the format for chat template is correct

{'Description': 'What causes blood in urine?', 'Patient': "Dr, My daughter is 5yrs old.i saw stains in her trousers a few days ago.the stains were light red colour.later i found some pus like liquid near her urinary tract.yesterday i saw light brick coloured liquid along her urine.feeling panic,gave  urine for culter and routine test,culter result not yet recd.her routine test say,pus cells:4-8 and epithiall cells :2-4.What's wrong with my daughter? (she goes to urine only 4 to 5 times a day,drinking water too not sufficient)", 'Doctor': 'Thanks for contacting HCMYou are concerned that your daughter may have a urinary tract infection. Your description of her urine and findings in her panties does suggest urinary tract infection. The urine analysis though is not very convincing for a urinary tract infection. The sample shows 2-4 epithelial cells and only 4-8 puss cells. The counts are normal and do not indicate infection. I recommend you wait for the culture results. I would recommend t

In [10]:
dataset = dataset.train_test_split(test_size=0.1) # Split dataset into training and validation set
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['Description', 'Patient', 'Doctor', 'text'],
        num_rows: 900
    })
    test: Dataset({
        features: ['Description', 'Patient', 'Doctor', 'text'],
        num_rows: 100
    })
})


In [11]:
new_model = "llama-3-8b-chat-doctor"

sft_config = SFTConfig(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    max_seq_length=512,  
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    eval_strategy="steps",
    eval_steps=100,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    # report_to="wandb"
)

In [12]:
tokenizer.pad_token = tokenizer.eos_token
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    processing_class=tokenizer, 
    args=sft_config,
)

Converting train dataset to ChatML: 100%|██████████████████████████████████████████████████| 900/900 [00:00<00:00, 22011.82 examples/s]
Applying chat template to train dataset: 100%|█████████████████████████████████████████████| 900/900 [00:00<00:00, 37988.06 examples/s]
Tokenizing train dataset: 100%|█████████████████████████████████████████████████████████████| 900/900 [00:00<00:00, 2468.91 examples/s]
Truncating train dataset: 100%|███████████████████████████████████████████████████████████| 900/900 [00:00<00:00, 123422.38 examples/s]
Converting eval dataset to ChatML: 100%|███████████████████████████████████████████████████| 100/100 [00:00<00:00, 15781.71 examples/s]
Applying chat template to eval dataset: 100%|██████████████████████████████████████████████| 100/100 [00:00<00:00, 20921.31 examples/s]
Tokenizing eval dataset: 100%|██████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 2169.03 examples/s]
Truncating eval dataset: 100%|██████████████████

In [13]:
# To clear out cache for unsuccessful run
torch.cuda.empty_cache()
trainer.train()

Step,Training Loss,Validation Loss
100,2.2513,2.526964
200,2.5343,2.475965
300,2.2869,2.436862
400,2.7472,2.415262


TrainOutput(global_step=450, training_loss=2.5127355739805433, metrics={'train_runtime': 835.7442, 'train_samples_per_second': 1.077, 'train_steps_per_second': 0.538, 'total_flos': 9315430879100928.0, 'train_loss': 2.5127355739805433})