# Gemma 3 Fine-Tuning on SQuAD Dataset

This notebook demonstrates how to fine-tune the Gemma 3 4B model on the Stanford Question Answering Dataset (SQuAD).

In [None]:
# Install required packages
!pip install -q -U transformers datasets peft bitsandbytes accelerate trl immutabledict sentencepiece wandb
!git clone https://github.com/google/gemma_pytorch.git

In [None]:
# Disable W&B logging for this run
import os
os.environ["WANDB_MODE"] = "disabled"

In [None]:
import torch
import os
import kagglehub
import sys
import contextlib
from datasets import load_dataset
import transformers
from transformers import AutoTokenizer
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from trl import SFTTrainer

# Add gemma_pytorch to path
sys.path.append("/kaggle/working/gemma_pytorch/")

from gemma.config import get_model_config
from gemma.gemma3_model import Gemma3ForMultimodalLM

# Check CUDA capabilities and set device
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"GPU capabilities: {torch.cuda.get_device_capability(0)}")
    
# Set a single GPU to use (to avoid multi-GPU issues)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## 1. Load Gemma 3 Model and Tokenizer

In [None]:
# Model configuration
VARIANT = '4b'
METHOD = 'it'  # instruction-tuned model

# Determine if we can use GPU and what precision to use
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    MACHINE_TYPE = 'cuda'
    torch_dtype = torch.bfloat16
elif torch.cuda.is_available():
    MACHINE_TYPE = 'cuda'
    torch_dtype = torch.float16
else:
    MACHINE_TYPE = 'cpu'
    torch_dtype = torch.float32
    
print(f"Using {MACHINE_TYPE} with {torch_dtype}")

In [None]:
# Download model weights
weights_dir = kagglehub.model_download(f"google/gemma-3/pytorch/gemma-3-{VARIANT}-{METHOD}/1")
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
ckpt_path = os.path.join(weights_dir, f'model.ckpt')

# Set up model config
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

# Helper function to set default tensor type
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

In [None]:
# Load the model in 4-bit precision with LORA adapters for efficient fine-tuning
device = torch.device(MACHINE_TYPE)

with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path, map_location=device)['model_state_dict'])
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

## 2. Prepare Model for Fine-Tuning with PEFT

In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# Define LoRA configuration
# Note: Target modules may need adjustment for Gemma 3 architecture
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# Apply PEFT configuration to model
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

## 3. Load and Prepare SQuAD Dataset

In [None]:
# Load SQuAD dataset
squad_dataset = load_dataset("squad")
print(squad_dataset)

In [None]:
# Format SQuAD examples for instruction fine-tuning
# We'll use a specific format tailored for Gemma 3's chat template

USER_CHAT_TEMPLATE = "<start_of_turn>user\nContext: {context}\n\nQuestion: {question}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{answer}<end_of_turn>\n"

def format_squad_example(example):
    user_prompt = USER_CHAT_TEMPLATE.format(
        context=example["context"],
        question=example["question"]
    )
    model_response = MODEL_CHAT_TEMPLATE.format(answer=example["answers"]["text"][0])
    return {
        "formatted_prompt": user_prompt + model_response,
        "input": user_prompt,
        "output": model_response
    }

# Apply formatting to the dataset
train_dataset = squad_dataset["train"].map(format_squad_example)
validation_dataset = squad_dataset["validation"].map(format_squad_example)

# Take a subset for faster experimentation
train_subset = train_dataset.select(range(1000))  # Adjust as needed
validation_subset = validation_dataset.select(range(100))  # Adjust as needed

print(f"Training examples: {len(train_subset)}")
print(f"Validation examples: {len(validation_subset)}")

# Display an example
print("\nExample of formatted data:")
print(train_subset[0]["formatted_prompt"])

## 4. Configure Training

