# Fine-tuning Gemma-3-1b on GSM8K Dataset

This notebook demonstrates how to fine-tune the Gemma-3-1b model using the GSM8K dataset with LoRA (Low-Rank Adaptation) for efficient training.

In [None]:
# Install required libraries
!pip install -qqq "transformers>=4.55.0" "trl>=0.22.1" "datasets" "torch"
!pip install -qqq "accelerate" "peft" "huggingface_hub"

In [None]:
# Import required libraries
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
from peft import LoraConfig

In [None]:
# Load Gemma-3-1b base model
model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

In [None]:
# Load and format the GSM8K dataset
dataset = load_dataset("openai/gsm8k", "main", split="train")

def format_to_messages(example):
    system_instruction = """You are a highly logical and analytical problem-solving engine. When presented with a complex math word problem, your primary objective is to generate a comprehensive, step-by-step thinking process. Each step must clearly state the calculation performed and the resulting intermediate value. Ensure the final answer is extracted from the solution steps and is the last output."""
    return {
        "messages": [
            {"role": "system", "content": system_instruction},
            {"role": "user", "content": example["question"]},
            {"role": "assistant", "content": example["answer"]},
        ]
    }

gsm8k_formatted = dataset.map(format_to_messages)

In [None]:
# Define LoRA and SFT configurations
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear",
)

sft_config = SFTConfig(
    output_dir="./gemma-lora-demo",
    num_train_epochs=20,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    logging_steps=10,
    save_strategy="no",
    report_to="none",
    packing=True,
)

In [None]:
# Initialize the SFT Trainer
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=gsm8k_formatted,
    peft_config=peft_config,
)

In [None]:
# Train the model
trainer.train()