# Gemma 3 Fine-Tuning on SQuAD Dataset

This notebook demonstrates how to fine-tune the Gemma 3 4B model on the Stanford Question Answering Dataset (SQuAD).


## Introducing the Gemma 3 4B-IT

**Gemma 3** is Google's latest addition to its family of lightweight, state-of-the-art open AI models, designed to deliver high performance while being resource-efficient. The **4B Instruct** version of **Gemma 3** is tailored for **instruction-based tasks**, offering developers an accessible and powerful tool for creating intelligent applications.  

Announcement: [Gemma 3 Blog Post](https://blog.google/technology/developers/gemma-3/)

Gemma 3 features a **transformer architecture** optimized with advanced techniques like **RoPE embeddings** and **GeGLU activations**, enabling sophisticated reasoning and text generation capabilities.

Key Features:
- **128K-token context window**: Allows processing and understanding of vast amounts of information.  
- **Multilingual support**: Over **140 languages**, ideal for global applications.  
- **Multimodal capabilities**: Supports text, images, and videos, enabling interactive AI solutions.  
- **Edge device optimization**: Efficiently runs on consumer hardware with a single GPU, making it accessible for developers with limited resources.

Resources:
- [Gemma 3 Model Overview](https://ai.google.dev/gemma/docs/core)  
- [Gemma 3 Technical Report](https://storage.googleapis.com/deepmind-media/gemma/Gemma3Report.pdf)  
- [Gemma 3 Model Card](https://ai.google.dev/gemma/docs/core/model_card_3)

### Package Setup

In [1]:
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 294, done.[K
remote: Counting objects: 100% (177/177), done.[K
remote: Compressing objects: 100% (97/97), done.[K
remote: Total 294 (delta 127), reused 80 (delta 80), pack-reused 117 (from 1)[K
Receiving objects: 100% (294/294), 5.53 MiB | 19.66 MiB/s, done.
Resolving deltas: 100% (165/165), done.


In [3]:
# Install required packages
# !pip3 install -q -U bitsandbytes
# !pip3 install -q -U peft
!pip3 install -q -U trl
# !pip3 install -q -U accelerate
# !pip3 install -q -U datasets
# !pip3 install -q -U transformers

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m336.4/336.4 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [11]:
# Check Package Versions
# !pip freeze | grep bitsandbytes
# !pip freeze | grep peft
# !pip freeze | grep trl
# !pip freeze | grep accelerate
# !pip freeze | grep datasets
!pip freeze | grep transformers

sentence-transformers==3.4.1
transformers==4.51.1


In [4]:
# Suppress Warnings
import warnings
warnings.filterwarnings("ignore")

### Import Libraries

In [4]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import contextlib
import kagglehub
import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 

import torch
import torch.nn as nn

import transformers
from transformers import (AutoModelForCausalLM, 
                          AutoTokenizer, 
                          BitsAndBytesConfig, 
                          TrainingArguments, 
                          pipeline, 
                          logging)

# from transformers.models.gemma3 import Gemma3ForCausalLM

from gemma.config import get_model_config
from gemma.gemma3_model import Gemma3ForMultimodalLM

from datasets import Dataset
from peft import LoraConfig, PeftConfig, PeftModel
from trl import SFTTrainer, SFTConfig
# import bitsandbytes as bnb

### CUDA & GPU Checking

In [3]:
# Disable W&B logging for this run
import os
os.environ["WANDB_MODE"] = "disabled"

# Set Cuda Allocation
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use the first GPU
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenization parallelism


In [5]:
# Choose variant and machine type
VARIANT = '4b'
MACHINE_TYPE = 'cuda'
OUTPUT_LEN = 20
METHOD = 'it'

weights_dir = kagglehub.model_download(f"google/gemma-3/pytorch/gemma-3-{VARIANT}-{METHOD}/1")
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
ckpt_path = os.path.join(weights_dir, f'model.ckpt')

# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

# Instantiate the model
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)

# For loading the model weights (Inference) ----------------- For Reference
# device = torch.device(MACHINE_TYPE)
# with _set_default_tensor_type(model_config.get_dtype()):
#     model = Gemma3ForMultimodalLM(model_config)
#     model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])        
#     model = model.to(device).eval()

In [22]:
!pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3 -q --no-cache

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone


In [23]:
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

hf_token_name = "HF_TOKEN_EG"
hf_key = UserSecretsClient().get_secret(hf_token_name)
print(f"Successfully loaded {hf_token_name}!")

login(token = hf_key)
print(f"Login with {hf_token_name} complete!")

Successfully loaded HF_TOKEN_EG!
Login with HF_TOKEN_EG complete!


In [24]:
from transformers import AutoTokenizer
# Load Gemma 3‑27B‑IT’s tokenizer
MODEL_GEMMA = "google/gemma-3-27b-it"
gemma_tokenizer = AutoTokenizer.from_pretrained(MODEL_GEMMA, trust_remote_code=True)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [5]:
def define_device():
    """Determine and return the optimal PyTorch device based on availability."""
    
    print(f"PyTorch version: {torch.__version__}", end=" -- ")

    # Check if MPS (Metal Performance Shaders) is available for macOS
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        print("using MPS device on macOS")
        return torch.device("mps")

    # Check for CUDA availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"using {device}")
    return device

## 1. Load Gemma 3 Model and Tokenizer

* If the GPU supports **bfloat16** (available on GPUs with Compute Capability **8.0+**), it is used for computations.  
  Otherwise, **float16** is used as the default.  

* **Device Selection:**  
  * The function `define_device()` selects the best available device (**CPU, CUDA, or MPS**).  

* **Model Initialization:**  
  * The model is loaded with memory-efficient configurations, including `low_cpu_mem_usage=True`, and assigned to the selected device.  

* **Tokenizer Setup:**  
  * A **tokenizer** is initialized with a **maximum sequence length of 1024**.  
  * The **end-of-sequence (EOS) token** is stored for later use.  

In [6]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/gemma-3/pytorch/gemma-3-4b-it/1/model.ckpt
/kaggle/input/gemma-3/pytorch/gemma-3-4b-it/1/tokenizer.model
/kaggle/input/stanford-question-answering-dataset/train-v1.1.json
/kaggle/input/stanford-question-answering-dataset/dev-v1.1.json


In [8]:
# Determine optimal computation dtype based on GPU capability
compute_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
print(f"Using compute dtype {compute_dtype}")

# Select the best available device (CPU, CUDA, or MPS)
device = define_device()
print(f"Operating on {device}")

# Path to the pre-trained model
GEMMA_PATH = "/kaggle/input/gemma-3/pytorch/gemma-3-4b-it/1"

# Load the model with optimized settings
model = Gemma3ForCausalLM.from_pretrained(
    GEMMA_PATH,
    torch_dtype=compute_dtype,
    attn_implementation="eager",
    low_cpu_mem_usage=True,
    device_map=device
)

# Define maximum sequence length for the tokenizer
max_seq_length = 1024

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    GEMMA_PATH, 
    max_seq_length=max_seq_length,
    device_map=device
)

# Store the EOS token for later use
EOS_TOKEN = tokenizer.eos_token

Using compute dtype torch.float16
PyTorch version: 2.5.1+cu124 -- using cuda
Operating on cuda


OSError: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory /kaggle/input/gemma-3/pytorch/gemma-3-4b-it/1.

## 2. Load and Prepare SQuAD Dataset

In [None]:
# Load SQuAD dataset
squad_dataset = load_dataset("squad")
print(squad_dataset)

In [None]:
# Format SQuAD examples for instruction fine-tuning
# We'll use a specific format tailored for Gemma 3's chat template

USER_CHAT_TEMPLATE = "<start_of_turn>user\nContext: {context}\n\nQuestion: {question}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{answer}<end_of_turn>\n"

def format_squad_example(example):
    user_prompt = USER_CHAT_TEMPLATE.format(
        context=example["context"],
        question=example["question"]
    )
    model_response = MODEL_CHAT_TEMPLATE.format(answer=example["answers"]["text"][0])
    return {
        "formatted_prompt": user_prompt + model_response,
        "input": user_prompt,
        "output": model_response
    }

# Apply formatting to the dataset
train_dataset = squad_dataset["train"].map(format_squad_example)
validation_dataset = squad_dataset["validation"].map(format_squad_example)

# Take a subset for faster experimentation
train_subset = train_dataset.select(range(1000))  # Adjust as needed
validation_subset = validation_dataset.select(range(100))  # Adjust as needed

print(f"Training examples: {len(train_subset)}")
print(f"Validation examples: {len(validation_subset)}")

# Display an example
print("\nExample of formatted data:")
print(train_subset[0]["formatted_prompt"])

## 3. Prepare Model for Fine-Tuning with PEFT

In the next cell, we set everything up for fine-tuning the model. We configure and initialize a **Simple Fine-tuning Trainer (SFTTrainer)** for training the model using the **Parameter-Efficient Fine-Tuning (PEFT)** method. PEFT is efficient because it operates on a reduced number of parameters compared to the model's overall size. This method focuses on refining only a limited set of additional model parameters while keeping the majority of the pre-trained large language model (LLM) parameters fixed, significantly reducing computational and storage expenses. Additionally, PEFT helps mitigate **catastrophic forgetting**, a common issue when fine-tuning LLMs completely.

### PEFTConfig:
The `peft_config` object specifies the parameters for PEFT. The following are some of the most important parameters:

- **lora_alpha**: The learning rate for the LoRA update matrices.
- **lora_dropout**: The dropout probability for the LoRA update matrices.
- **r**: The rank of the LoRA update matrices.
- **bias**: The type of bias to use. Possible values are: `none`, `additive`, and `learned`.
- **task_type**: The task type the model is being trained for. Possible values are `CAUSAL_LM` and `MASKED_LM`.

### TrainingArguments:
The `training_arguments` object specifies the parameters for training the model. The following are some key parameters:

- **output_dir**: Directory where the training logs and checkpoints will be saved.
- **num_train_epochs**: Number of epochs to train the model for.
- **per_device_train_batch_size**: Number of samples in each batch on each device.
- **gradient_accumulation_steps**: Number of batches to accumulate gradients before updating the model parameters.
- **gradient_checkpointing**: Whether to use gradient checkpointing to reduce GPU memory usage.
- **optim**: The optimizer used for training the model.
- **save_steps**: The number of steps after which to save a checkpoint.
- **logging_steps**: The number of steps after which to log the training metrics.
- **learning_rate**: The learning rate for the optimizer.
- **weight_decay**: The weight decay parameter for the optimizer.
- **fp16**: Whether to use 16-bit floating-point precision.
- **bf16**: Whether to use BFloat16 precision.
- **max_grad_norm**: The maximum gradient norm.
- **max_steps**: The maximum number of steps to train the model for.
- **warmup_ratio**: Proportion of training steps to use for warming up the learning rate.
- **group_by_length**: Whether to group the training samples by length.
- **lr_scheduler_type**: The type of learning rate scheduler to use.
- **report_to**: The tools to report the training metrics to.
- **evaluation_strategy**: The strategy for evaluating the model during training.
- **eval_steps**: Number of update steps between evaluations.
- **eval_accumulation_steps**: Number of prediction steps to accumulate before moving the output to CPU.

### SFTTrainer:
The `SFTTrainer` is a custom trainer class from the **TRL** library. It is used to fine-tune large language models using the PEFT method.

The `SFTTrainer` object is initialized with the following arguments:

- **model**: The model to be trained.
- **train_dataset**: The training dataset.
- **eval_dataset**: The evaluation dataset.
- **peft_config**: The PEFT configuration.
- **tokenizer**: The tokenizer to use.
- **args**: The training arguments.
- **dataset_text_field**: The name of the text field in the dataset.
- **packing**: Whether to pack the training samples.
- **max_seq_length**: The maximum sequence length.

Once the `SFTTrainer` object is initialized, it can be used to train the model by calling the `train()` method.

In [None]:
# Apply LoRA configuration
peft_config = LoraConfig(
    task_type="CAUSAL_LM",  
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj",
                      "gate_proj", "up_proj", "down_proj"],
    r=16,                         
    lora_alpha=32,                
    lora_dropout=0.1,
    bias="none"
)

# Add LoRA adapters to the model
model = get_peft_model(model, peft_config)

# Freeze all parameters except LoRA parameters
for name, param in model.named_parameters():
    if "lora" not in name:
        param.requires_grad = False  

## 4. Configure Training

In [None]:
# # Set training arguments
# training_args = transformers.TrainingArguments(
#     output_dir="./gemma3_squad_results",
#     eval_strategy="steps",
#     evaluation_strategy="steps",  # More explicit parameter name
#     per_device_train_batch_size=1,
#     gradient_accumulation_steps=4,
#     warmup_steps=2,
#     max_steps=100,  # Adjust based on available time/resources
#     learning_rate=2e-5,  # Slightly lower than with Gemma 2
#     fp16=True if torch_dtype == torch.float16 else False,
#     bf16=True if torch_dtype == torch.bfloat16 else False,
#     optim="paged_adamw_8bit",
#     save_strategy="steps",
#     save_steps=50,
#     eval_steps=25,
#     logging_dir="./logs",
#     logging_steps=10,
#     push_to_hub=False,
#     report_to="none",  # Disable reporting to wandb
#     run_name="gemma3-squad-finetune"
# )

training_arguments = SFTConfig(
    output_dir="logs",
    num_train_epochs=3,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Use reentrant checkpointing
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    optim="adamw_torch_fused",  # Use fused AdamW optimizer
    save_steps=112,
    load_best_model_at_end=True,
    logging_steps=25,
    learning_rate=2e-5,
    weight_decay=0.001,
    fp16=True if compute_dtype == torch.float16 else False,  # Use float16 precision
    bf16=True if compute_dtype == torch.bfloat16 else False,  # Use bfloat16 precision
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=False,
    evaluation_strategy="steps",
    eval_steps=112,
    eval_accumulation_steps=1,
    lr_scheduler_type="constant",
    report_to="tensorboard",
    max_seq_length=max_seq_length,
    packing=False,
    dataset_kwargs={
        "add_special_tokens": False,  # Template with special tokens
        "append_concat_token": True,  # Add EOS token as separator token
    }
)

In [None]:
model.config.use_cache = False
model.config.pretraining_tp = 1

# Initialize SFT Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_subset,
    eval_dataset=validation_subset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
)

## 5. Train the Model

In [None]:
# Start training
trainer.train()

In [None]:
# Save the fine-tuned model
model_save_path = "./gemma3_LoRA_squad_finetuned"
trainer.model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

## 6. Inference with the Fine-tuned Model

In [None]:
# Load the fine-tuned model for inference
from peft import PeftModel, PeftConfig

# Load the PEFT configuration
peft_config = PeftConfig.from_pretrained(model_save_path)

# Reload model with the fine-tuned weights
with _set_default_tensor_type(model_config.get_dtype()):
    eval_model = Gemma3ForMultimodalLM(model_config)
    eval_model.load_state_dict(torch.load(ckpt_path, map_location=device)['model_state_dict'])
    
# Load the PEFT model
eval_model = PeftModel.from_pretrained(eval_model, model_save_path)
eval_model = eval_model.to(device).eval()

In [None]:
# Function for question answering with the fine-tuned model
def answer_question(context, question, output_len=50):
    user_prompt = USER_CHAT_TEMPLATE.format(context=context, question=question)
    
    # Tokenize input
    inputs = tokenizer(user_prompt, return_tensors="pt").to(device)
    
    # Generate answer
    with torch.no_grad():
        outputs = eval_model.generate(
            inputs.input_ids,
            max_new_tokens=output_len,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
    
    # Decode the generated text and extract the answer
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the model's answer part
    model_part = generated_text[len(user_prompt):]
    
    # Remove the model chat template if present
    if "<start_of_turn>model\n" in model_part:
        answer = model_part.split("<start_of_turn>model\n")[1].split("<end_of_turn>")[0].strip()
    else:
        answer = model_part.strip()
    
    return answer

In [None]:
# Example SQuAD passages and questions for testing
examples = [
    {
        "context": "Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.",
        "question": "Which NFL team won Super Bowl 50?",
        "reference_answer": "Denver Broncos"
    },
    {
        "context": "Computational complexity theory is a branch of the theory of computation in theoretical computer science that focuses on classifying computational problems according to their inherent difficulty. A computational problem is understood to be a task that is in principle amenable to being solved by a computer, which is equivalent to stating that the problem may be solved by mechanical application of mathematical steps, such as an algorithm.",
        "question": "What is computational complexity theory a branch of?",
        "reference_answer": "theory of computation"
    },
    {
        "context": "Nikola Tesla (10 July 1856 – 7 January 1943) was a Serbian-American inventor, electrical engineer, mechanical engineer, and futurist best known for his contributions to the design of the modern alternating current (AC) electricity supply system. Born and raised in the Austrian Empire, Tesla studied engineering and physics in the 1870s without receiving a degree, gaining practical experience in the early 1880s working in telephony and at Continental Edison in the new electric power industry.",
        "question": "When was Nikola Tesla born?",
        "reference_answer": "10 July 1856"
    }
]

# Test the model on the examples
for idx, example in enumerate(examples):
    print(f"Example {idx+1}:")
    print(f"Context: {example['context'][:100]}...")
    print(f"Question: {example['question']}")
    print(f"Reference Answer: {example['reference_answer']}")
    
    model_answer = answer_question(example['context'], example['question'])
    print(f"Model Answer: {model_answer}")
    print("-" * 80)

## 7. Compare Gemma 2 vs Gemma 3 Performance

Now that we've fine-tuned Gemma 3 on the SQuAD dataset, let's analyze the differences in performance compared to Gemma 2.

In [None]:
# Load the previously fine-tuned Gemma 2 model (if available)
# Note: Adjust paths as needed
gemma2_path = "./results"  # Path to your Gemma 2 fine-tuned model

try:
    # Import necessary libraries for Gemma 2
    from transformers import AutoModelForCausalLM
    
    # Load Gemma 2 model and tokenizer
    gemma2_peft_config = PeftConfig.from_pretrained(gemma2_path)
    gemma2_base_model = AutoModelForCausalLM.from_pretrained(
        gemma2_peft_config.base_model_name_or_path,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    gemma2_model = PeftModel.from_pretrained(gemma2_base_model, gemma2_path)
    gemma2_tokenizer = AutoTokenizer.from_pretrained(gemma2_path)
    
    def gemma2_answer_question(context, question):
        prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
        
        inputs = gemma2_tokenizer(prompt, return_tensors="pt").to(gemma2_model.device)
        
        with torch.no_grad():
            outputs = gemma2_model.generate(
                **inputs,
                max_new_tokens=50,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )
        
        generated_text = gemma2_tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer = generated_text[len(prompt):].strip()
        
        return answer
    
    print("Successfully loaded Gemma 2 model for comparison")
    has_gemma2 = True
except Exception as e:
    print(f"Couldn't load Gemma 2 model: {e}")
    has_gemma2 = False

In [None]:
if has_gemma2:
    # Compare Gemma 2 vs Gemma 3 on the examples
    print("\n==== COMPARISON: GEMMA 2 vs GEMMA 3 ====\n")
    
    for idx, example in enumerate(examples):
        print(f"Example {idx+1}:")
        print(f"Question: {example['question']}")
        print(f"Reference Answer: {example['reference_answer']}")
        
        # Get answers from both models
        gemma2_answer = gemma2_answer_question(example['context'], example['question'])
        gemma3_answer = answer_question(example['context'], example['question'])
        
        print(f"Gemma 2 Answer: {gemma2_answer}")
        print(f"Gemma 3 Answer: {gemma3_answer}")
        print("-" * 80)

## 8. Key Differences Between Gemma 2 and Gemma 3

Based on the model implementations and fine-tuning process, here are some key differences between Gemma 2 and Gemma 3:

1. **Vocabulary Size**: 
   - Gemma 2: 256,000 tokens
   - Gemma 3: 262,144 tokens (larger vocabulary)

2. **Architecture Changes**:
   - Gemma 3 includes multimodal capabilities with the `Gemma3ForMultimodalLM` class
   - Gemma 3 uses a different layer configuration (Gemma 3 4B has 34 layers vs. different configurations in Gemma 2)
   - QK normalization is enabled by default in Gemma 3

3. **Context Length**:
   - Gemma 2: 8,192 tokens
   - Gemma 3: 32,768 tokens (4x longer context window)

4. **Attention Mechanism**:
   - Gemma 3 uses interleaved local/global attention with larger window sizes
   - Attention window sizes in Gemma 3 4B: [1024, 1024, 1024, 1024, 1024, 32768]

5. **Model Dimensionality**:
   - Different model dimensions and hidden layer sizes
   - Gemma 3 4B has model_dim=2560 compared to Gemma 2 models

6. **Chat Template**:
   - Gemma 3 uses the `GEMMA_VLM` prompt wrapping style for multimodal capabilities

7. **Performance Expectations**:
   - Improved reasoning capabilities
   - Better handling of longer contexts
   - More robust performance on complex questions