# MMLU Benchmark Smoke Test

Verifying `llama_jax.benchmarks.mmlu` behaves as expected.

# Parameters

In [1]:
checkpoint = "Llama3.2-3B-Instruct"
n_shots = 0
bs = 32

# Setup

In [2]:
import os
from pathlib import Path
from random import sample
import sys
from time import perf_counter_ns as timer

import jax
from jax import numpy as jnp
from jax.nn import softmax
import rich
from rich.columns import Columns
from matplotlib import pyplot as plt
from tqdm.auto import trange, tqdm

import llama_jax as ll
from llama_jax.benchmarks.mmlu import (
    display_questions,
    download_dataset, 
    load_dataset,
    evaluate_generator,
)

In [3]:
# Configure
datasets_path = Path(os.environ["PROJECT_ROOT"]) / "build" / "datasets"
mmlu_dataset_path = datasets_path / "mmlu"

In [4]:
print(f"GPU Devices: {jax.device_count()}")



Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

GPU Devices: 1


I0000 00:00:1740331587.338119 77311374 service.cc:145] XLA service 0x142b2c140 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1740331587.338131 77311374 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1740331587.339653 77311374 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1740331587.339666 77311374 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


# Dataset

In [5]:
download_dataset(mmlu_dataset_path)

In [6]:
dataset = load_dataset(mmlu_dataset_path)
print(f"Loaded {len(dataset.questions)} questions, {len(dataset.examples)} examples, {len(dataset.categories)} categories")

Loaded 14042 questions, 285 examples, 57 categories


In [7]:
# Display sample
display_questions(dataset.questions)

Unnamed: 0,qid,category,question,A,B,C,D,answer
0,9144,moral scenarios,For which of these two scenarios does the main...,"Wrong, Wrong","Wrong, Not wrong","Not wrong, Wrong","Not wrong, Not wrong",C
1,9314,moral scenarios,For which of these two scenarios does the main...,"Wrong, Wrong","Wrong, Not wrong","Not wrong, Wrong","Not wrong, Not wrong",C
2,10207,prehistory,"The construction of large-scale features, such...",the practice of slavery.,social and political complexity.,a Mesolithic tradition.,the shift from Paleolithic to Neolithic.,B
3,10893,professional law,A husband was about to leave his home for work...,"prevail, because the husband is strictly liabl...","prevail, because the statute was designed to p...","not prevail, because the driver had the last c...","not prevail, because the husband was acting re...",D
4,9835,philosophy,The control condition claims:,it is morally wrong to ever lose control of on...,it is morally wrong to try to control the live...,one cannot be morally assessed for what is due...,one cannot be morally blamed for taking contro...,C


# Categories

In [8]:
rich.print(Columns(sorted(list(dataset.categories))))

# Model

In [9]:
config = ll.checkpoint.load_config(checkpoint, max_tokens=1024)
tokenizer = ll.checkpoint.load_tokenizer(config)
params = ll.checkpoint.load_parameters(config)
model = ll.model.create(config, params)
generator = ll.benchmarks.mmlu.generator(config, model=model, n_shots=n_shots, examples=dataset.examples, bs=bs)

In [10]:
# Lookup token ids for MMLU options
mmlu_token_ids = jnp.array([tokenizer.encode(option, bos=False).item() for option in ll.benchmarks.mmlu.OPTIONS])
mmlu_token_ids

Array([32, 33, 34, 35], dtype=int32)

# US History

In [11]:
questions = tuple(q for q in dataset.questions if q.category == "high school us history")
display_questions(questions)

Unnamed: 0,qid,category,question,A,B,C,D,answer
0,5804,high school us history,This question refers to the following informat...,Dred Scott v. Sanford (1857),Ex parte Milligan (1866),Plessy v. Ferguson (1896),Brown v. Board of Education of Topeka (1954),A
1,5694,high school us history,This question refers to the following informat...,the policy of containment.,"the principle of ""massive retaliation.""",participation in the Atlantic Charter.,"embarking on a ""roll-back"" of communism.",A
2,5770,high school us history,This question refers to the following informat...,many Vietnamese viewed the United States as a ...,most Vietnamese were opposed to Communism,most Vietnamese favored Communism,some Viet Cong fought alongside American troop...,A
3,5676,high school us history,This question refers to the following informat...,Debates about access to voting rights.,Debates about the role of federal government i...,Debates about discrimination in employment.,Debates about the legal status of women.,D
4,5763,high school us history,This question refers to the following informat...,"In the 1990s, anti-government sentiment, in re...",In 1993 Attorney General Janet Reno launched a...,Fearing a rise in anti-government sentiment af...,Attorney General Janet Reno resigned from offi...,A