In [None]:
# Set training arguments
training_args = transformers.TrainingArguments(
    output_dir="./gemma3_squad_results",
    eval_strategy="steps",
    evaluation_strategy="steps",  # More explicit parameter name
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    max_steps=100,  # Adjust based on available time/resources
    learning_rate=2e-5,  # Slightly lower than with Gemma 2
    fp16=True if torch_dtype == torch.float16 else False,
    bf16=True if torch_dtype == torch.bfloat16 else False,
    optim="paged_adamw_8bit",
    save_strategy="steps",
    save_steps=50,
    eval_steps=25,
    logging_dir="./logs",
    logging_steps=10,
    push_to_hub=False,
    report_to="none",  # Disable reporting to wandb
    run_name="gemma3-squad-finetune"
)

In [None]:
# Initialize SFT Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_subset,
    eval_dataset=validation_subset,
    args=training_args,
    peft_config=peft_config,
    dataset_text_field="formatted_prompt",
    max_seq_length=512,  # Adjust based on your context length needs
    tokenizer=tokenizer,
    packing=False  # Set to False to avoid truncating examples
)

## 5. Train the Model

In [None]:
# Start training
trainer.train()

In [None]:
# Save the fine-tuned model
model_save_path = "./gemma3_squad_finetuned"
trainer.model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

## 6. Inference with the Fine-tuned Model

In [None]:
# Load the fine-tuned model for inference
from peft import PeftModel, PeftConfig

# Load the PEFT configuration
peft_config = PeftConfig.from_pretrained(model_save_path)

# Reload model with the fine-tuned weights
with _set_default_tensor_type(model_config.get_dtype()):
    eval_model = Gemma3ForMultimodalLM(model_config)
    eval_model.load_state_dict(torch.load(ckpt_path, map_location=device)['model_state_dict'])
    
# Load the PEFT model
eval_model = PeftModel.from_pretrained(eval_model, model_save_path)
eval_model = eval_model.to(device).eval()

