In [None]:
# import transformers
# import torch

# model_id = "mlabonne/Meta-Llama-3-8B"

# pipeline = transformers.pipeline(
#     "text-generation", model=model_id, model_kwargs={"dtype": torch.bfloat16}, device_map="auto"
# )
# pipeline("Hey how are you doing today?")


# Llama-3-8B Finetuning Tutorial

This notebook demonstrates how to finetune Llama-3-8B on a toy dataset using LoRA (Low-Rank Adaptation) for efficient training.


## Step 1: Import Libraries


In [1]:
import warnings

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset
from trl.trainer.sft_trainer import SFTTrainer

warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.8.0+cu128
CUDA available: True
GPU: NVIDIA GeForce RTX 4090
GPU Memory: 25.24 GB


## Step 2: Create a Toy Dataset

We'll create a simple instruction-following dataset with examples teaching the model to respond in a specific style.


In [2]:
# Create a toy dataset for finetuning
# In this example, we'll teach the model to be a helpful math tutor

toy_data = [
    {
        "instruction": "What is 15 + 27?",
        "response": "Let me help you solve this step by step:\n15 + 27 = 42\n\nTo break it down: 15 + 20 = 35, then 35 + 7 = 42."
    },
    {
        "instruction": "Explain what a prime number is.",
        "response": "A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself. For example, 2, 3, 5, 7, 11 are prime numbers. The number 4 is not prime because it can be divided by 2."
    },
    {
        "instruction": "How do I calculate the area of a circle?",
        "response": "To calculate the area of a circle, use the formula:\n\nArea = π × r²\n\nWhere r is the radius of the circle and π (pi) is approximately 3.14159. For example, if the radius is 5, the area would be π × 5² = π × 25 ≈ 78.54 square units."
    },
    {
        "instruction": "What is 8 × 7?",
        "response": "8 × 7 = 56\n\nYou can think of this as adding 8 seven times: 8 + 8 + 8 + 8 + 8 + 8 + 8 = 56"
    },
    {
        "instruction": "What is the Pythagorean theorem?",
        "response": "The Pythagorean theorem states that in a right-angled triangle, the square of the hypotenuse (the longest side) equals the sum of squares of the other two sides.\n\nFormula: a² + b² = c²\n\nWhere c is the hypotenuse and a and b are the other two sides."
    },
    {
        "instruction": "How do you find the mean of a set of numbers?",
        "response": "To find the mean (average) of a set of numbers:\n\n1. Add all the numbers together\n2. Divide by how many numbers there are\n\nExample: For [4, 8, 6, 5, 10]\nMean = (4 + 8 + 6 + 5 + 10) ÷ 5 = 33 ÷ 5 = 6.6"
    },
    {
        "instruction": "What is 100 - 37?",
        "response": "100 - 37 = 63\n\nOne way to solve this: 100 - 30 = 70, then 70 - 7 = 63"
    },
    {
        "instruction": "Explain what a fraction is.",
        "response": "A fraction represents a part of a whole. It consists of two numbers:\n\n- Numerator (top number): tells you how many parts you have\n- Denominator (bottom number): tells you how many parts make up the whole\n\nFor example, 3/4 means you have 3 parts out of 4 total parts."
    },
]