In [12]:
# 5719
question = next(q for q in questions if q.qid == 5719)

In [13]:
display_questions([question])

# Prompt
messages = ll.benchmarks.mmlu.generate_prompt(question, n_shots=n_shots, examples=dataset.examples)
prompt = ll.chat.render_prompt(messages)
print(prompt)

# Split prompt into tokens
token_ids = tokenizer.encode(prompt)
print(f"Split prompt into {len(token_ids[0])} tokens")

# Transform token ids into next token logits
logits = ll.model.forward(config, model, token_ids)

# Map logits to option scores
mmlu_logits = logits.take(mmlu_token_ids, axis=-1)
scores = softmax(mmlu_logits, axis=-1)
print(f"Scores: {scores}")

Unnamed: 0,qid,category,question,A,B,C,D,answer
0,5719,high school us history,This question refers to the following informat...,The Declaration of Independence,The Albany Plan,The Boston Tea Party,The Constitution of the United States,A


<|start_header_id|>system<|end_header_id|>

You are a student answering multiple choice questions on an exam. Each question has 4 options: A, B, C, D. There will be 0 example questions followed by a test question. Your job is to answer the test question. Your answer MUST be one of {A, B, C, D}.<|eot_id|>
<|start_header_id|>user<|end_header_id|>

# Instructions

The following are multiple choice questions (with answers) about high school us history.

# Question

This question refers to the following information.
Let us not, I beseech you sir, deceive ourselves. Sir, we have done everything that could be done, to avert the storm which is now coming on. We have petitioned; we have remonstrated; we have supplicated; we have prostrated ourselves before the throne, and have implored its interposition to arrest the tyrannical hands of the ministry and Parliament. Our petitions have been slighted; our remonstrances have produced additional violence and insult; our supplications have been disre

In [14]:
answer = next(generator([question]))[0]
answer

Answer(qid=5719, expected='A', actual='A', scores={'A': 0.70703125, 'B': 0.2294921875, 'C': 0.012939453125, 'D': 0.051025390625}, correct=True)

In [15]:
# 5770
question = next(q for q in questions if q.qid == 5770)

In [16]:
display_questions([question])

# Prompt
messages = ll.benchmarks.mmlu.generate_prompt(question, n_shots=n_shots, examples=dataset.examples)
prompt = ll.chat.render_prompt(messages)
print(prompt)

# Split prompt into tokens
token_ids = tokenizer.encode(prompt)
print(f"Split prompt into {len(token_ids[0])} tokens")

# Transform token ids into next token logits
logits = ll.model.forward(config, model, token_ids)

# Map logits to option scores
mmlu_logits = logits.take(mmlu_token_ids, axis=-1)
scores = softmax(mmlu_logits, axis=-1)
print(f"Scores: {scores}")

Unnamed: 0,qid,category,question,A,B,C,D,answer
0,5770,high school us history,This question refers to the following informat...,many Vietnamese viewed the United States as a ...,most Vietnamese were opposed to Communism,most Vietnamese favored Communism,some Viet Cong fought alongside American troop...,A


<|start_header_id|>system<|end_header_id|>

You are a student answering multiple choice questions on an exam. Each question has 4 options: A, B, C, D. There will be 0 example questions followed by a test question. Your job is to answer the test question. Your answer MUST be one of {A, B, C, D}.<|eot_id|>
<|start_header_id|>user<|end_header_id|>

# Instructions

The following are multiple choice questions (with answers) about high school us history.

# Question

