In [46]:
from datasets import load_dataset, concatenate_datasets

In [5]:
mmlu_task_list =  ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']

In [19]:
# high school + 'elementary' tasks
mmlu_hs_list = [task for task in mmlu_task_list if task.startswith('high') or task.startswith('elementary')]
# college + professional tasks
mmlu_adv_list = [task for task in mmlu_task_list if task.startswith('college') or task.startswith('professional')]

mmlu_hs_dict = {task:None for task in mmlu_hs_list}
mmlu_adv_dict = {task:None for task in mmlu_adv_list}

In [None]:
# load mmlu
for task in mmlu_hs_list:
    mmlu_hs_dict[task] = load_dataset('cais/mmlu', task)

for task in mmlu_adv_list:
    mmlu_adv_dict[task] = load_dataset('cais/mmlu', task)

In [50]:
# merge validation splits
hs_val = concatenate_datasets([datadict['validation'] for _, datadict in mmlu_hs_dict.items()])
adv_val = concatenate_datasets([datadict['validation'] for _, datadict in mmlu_adv_dict.items()])

In [53]:
# sample
num_samples = 350
seed = 42

hs_val = hs_val.shuffle(seed=seed).select(list(range(num_samples)))
adv_val = adv_val.shuffle(seed=seed).select(list(range(num_samples)))

In [56]:
# eval on pythia

In [None]:
from transformers import GPTNeoXForCausalLM, AutoTokenizer

# Two sets of eight models of sizes 70M, 160M, 410M, 1B, 1.4B, 2.8B, 6.9B, and 12B. 
# For each size, there are two models: one trained on the Pile, 
# and one trained on the Pile after the dataset has been globally deduplicated.
# 143 evenly-spaced checkpoints from step1000 to step143000

model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-70m-deduped",
  revision="step3000",
  cache_dir="./pythia-70m-deduped/step3000",
)

tokenizer = AutoTokenizer.from_pretrained(
  "EleutherAI/pythia-70m-deduped",
  revision="step3000",
  cache_dir="./pythia-70m-deduped/step3000",
)

inputs = tokenizer("Hello, I am", return_tensors="pt")
tokens = model.generate(**inputs)
tokenizer.decode(tokens[0])
