In [1]:
import sys
sys.path.append('gemma_pytorch')

In [2]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '9b', '9b-it', '27b', '27b-it']
MACHINE_TYPE = 'cpu' #@param ['cuda', 'cpu']

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'

In [3]:
import torch
import os
from datasets import load_dataset
from evaluate import load  # Use evaluate instead of load_metric
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
from gemma.config import get_model_config
import kagglehub


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Download weights directory from Kaggle
try:
    weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
    print(f"Downloaded weights to: {weights_dir}")
except Exception as e:
    raise RuntimeError(f"Failed to download model weights: {e}")

Downloaded weights to: /Users/uochuba/.cache/kagglehub/models/google/gemma-2/pyTorch/gemma-2-2b-it/1


In [None]:
# NOTE: The "installation" is just cloning the repo.
# !git clone https://github.com/google/gemma_pytorch.git

In [5]:
# Verify the presence of tokenizer and checkpoint files
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
if not os.path.isfile(tokenizer_path):
    raise FileNotFoundError(f"Tokenizer not found at: {tokenizer_path}")

ckpt_path = os.path.join(weights_dir, 'model.ckpt')
if not os.path.isfile(ckpt_path):
    raise FileNotFoundError(f"PyTorch checkpoint not found at: {ckpt_path}")

print("Tokenizer and checkpoint files verified.")

Tokenizer and checkpoint files verified.


In [6]:
# Load SQuAD dataset
dataset = load_dataset("squad")

# Load Gemma 2B model
MODEL_VARIANT = "2b-v2"  # Update if needed
model_config = get_model_config(MODEL_VARIANT)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Load tokenizer and model
# tokenizer = Tokenizer(tokenizer_path)
torch.set_default_dtype(model_config.get_dtype())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model.to(device).eval()

GemmaForCausalLM(
  (embedder): Embedding()
  (model): GemmaModel(
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear()
          (o_proj): Linear()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear()
          (up_proj): Linear()
          (down_proj): Linear()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
        (pre_feedforward_layernorm): RMSNorm()
        (post_feedforward_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (sampler): Sampler()
)

In [None]:
# Helper functions
def format_prompt(context, question):
    """
    Format input for the Gemma model based on the updated tokenizer integration.
    """
    prompt = f"<start_of_turn>user\nAnswer as concisely as possible. Context: {context}\nQuestion: {question}<end_of_turn><eos>\n<start_of_turn>model\n"
    return prompt

def generate_answer(model, prompt, max_length=128, device="cuda"):
    """
    Generate an answer using the Gemma model with integrated tokenizer.
    """
    # Pass the prompt directly as a string
    outputs = model.generate(prompt, output_len=max_length, device=device)  
    answer = outputs.split("<end_of_turn>")[0].split("<start_of_turn>model\n")[-1]
    return answer.strip()

def evaluate_squad(model, dataset, device="cuda"):
    """
    Evaluate Gemma on SQuAD using Exact Match (EM) and F1 metrics.
    """
    metric = load("squad")
    exact_match = 0
    f1 = 0
    total = len(dataset["validation"].select(range(3)))
    
    predictions = []
    references = []
    
    for example in dataset["validation"].select(range(3)):
        context = example["context"]
        question = example["question"]
        answers = example["answers"]["text"]
        
        # Generate model's answer
        prompt = format_prompt(context, question)
        prediction = generate_answer(model, prompt, max_length=128, device=device)
        
        print("Context:", context)
        print("Question:", question)
        print("Expected Answers:", answers)
        print("Model Prediction:", prediction)
        
        # Prepare predictions and references in the format expected by the metric
        predictions.append({
            "prediction_text": prediction,
            "id": str(len(predictions))  # Add a unique ID
        })
        
        references.append({
            "answers": {
                "text": answers,
                "answer_start": [0] * len(answers)  # Provide a default answer_start
            },
            "id": str(len(references))  # Add a unique ID
        })
    
    # Compute metrics
    scores = metric.compute(predictions=predictions, references=references)
    
    # Calculate average scores
    exact_match = scores["exact_match"]
    f1 = scores["f1"]
    
    print(f"Exact Match (EM): {exact_match:.2f}%")
    print(f"F1 Score: {f1:.2f}%")

In [27]:
# Run evaluation
evaluate_squad(model, dataset, device=device)

Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
Question: Which NFL team represented the AFC at Super Bowl 50?
Expected Answers: ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']
Model Prediction: Denver Broncos
Context: Super Bowl 50 was an American footbal