In [None]:
# Function for question answering with the fine-tuned model
def answer_question(context, question, output_len=50):
    user_prompt = USER_CHAT_TEMPLATE.format(context=context, question=question)
    
    # Tokenize input
    inputs = tokenizer(user_prompt, return_tensors="pt").to(device)
    
    # Generate answer
    with torch.no_grad():
        outputs = eval_model.generate(
            inputs.input_ids,
            max_new_tokens=output_len,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
    
    # Decode the generated text and extract the answer
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the model's answer part
    model_part = generated_text[len(user_prompt):]
    
    # Remove the model chat template if present
    if "<start_of_turn>model\n" in model_part:
        answer = model_part.split("<start_of_turn>model\n")[1].split("<end_of_turn>")[0].strip()
    else:
        answer = model_part.strip()
    
    return answer

In [None]:
# Example SQuAD passages and questions for testing
examples = [
    {
        "context": "Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.",
        "question": "Which NFL team won Super Bowl 50?",
        "reference_answer": "Denver Broncos"
    },
    {
        "context": "Computational complexity theory is a branch of the theory of computation in theoretical computer science that focuses on classifying computational problems according to their inherent difficulty. A computational problem is understood to be a task that is in principle amenable to being solved by a computer, which is equivalent to stating that the problem may be solved by mechanical application of mathematical steps, such as an algorithm.",
        "question": "What is computational complexity theory a branch of?",
        "reference_answer": "theory of computation"
    },
    {
        "context": "Nikola Tesla (10 July 1856 – 7 January 1943) was a Serbian-American inventor, electrical engineer, mechanical engineer, and futurist best known for his contributions to the design of the modern alternating current (AC) electricity supply system. Born and raised in the Austrian Empire, Tesla studied engineering and physics in the 1870s without receiving a degree, gaining practical experience in the early 1880s working in telephony and at Continental Edison in the new electric power industry.",
        "question": "When was Nikola Tesla born?",
        "reference_answer": "10 July 1856"
    }
]

# Test the model on the examples
for idx, example in enumerate(examples):
    print(f"Example {idx+1}:")
    print(f"Context: {example['context'][:100]}...")
    print(f"Question: {example['question']}")
    print(f"Reference Answer: {example['reference_answer']}")
    
    model_answer = answer_question(example['context'], example['question'])
    print(f"Model Answer: {model_answer}")
    print("-" * 80)

## 7. Compare Gemma 2 vs Gemma 3 Performance

Now that we've fine-tuned Gemma 3 on the SQuAD dataset, let's analyze the differences in performance compared to Gemma 2.

In [None]:
# Load the previously fine-tuned Gemma 2 model (if available)
# Note: Adjust paths as needed
gemma2_path = "./results"  # Path to your Gemma 2 fine-tuned model

try:
    # Import necessary libraries for Gemma 2
    from transformers import AutoModelForCausalLM
    
    # Load Gemma 2 model and tokenizer
    gemma2_peft_config = PeftConfig.from_pretrained(gemma2_path)
    gemma2_base_model = AutoModelForCausalLM.from_pretrained(
        gemma2_peft_config.base_model_name_or_path,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    gemma2_model = PeftModel.from_pretrained(gemma2_base_model, gemma2_path)
    gemma2_tokenizer = AutoTokenizer.from_pretrained(gemma2_path)
    
    def gemma2_answer_question(context, question):
        prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
        
        inputs = gemma2_tokenizer(prompt, return_tensors="pt").to(gemma2_model.device)
        
        with torch.no_grad():
            outputs = gemma2_model.generate(
                **inputs,
                max_new_tokens=50,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )
        
        generated_text = gemma2_tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer = generated_text[len(prompt):].strip()
        
        return answer
    
    print("Successfully loaded Gemma 2 model for comparison")
    has_gemma2 = True
except Exception as e:
    print(f"Couldn't load Gemma 2 model: {e}")
    has_gemma2 = False

In [None]:
if has_gemma2:
    # Compare Gemma 2 vs Gemma 3 on the examples
    print("\n==== COMPARISON: GEMMA 2 vs GEMMA 3 ====\n")
    
    for idx, example in enumerate(examples):
        print(f"Example {idx+1}:")
        print(f"Question: {example['question']}")
        print(f"Reference Answer: {example['reference_answer']}")
        
        # Get answers from both models
        gemma2_answer = gemma2_answer_question(example['context'], example['question'])
        gemma3_answer = answer_question(example['context'], example['question'])
        
        print(f"Gemma 2 Answer: {gemma2_answer}")
        print(f"Gemma 3 Answer: {gemma3_answer}")
        print("-" * 80)

## 8. Key Differences Between Gemma 2 and Gemma 3

Based on the model implementations and fine-tuning process, here are some key differences between Gemma 2 and Gemma 3:

1. **Vocabulary Size**: 
   - Gemma 2: 256,000 tokens
   - Gemma 3: 262,144 tokens (larger vocabulary)

2. **Architecture Changes**:
   - Gemma 3 includes multimodal capabilities with the `Gemma3ForMultimodalLM` class
   - Gemma 3 uses a different layer configuration (Gemma 3 4B has 34 layers vs. different configurations in Gemma 2)
   - QK normalization is enabled by default in Gemma 3

3. **Context Length**:
   - Gemma 2: 8,192 tokens
   - Gemma 3: 32,768 tokens (4x longer context window)

4. **Attention Mechanism**:
   - Gemma 3 uses interleaved local/global attention with larger window sizes
   - Attention window sizes in Gemma 3 4B: [1024, 1024, 1024, 1024, 1024, 32768]

5. **Model Dimensionality**:
   - Different model dimensions and hidden layer sizes
   - Gemma 3 4B has model_dim=2560 compared to Gemma 2 models

6. **Chat Template**:
   - Gemma 3 uses the `GEMMA_VLM` prompt wrapping style for multimodal capabilities

7. **Performance Expectations**:
   - Improved reasoning capabilities
   - Better handling of longer contexts
   - More robust performance on complex questions