In [None]:
from utils import load_model

model_id = 'gemma-3'
model, processor = load_model(model_id)
model.eval()

In [None]:
from datasets import load_dataset
dataset_dict = load_dataset("cais/mmlu", "all")

# Let's say you want to sample from the 'test' split
split_name = 'test'
if split_name not in dataset_dict:
    print(f"Split '{split_name}' not found. Please choose from: {list(dataset_dict.keys())}")
    exit()

ds_split = dataset_dict[split_name]
print(f"\nSelected split '{split_name}' with {len(ds_split)} examples.")

# 3. Define the number of random samples you want
n = 500  # For example, 5 random samples

# 4. Ensure n is not larger than the dataset size
if n > len(ds_split):
    print(f"Warning: n ({n}) is larger than the dataset size ({len(ds_split)}).")
    print(f"Sampling all {len(ds_split)} available items instead.")
    n_samples_to_take = len(ds_split)
else:
    n_samples_to_take = n

# 5. Shuffle the chosen split and select the first n items
random_sample_dataset = ds_split.shuffle(seed=42).select(range(n_samples_to_take))

In [None]:
system_prompt = "You are a helpful assistant."
def generate_and_decode_new_tokens(prompt, model, processor, model_id, max_new_tokens=256):
    """
    Generate a response to a prompt and decode only the new tokens.
    
    Args:
        prompt (str): The input prompt text
        model: The language model to use for generation
        processor: The tokenizer/processor for encoding/decoding
        max_new_tokens (int): Maximum number of new tokens to generate
        
    Returns:
        str: The decoded text from only the newly generated tokens
    """
    if model_id == 'gemma-3':
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_prompt}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt + " Give me your best guess and answer as concisely as possible."}
                ]
            }
        ]
    else:
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": prompt + " Give me your best guess and answer as concisely as possible."
            }
        ]    
    inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        
    if model_id == 'gemma-3':
        inputs = processor(text=inputs, return_tensors="pt").to('cuda')
    else:
        inputs = processor(inputs, return_tensors="pt").to('cuda')

    input_len = inputs["input_ids"].shape[-1]
    
    with torch.inference_mode():
        if 'intervenable' in str(type(model)).lower():
            _, generation = model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
        else:
            generation = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        new_tokens = generation[0][input_len:]

    
    # Decode only the new tokens
    res_1 = processor.decode(new_tokens, skip_special_tokens=True)
    return res_1

In [None]:
questions_test = random_sample_dataset['question']
answers_idxs = random_sample_dataset['answer']
choices_test = random_sample_dataset['choices']
answers_test = [x[answers_idxs[i]] for i, x in enumerate(choices_test)]

In [None]:
from tqdm.auto import tqdm
import torch
predictions = []
for question in tqdm(questions_test):
    res = generate_and_decode_new_tokens(question, model, processor, model_id)
    predictions.append(res)

# Compute Score

In [None]:
from utils import to_request, create_anthropic_batch_job

requests = [to_request(f'mmlu_{model_id}_base-{i}', q, a, p) 
                for i, (q, a, p) in enumerate(zip(questions_test, answers_test, predictions))]

import json
with open(f"mmlu-{model_id}_base.jsonl", 'w') as f:
    for item in requests:
        json_line = json.dumps(item)
        f.write(json_line + '\n')

from openai import OpenAI
client = OpenAI(api_key="")
batch_input_file = client.files.create(
    file=open(f"mmlu-{model_id}_base.jsonl", "rb"),
    purpose="batch"
)
batch_input_file_id = batch_input_file.id
client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": f"mmlu-{model_id}_base"
    }
)