In [1]:
config: str = "all"
eval_split: str = "dev"
model_id: str = "microsoft/Phi-3-mini-4k-instruct"

In [2]:
from datasets import load_dataset

dataset = load_dataset("cais/mmlu", config, split=eval_split)

display(dataset)
display(dataset[0])

Dataset({
    features: ['question', 'subject', 'choices', 'answer'],
    num_rows: 285
})

{'question': 'Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.',
 'subject': 'abstract_algebra',
 'choices': ['0', '1', '2', '3'],
 'answer': 1}

In [3]:
answer_map = ["A", "B", "C", "D"]

example = {
    'question': 'How many legs does a banana have?',
    'subject': 'example',
    'choices': ['0', '1', '2', '3'],
    'answer': 0
}

qa_template = """
{question}

{choices}

{answer}
"""[1:-1]

def format_question(sample, show_answer: bool = False):
    return qa_template.format(
        question=sample["question"],
        choices="\n".join(
            f"{answer_map[i]}: {choice}"
            for i, choice in enumerate(sample["choices"])
        ),
        answer=answer_map[int(sample["answer"])] if show_answer else ""
    )

print("<start>" + format_question(example, show_answer=True) + "<end>")

<start>How many legs does a banana have?

A: 0
B: 1
C: 2
D: 3

A<end>


In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
user_message = """
Answer multiple choice questions in the format shown by the example below:

Example: {formatted_example}

Return only the final letter choice. Do not provide any further explanations. No yapping.

Question: {formatted_question}
"""[1:-1].format(
    formatted_example=format_question(example, show_answer=True),
    formatted_question="{formatted_question}"
)

print(user_message.format(formatted_question=format_question(dataset[0])))

Answer multiple choice questions in the format shown by the example below:

Example: How many legs does a banana have?

A: 0
B: 1
C: 2
D: 3

A

Return only the final letter choice. Do not provide any further explanations. No yapping.

Question: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.

A: 0
B: 1
C: 2
D: 3




In [6]:
from more_itertools import unzip
from tqdm import tqdm

def generate_prompt(sample) -> str:
    formatted_message = user_message.format(formatted_question=format_question(sample))
    messages = [{"role": "user", "content": formatted_message}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    assert isinstance(prompt, str)
    return prompt

def generate_answers_batched(dataset, batch_size=8):
    all_predictions = dict[int, str]()
    indexed_prompts = sorted(
        enumerate(map(generate_prompt, dataset)),
        key=lambda i_prompt: len(i_prompt[1])
    )

    for i in tqdm(range(0, len(indexed_prompts), batch_size), desc="Batch"):
        batch_idxs, batch_prompts = unzip(indexed_prompts[i: i + batch_size])
        batch_idxs, batch_prompts = list(batch_idxs), list(batch_prompts)

        batch_inputs = tokenizer(batch_prompts, padding=True, return_tensors="pt")
        batch_input_length = batch_inputs["input_ids"].shape[1]

        with torch.no_grad():
            batch_outputs = model.generate(
                **batch_inputs.to(model.device),
                max_new_tokens=10,
                do_sample=True,
                temperature=0.001,
                top_p=0.999,
            )

        for j, output in zip(batch_idxs, batch_outputs):
            all_predictions[j] = tokenizer.decode(
                output[batch_input_length:], skip_special_tokens=True
            )

    return dataset.add_column(
        "prediction", [all_predictions[i] for i in range(len(dataset))]
    )

example_preds = generate_answers_batched(dataset.select(range(4)), batch_size=2)
display(example_preds["prediction"])

Batch:   0%|          | 0/2 [00:00<?, ?it/s]

Batch: 100%|██████████| 2/2 [00:00<00:00,  2.11it/s]


['C', 'D', 'D', 'D']

In [7]:
dataset = generate_answers_batched(dataset)
dataset = dataset.map(lambda d: {"answer_letter": answer_map[d["answer"]]})
display(dataset.to_pandas())

Batch:   0%|          | 0/36 [00:00<?, ?it/s]

Batch: 100%|██████████| 36/36 [00:15<00:00,  2.26it/s]


Map:   0%|          | 0/285 [00:00<?, ? examples/s]

Unnamed: 0,question,subject,choices,answer,prediction,answer_letter
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,abstract_algebra,"[0, 1, 2, 3]",1,C,B
1,Statement 1 | If aH is an element of a factor ...,abstract_algebra,"[True, True, False, False, True, False, False,...",1,D,B
2,Statement 1 | Every element of a group generat...,abstract_algebra,"[True, True, False, False, True, False, False,...",2,D,C
3,Statement 1| Every function from a finite set ...,abstract_algebra,"[True, True, False, False, True, False, False,...",0,D,A
4,Find the characteristic of the ring 2Z.,abstract_algebra,"[0, 3, 12, 30]",0,A,A
...,...,...,...,...,...,...
280,What is the sign of the covenant for Jewish m...,world_religions,"[The rainbow, Circumcision, A son, Bar mitzvah]",1,B,B
281,What is the Second Gem in Buddhism?,world_religions,"[The Dharma, The Sangha, The Buddha, The Bodhi...",0,C,A
282,"In which dynasty was the ""Mandate of Heaven"" ...",world_religions,"[Shang, Zhou, Han, Xia]",1,B,B
283,Which Japanese government promoted a kind of ...,world_religions,"[Honen, Tanaka, Tokugawa, Meiji]",3,D,D


In [8]:
fmt_acc = sum(
    pred.strip() in answer_map
    for pred in dataset["prediction"]
) / len(dataset)
ans_acc = sum(
    pred.strip() == ans.strip()
    for pred, ans in zip(dataset["prediction"], dataset["answer_letter"])
) / len(dataset)
print(f"Format Accuracy: {fmt_acc}")
print(f"Answer Accuracy: {ans_acc}")

Format Accuracy: 0.9964912280701754
Answer Accuracy: 0.656140350877193