# Convert to Llama-3 chat format
def format_chat_template(example):
    """Format the data in Llama-3 instruction format"""
    text = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{example['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{example['response']}<|eot_id|>"""
    return {"text": text}

# Create dataset
dataset = Dataset.from_list(toy_data)
dataset = dataset.map(format_chat_template)

print(f"Dataset size: {len(dataset)} examples")
print("\nExample formatted prompt:")
print(dataset[0]['text'])


Map: 100%|██████████| 8/8 [00:00<00:00, 2153.13 examples/s]

Dataset size: 8 examples

Example formatted prompt:
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

What is 15 + 27?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Let me help you solve this step by step:
15 + 27 = 42

To break it down: 15 + 20 = 35, then 35 + 7 = 42.<|eot_id|>





## Step 3: Configure Model Loading with Quantization

We'll use 4-bit quantization to reduce memory usage and enable efficient training on a single GPU.


In [3]:
# Model configuration
model_id = "mlabonne/Meta-Llama-3-8B"  # You can also use "mlabonne/Meta-Llama-3-8B"

# 4-bit quantization config to reduce memory usage
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True,
# )

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)



In [4]:
tokenizer.SPECIAL_TOKENS_ATTRIBUTES

['bos_token',
 'eos_token',
 'unk_token',
 'sep_token',
 'pad_token',
 'cls_token',
 'mask_token',
 'additional_special_tokens']

In [5]:
tokenizer.bos_token

'<|begin_of_text|>'

In [6]:
tokenizer.eos_token

'<|end_of_text|>'

In [4]:
# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

print(f"Model loaded successfully!")
print(f"Model dtype: {model.dtype}")


Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Model loaded successfully!
Model dtype: torch.float32


## Step 4: Configure LoRA

LoRA (Low-Rank Adaptation) allows efficient finetuning by only training a small number of additional parameters.


In [5]:
# LoRA configuration
peft_config = LoraConfig(
    r=16,                      # Rank of the low-rank matrices
    lora_alpha=32,             # Scaling factor
    lora_dropout=0.05,         # Dropout probability
    bias="none",               # Bias type
    task_type="CAUSAL_LM",     # Task type
    target_modules=[           # Which modules to apply LoRA to
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

# Apply LoRA to the model
model = get_peft_model(model, peft_config)

# Print trainable parameters
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || "
        f"all params: {all_param:,} || "
        f"trainable%: {100 * trainable_params / all_param:.2f}%"
    )

print_trainable_parameters(model)


trainable params: 41,943,040 || all params: 4,582,543,360 || trainable%: 0.92%


## Step 5: Configure Training Arguments


In [6]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./llama3-finetuned",           # Output directory
    num_train_epochs=10,                         # Number of training epochs
    per_device_train_batch_size=1,              # Batch size per device
    gradient_accumulation_steps=4,              # Gradient accumulation steps
    learning_rate=2e-4,                         # Learning rate
    lr_scheduler_type="cosine",                 # Learning rate scheduler
    warmup_steps=10,                            # Warmup steps
    logging_steps=1,                            # Log every N steps
    save_strategy="epoch",                      # Save checkpoint every epoch
    optim="paged_adamw_8bit",                   # Optimizer
    fp16=False,                                 # Use FP16 precision
    bf16=True,                                  # Use BF16 precision
    max_grad_norm=0.3,                          # Max gradient norm
    report_to="none",                           # Don't report to any platform
)

print("Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")


Training configuration:
  Epochs: 10
  Batch size: 1
  Gradient accumulation: 4
  Effective batch size: 4
  Learning rate: 0.0002


## Step 6: Train the Model


In [7]:
# Create SFTTrainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=training_args,
)

print("Starting training...")
print("=" * 50)

# Train the model
trainer.train()

print("=" * 50)
print("Training completed!")


Adding EOS to train dataset: 100%|██████████| 8/8 [00:00<00:00, 4191.16 examples/s]
Tokenizing train dataset: 100%|██████████| 8/8 [00:00<00:00, 1982.65 examples/s]
Truncating train dataset: 100%|██████████| 8/8 [00:00<00:00, 8174.04 examples/s]


Starting training...


Step,Training Loss
1,3.5384
2,2.5693
3,2.5916
4,3.1721
5,2.4317
6,2.4427
7,2.3395
8,1.5748
9,1.4558


KeyboardInterrupt: 

## Step 7: Save the Model


In [8]:
# Save the fine-tuned model (LoRA adapters only)
model.save_pretrained("./llama3-finetuned-lora")
tokenizer.save_pretrained("./llama3-finetuned-lora")

print("Model saved to ./llama3-finetuned-lora")


Model saved to ./llama3-finetuned-lora


## Step 8: Test the Fine-tuned Model


In [9]:
# Reload model properly for inference to fix dtype issues
from peft import PeftModel
from transformers import AutoModelForCausalLM


model_id = "mlabonne/Meta-Llama-3-8B"

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
# )

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

print("Reloading model for inference...")
inference_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
inference_model = PeftModel.from_pretrained(inference_model, "./llama3-finetuned/checkpoint-20")
inference_model.eval()
print("Model reloaded successfully!")

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Test the model with a new question
def test_model(question):
    prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(inference_model.device)
    
    with torch.inference_mode():
        outputs = inference_model.generate(  # type: ignore
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract just the assistant's response
    response = response.split("assistant")[-1].strip()
    
    return response

# Test with a few questions
test_questions = [
    "What is 15 + 27?",
    "What is a prime number?",
    "How do you calculate the perimeter of a rectangle?",
]

print("Testing the fine-tuned model:\n")
for q in test_questions:
    print(f"Q: {q}")
    print(f"A: {test_model(q)}")
    print("-" * 80)


Reloading model for inference...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.85s/it]


Model reloaded successfully!
Testing the fine-tuned model:

Q: What is 15 + 27?
A: Let me help you solve this step by step:
15 + 27 = 42

To break it down: 15 + 20 = 35, then 35 + 7 = 42. Gi�://

To break it down: 15 + 20 = 35, then 35 + 7 = 42.・━・━
--------------------------------------------------------------------------------
Q: What is a prime number?
A: A prime number (or a prime) is a natural number greater than 1 that has no positive divisors other than 1 and itself. For example, 2, 3, 5, 7, 11 are prime numbers. The number 4 is not prime because it can be divided by 2.网刊
--------------------------------------------------------------------------------
Q: How do you calculate the perimeter of a rectangle?
A: To calculate the perimeter of a rectangle, use the formula:

Perimeter = 2 × (length + width)

For example, if the length is 5 and the width is 3, the perimeter would be:

Perimeter = 2 × (5 + 3) = 2 × 8 = 16

So the perimeter is 16.vinfos
------------------------------------