In [None]:
%%capture

! pip install transformers==4.50.2 datasets

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from tqdm import tqdm
import torch
import re

# Load Gemma model and tokenizer
model_path = "google/gemma-3-4b-it"
# model_path = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# System prompt for instruction tuning
system_prompt = "Give final answer only without steps."

# Generate answer using chat-style prompting and Gemma-3 settings
def generate_answer(prompt: str, max_tokens: int = 1024) -> str:
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt},
    ]
    chat_text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    inputs = tokenizer(chat_text, return_tensors="pt").to(model.device)

    # Run generation
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=1.0,
        top_p=0.95,
        top_k=64,
        pad_token_id=tokenizer.eos_token_id,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

In [None]:
generate_answer("2+2?")

'user\nGive final answer only without steps.\n\n2+2?\nmodel\n4\n'

In [None]:
from datasets import load_dataset

In [None]:
ds_aime = load_dataset("Maxwell-Jia/AIME_2024", split="train")

README.md:   0%|          | 0.00/1.78k [00:00<?, ?B/s]

aime_2024_problems.parquet:   0%|          | 0.00/37.2k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
correct=0
for i in range(10):
    curr_question = ds_aime[i]['Problem'] + "\nGenerate final answers only."
    curr_truth = ds_aime[i]['Answer']
    curr_pred = generate_answer(curr_question, max_tokens=20480)
    print(curr_question)
    print(curr_truth)
    print("--------------")
    print(curr_pred)
    if str(curr_truth) in curr_pred:
        correct += 1
        print("Correct!")
    else:
        print("Incorrect!")
    print("--------------")

print(f"Accuracy: {correct / len(ds_aime)}")

Let $x,y$ and $z$ be positive real numbers that satisfy the following system of equations: 
\[\log_2\left({x \over yz}\right) = {1 \over 2}\]
\[\log_2\left({y \over xz}\right) = {1 \over 3}\]
\[\log_2\left({z \over xy}\right) = {1 \over 4}\]
Then the value of $\left|\log_2(x^4y^3z^2)\right|$ is $\tfrac{m}{n}$ where $m$ and $n$ are relatively prime positive integers. Find $m+n$.
Generate final answers only.
33
--------------
user
Give final answer only without steps.

Let $x,y$ and $z$ be positive real numbers that satisfy the following system of equations: 
\[\log_2\left({x \over yz}\right) = {1 \over 2}\]
\[\log_2\left({y \over xz}\right) = {1 \over 3}\]
\[\log_2\left({z \over xy}\right) = {1 \over 4}\]
Then the value of $\left|\log_2(x^4y^3z^2)\right|$ is $\tfrac{m}{n}$ where $m$ and $n$ are relatively prime positive integers. Find $m+n$.
Generate final answers only.
model
From the given equations, we have:
\begin{align*} \label{eq:1} \log_2\left(\frac{x}{yz}\right) = \frac{1}{2} &\i

In [None]:
from typing import Any, Literal, Sequence

def evaluate_model(
    ds: Sequence[dict[str, Any]],
    prompt_key: str = "Problem",
    answer_key: str = "Answer",
    range_size: int = 10,
    verbose: bool = True,
    max_tokens: int = 20480,
) -> float:
    """
    Evaluates the model's accuracy on a given dataset.

    Args:
        ds: The dataset to evaluate on.
        prompt_key: The key in the dataset for the prompt.
        answer_key: The key in the dataset for the answer.
        range_size: The number of samples to evaluate.
        verbose: Whether to print detailed output.
        max_tokens: The maximum number of tokens to generate.

    Returns:
        The accuracy of the model on the dataset.
    """
    correct = 0
    for i in range(range_size):
        curr_question = ds[i][prompt_key] + "\nGenerate final answers only."
        curr_truth = ds[i][answer_key]
        curr_pred = generate_answer(curr_question, max_tokens=max_tokens)
        if verbose:
            print(curr_question)
            print(curr_truth)
            print("--------------")
            print(curr_pred)
        if str(curr_truth) in curr_pred:
            correct += 1
            if verbose:
                print("Correct!")
        else:
            if verbose:
                print("Incorrect!")
        if verbose:
            print("--------------")

    accuracy = correct / len(ds)
    if verbose:
        print(f"Accuracy: {accuracy}")
    return accuracy

In [None]:
%%time

# === Load and evaluate datasets ===

# Dataset 1: Maxwell-Jia/AIME_2024
ds_aime = load_dataset("Maxwell-Jia/AIME_2024", split="train")
acc_aime = evaluate_model(ds_aime, prompt_key="Problem", answer_key="Answer", max_samples=50)

# Dataset 2: HuggingFaceH4/MATH-500
ds_math = load_dataset("HuggingFaceH4/MATH-500", split="test")
acc_math = evaluate_model(ds_math, prompt_key="problem", answer_key="answer", max_samples=50)

# Dataset 3: Idavidrein/gpqa
ds_gpqa = load_dataset("Idavidrein/gpqa", name="gpqa_diamond", split="train")
acc_gpqa = evaluate_model(ds_gpqa, prompt_key="Pre-Revision Question", answer_key="Pre-Revision Correct Answer", max_samples=50)

# Print results
print(f"\n📊 AIME Accuracy:     {acc_aime:.3f}")
print(f"📊 MATH-500 Accuracy: {acc_math:.3f}")
print(f"📊 GPQA Accuracy:     {acc_gpqa:.3f}")