# How to Run MMLU with Generative Models (Hugging Face Transformers)
Based on: https://github.com/FranxYao/chain-of-thought-hub/blob/main/MMLU/run_mmlu_open_source.py

In [1]:
import json
import os
import time
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import gc

In [2]:
HF_CACHE_LOCATION = "/data/shk148/models/opt/cache"

TASKS = [
        'abstract_algebra',
        'anatomy',
        'astronomy',
        'business_ethics',
        'clinical_knowledge',
        'college_biology',
        'college_chemistry',
        'college_computer_science',
        'college_mathematics',
        'college_medicine',
        'college_physics',
        'computer_security',
        'conceptual_physics',
        'econometrics',
        'electrical_engineering',
        'elementary_mathematics',
        'formal_logic',
        'global_facts',
        'high_school_biology',
        'high_school_chemistry',
        'high_school_computer_science',
        'high_school_european_history',
        'high_school_geography',
        'high_school_government_and_politics',
        'high_school_macroeconomics',
        'high_school_mathematics',
        'high_school_microeconomics',
        'high_school_physics',
        'high_school_psychology',
        'high_school_statistics',
        'high_school_us_history',
        'high_school_world_history',
        'human_aging',
        'human_sexuality',
        'international_law',
        'jurisprudence',
        'logical_fallacies',
        'machine_learning',
        'management',
        'marketing',
        'medical_genetics',
        'miscellaneous',
        'moral_disputes',
        'moral_scenarios',
        'nutrition',
        'philosophy',
        'prehistory',
        'professional_accounting',
        'professional_law',
        'professional_medicine',
        'professional_psychology',
        'public_relations',
        'security_studies', 
        'sociology',
        'us_foreign_policy',
        'virology',
        'world_religions'
        ]

choices = ["A", "B", "C", "D"]
DATA_DIR = "/data/shk148/MMLU"

In [3]:
# Helper Functions: Largely unchanged from https://github.com/FranxYao/chain-of-thought-hub/blob/main/MMLU/run_mmlu_open_source.py

def compute_metric(output_filename):
    with open(output_filename, 'r') as f:
        run_results = json.load(f)
    total_acc = 0
    total_num = 0
    for task in run_results:
        acc = 0
        pred_answers = run_results[task]['pred_answers']
        gold_answers = run_results[task]['gold_answers']
        for pred, gold in zip(pred_answers, gold_answers):
            if pred == gold: acc += 1
        print("ACC-%s: %.4f" % (task, acc/len(gold_answers)))
        total_acc += acc
        total_num += len(gold_answers)
    print("ACC-all: %.4f" % (total_acc/total_num))


def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt

def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    return prompt

def prepare_input(tokenizer, prompts):
    input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True)
    input_tokens = {k:input_tokens[k] for k in input_tokens if k in ["input_ids", "attention_mask"]}
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to('cuda')

    return input_tokens

def batch_split(prompts, batch_num):
    batch_prompts = []
    mini_batch = []
    for prompt in prompts:
        mini_batch.append(prompt)
        if len(mini_batch) == batch_num:
            batch_prompts.append(mini_batch)
            mini_batch = []
    if len(mini_batch) != 0:
        batch_prompts.append(mini_batch)
    return batch_prompts

def batch_infer(model, tokenizer, prompts):
    batch_size = 8
    answers = []
    for batch_input in tqdm(batch_split(prompts, batch_size)):
        encode_inputs = prepare_input(tokenizer, batch_input)
        outputs = model.generate(**encode_inputs, max_new_tokens=1, pad_token_id=tokenizer.pad_token_id)
        answers.extend(tokenizer.batch_decode(outputs, skip_special_tokens=True))
    answers = [answer[-1] for answer in answers]
    return answers


In [4]:
def load(checkpoint, model_type):
    n_gpus = torch.cuda.device_count()
    
    model = AutoModelForCausalLM.from_pretrained(checkpoint, cache_dir=HF_CACHE_LOCATION).cuda()
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, cache_dir=HF_CACHE_LOCATION, padding_side='left')      
    model.eval()
    return model, tokenizer

In [5]:
def main(ckpt_dir: str, param_size: str, model_type: str):
    
    run_results = {}
    output_filename = 'run_results_%s_%sb.json' % (model_type, param_size)
    
    model, tokenizer = load(ckpt_dir, model_type)
    start_time = time.time()
    for task in TASKS:
        print('Testing %s ...' % task)
        records = []
        dev_df = pd.read_csv(os.path.join(DATA_DIR, "dev", task + "_dev.csv"), header=None)[:5]
        test_df = pd.read_csv(os.path.join(DATA_DIR, "test", task + "_test.csv"), header=None)
        for i in range(test_df.shape[0]):
            # get prompt and make sure it fits
            k = 5
            prompt_end = format_example(test_df, i, include_answer=False)
            train_prompt = gen_prompt(dev_df, task, k)
            prompt = train_prompt + prompt_end
            while len(tokenizer.tokenize(prompt)) + 1> 2048: # bos token
                prompt_split = prompt.split("\n\n")
                prompt_split.pop(1)
                prompt = '\n\n'.join(prompt_split)
            label = test_df.iloc[i, test_df.shape[1]-1]
            records.append({'prompt':prompt, 'answer':label})

        pred_answers = batch_infer(model, tokenizer, [record['prompt'] for record in records])
        gold_answers = [record['answer'] for record in records]
        run_results[task] = {'pred_answers':pred_answers, 'gold_answers':gold_answers}
    with open(output_filename, 'w') as f:
        json.dump(run_results, f, ensure_ascii=False, indent=2)
    
    compute_metric(output_filename)
    end_time = time.time()
    print("total run time %.2f" % (end_time - start_time))
    del model
    del tokenizer
    return (end_time - start_time)

In [6]:

draft_models = [
	"facebook/opt-125m",
	"facebook/opt-350m",
	"facebook/opt-1.3b",
	"facebook/opt-2.7b",
    "facebook/opt-6.7b"
]
rslt = dict()

for draft_model in draft_models:
    rslt[draft_model] =  main(draft_model, draft_model.replace("facebook/opt-", ""), "OPT")
    gc.collect()
    torch.cuda.empty_cache()
print(rslt)


Testing abstract_algebra ...


100%|██████████| 13/13 [00:00<00:00, 15.21it/s]


ACC-abstract_algebra: 0.2900
ACC-all: 0.2900
total run time 0.95
Testing abstract_algebra ...


100%|██████████| 13/13 [00:01<00:00, 10.44it/s]


ACC-abstract_algebra: 0.2200
ACC-all: 0.2200
total run time 1.33
Testing abstract_algebra ...


100%|██████████| 13/13 [00:04<00:00,  2.86it/s]


ACC-abstract_algebra: 0.2600
ACC-all: 0.2600
total run time 4.63
Testing abstract_algebra ...


100%|██████████| 13/13 [00:08<00:00,  1.48it/s]


ACC-abstract_algebra: 0.1900
ACC-all: 0.1900
total run time 8.89


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Testing abstract_algebra ...


100%|██████████| 13/13 [00:22<00:00,  1.74s/it]


ACC-abstract_algebra: 0.2500
ACC-all: 0.2500
total run time 22.72
{'facebook/opt-125m': 0.94742751121521, 'facebook/opt-350m': 1.3316192626953125, 'facebook/opt-1.3b': 4.629910469055176, 'facebook/opt-2.7b': 8.892475128173828, 'facebook/opt-6.7b': 22.72069764137268}
