# MMLU Benchmark

[Massive Multitask Language Understanding (MMLU)](https://github.com/hendrycks/test) is a popular benchmark for evaluating language models' world knowledge and problem solving abilities. The MMLU dataset contains 14,042 multiple choice questions (MCQs) from 57 categories including mathematics, history, biology, and business. Each question has 4 options (A, B, C, D) and one correct answer. In addition, each category includes 5 example questions designed for few shot experiments. When MMLU was first published in 2020, only the largest GPT models could do better than random guessing. By 2024, multiple models from OpenAI, Anthropic, Meta, and Tencent have all published MMLU accuracies over 88%.

In this experiment, we'll measure Llama performance against MMLU ourselves. Our goal is to recreate Meta's published MMLU benchmark scores:

* MMLU of Llama 3.2 3B of 58% ([MODEL CARD](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md))

# Setup

In [1]:
import logging
from logging import Formatter, StreamHandler
import os
from pathlib import Path
from random import sample
import sys
from time import perf_counter_ns as timer

from matplotlib import pyplot as plt
from tqdm.auto import trange, tqdm

import llama_jax as ll
from llama_jax.benchmarks.mmlu import (
    display_questions,
    download_dataset, 
    load_dataset,
    evaluate_generator,
)

In [2]:
# Configure
datasets_path = Path(os.environ["PROJECT_ROOT"]) / "build" / "datasets"
mmlu_dataset_path = datasets_path / "mmlu"

# formatter = Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
# handler = StreamHandler(stream=sys.stderr)
# handler.setFormatter(formatter)
# logging.root.addHandler(handler)
# logging.root.setLevel(logging.DEBUG)

# Load Dataset

In [3]:
download_dataset(mmlu_dataset_path)

In [None]:
dataset = load_dataset(mmlu_dataset_path)
print(f"Loaded {len(dataset.questions)} questions, {len(dataset.examples)} examples, {len(dataset.categories)} categories")

In [None]:
# Display sample
display_questions(dataset.questions)

# Zero-Shot, Sampled

Before we run the end to end MMLU benchmark, this first stage will measure the accuracy on a small sample with no examples (0-shot).

In [6]:
checkpoint = "Llama3.2-3B-Instruct"
n_iterations = 3
n_questions = 128
n_shots = 0
bs = 32

In [None]:
# Initialize mmlu generator from checkpoint
config = ll.checkpoint.load_config(checkpoint, max_tokens=1024)
generator = ll.benchmarks.mmlu.generator(config, n_shots=n_shots, examples=dataset.examples, bs=bs)

In [None]:
# Pre-compile model
next(generator(sample(dataset.questions, k=1)))

In [None]:
start_time = timer()

scores = []
for _ in trange(n_iterations, desc="Iterations"):

    # Randomly sample questions
    questions = sample(dataset.questions, k=n_questions)

    # Track progress
    progress = tqdm(total=n_questions, desc="Questions", leave=False)
    
    score = evaluate_generator(
        generator,
        questions=questions,
        progress=progress,
    )
    scores.append(score)

duration = ((timer() - start_time) / 1000000000)

In [None]:
t = duration / (n_iterations * n_questions)
print(f"Average {t:0.2f} s/q")

In [None]:
plt.boxplot(scores)
plt.ylabel("MMLU Score")
plt.show()