This question refers to the following information.
"We found that not only was it a civil war, an effort by a people who had for years been seeking their liberation from any colonial influence whatsoever, but also we found that the Vietnamese whom we had enthusiastically molded after our own image were hard put to take up the fight against the threat we were supposedly saving them from.
"We found most people didn't even know the difference between communism and democracy. They only wanted to work in rice paddies without helicopt

In [17]:
answer = next(generator([question]))[0]
answer

Answer(qid=5770, expected='A', actual='B', scores={'A': 0.1904296875, 'B': 0.314453125, 'C': 0.216796875, 'D': 0.279296875}, correct=False)

# Microeconomics

In [18]:
questions = tuple(q for q in dataset.questions if q.category == "high school microeconomics")
display_questions(questions)

Unnamed: 0,qid,category,question,A,B,C,D,answer
0,4628,high school microeconomics,"Education makes Chris a better worker, voter, ...",increasing marginal utility and should be subs...,externalities and should be taxed,decreasing marginal utility and should be taxed,externalities and should be subsidized,D
1,4549,high school microeconomics,If the demand for grapes increases simultaneou...,"equilibrium quantity rises, but the price chan...","equilibrium quantity falls, but the price chan...","equilibrium quantity rises, and the price rises.","equilibrium quantity falls, and the price falls.",A
2,4610,high school microeconomics,Relatively free or easy entry (low or nonexist...,"More consumer choices, greater price elasticit...","More consumer choices, lower price elasticity ...","More consumer choices, greater price elasticit...","Fewer consumer choices, lower price elasticity...",A
3,4594,high school microeconomics,Suppose the market for roses is currently in e...,Price and quantity both rise.,"Price rises, but the change in quantity is amb...",Price and quantity both fall.,"Quantity rises, but the change in price is amb...",B
4,4669,high school microeconomics,If the market price is above the perfectly com...,the industry contracts as firms exit the market.,the industry expands as firms exit the market.,the industry contracts as firms enter the market.,the industry expands as firms enter the market.,D


In [19]:
# 4594
question = next(q for q in questions if q.qid == 4594)

In [20]:
display_questions([question])

# Prompt
messages = ll.benchmarks.mmlu.generate_prompt(question, n_shots=n_shots, examples=dataset.examples)
prompt = ll.chat.render_prompt(messages)
print(prompt)

# Split prompt into tokens
token_ids = tokenizer.encode(prompt)
print(f"Split prompt into {len(token_ids[0])} tokens")

# Transform token ids into next token logits
logits = ll.model.forward(config, model, token_ids)

# Map logits to option scores
mmlu_logits = logits.take(mmlu_token_ids, axis=-1)
scores = softmax(mmlu_logits, axis=-1)
print(f"Scores: {scores}")

Unnamed: 0,qid,category,question,A,B,C,D,answer
0,4594,high school microeconomics,Suppose the market for roses is currently in e...,Price and quantity both rise.,"Price rises, but the change in quantity is amb...",Price and quantity both fall.,"Quantity rises, but the change in price is amb...",B


<|start_header_id|>system<|end_header_id|>

You are a student answering multiple choice questions on an exam. Each question has 4 options: A, B, C, D. There will be 0 example questions followed by a test question. Your job is to answer the test question. Your answer MUST be one of {A, B, C, D}.<|eot_id|>
<|start_header_id|>user<|end_header_id|>

# Instructions

The following are multiple choice questions (with answers) about high school microeconomics.

# Question

Suppose the market for roses is currently in equilibrium. If the supply of roses falls, while at the same time the demand for roses rises, what can you say about the price and quantity of roses in the market?

A) Price and quantity both rise.
B) Price rises, but the change in quantity is ambiguous.
C) Price and quantity both fall.
D) Quantity rises, but the change in price is ambiguous.

Answer: <|eot_id|>
<|start_header_id|>assistant<|end_header_id|>


Split prompt into 193 tokens
Scores: [[0.00212097 0.96875 0.00509644 0.0

In [21]:
answer = next(generator([question]))[0]
answer

Answer(qid=4594, expected='B', actual='B', scores={'A': 0.0021209716796875, 'B': 0.96875, 'C': 0.005096435546875, 'D': 0.0257568359375}, correct=True)