In [None]:
import json
import os
from datasets import Dataset, DatasetDict
import os.path as osp
import csv

class MMLUDataset:
    mmlu_all_sets = [
        "college_biology",
        "college_chemistry",
        "college_computer_science",
        "college_mathematics",
        "college_physics",
        "electrical_engineering",
        "astronomy",
        "anatomy",
        "abstract_algebra",
        "machine_learning",
        "clinical_knowledge",
        "global_facts",
        "management",
        "nutrition",
        "marketing",
        "professional_accounting",
        "high_school_geography",
        "international_law",
        "moral_scenarios",
        "computer_security",
        "high_school_microeconomics",
        "professional_law",
        "medical_genetics",
        "professional_psychology",
        "jurisprudence",
        "world_religions",
        "philosophy",
        "virology",
        "high_school_chemistry",
        "public_relations",
        "high_school_macroeconomics",
        "human_sexuality",
        "elementary_mathematics",
        "high_school_physics",
        "high_school_computer_science",
        "high_school_european_history",
        "business_ethics",
        "moral_disputes",
        "high_school_statistics",
        "miscellaneous",
        "formal_logic",
        "high_school_government_and_politics",
        "prehistory",
        "security_studies",
        "high_school_biology",
        "logical_fallacies",
        "high_school_world_history",
        "professional_medicine",
        "high_school_mathematics",
        "college_medicine",
        "high_school_us_history",
        "sociology",
        "econometrics",
        "high_school_psychology",
        "human_aging",
        "us_foreign_policy",
        "conceptual_physics",
    ]
    question_sep='\n\n'
    def __init__(self, path:str, name:str):
        assert name in self.mmlu_all_sets
        self.name = name
        self.dataset = self.load(path, name)

    @staticmethod
    def load(path: str, name: str):
        dataset = DatasetDict()
        for split in ['dev', 'test']:
            raw_data = []
            filename = osp.join(path, split, f'{name}_{split}.csv')
            with open(filename, encoding='utf-8') as f:
                reader = csv.reader(f)
                for row in reader:
                    assert len(row) == 6
                    raw_data.append({
                        'input': row[0],
                        'A': row[1],
                        'B': row[2],
                        'C': row[3],
                        'D': row[4],
                        'target': row[5],
                    })
            dataset[split] = Dataset.from_list(raw_data)
        return dataset

    @property
    def test(self):
        return self.dataset['test']
    
    @property
    def fewshot(self):
        return self.dataset['dev']

    def get_fewshot_prefix(self, nshot=5):
        assert nshot > 0
        prefix = []
        for i in range(nshot):
            example, _ = self.format_read(split='dev', idx=i, with_target=True)
            prefix.append(example)
        return self.question_sep.join(prefix) + self.question_sep
    
    def get_question(self, idx):
        return self.format_read(split='test', idx=idx, with_target=False)

    def format_read(self, split:str, idx:int, with_target=False):
        item = self.dataset[split][idx]
        _hint = f'There is a single choice question about {self.name.replace("_", " ")}. Answer the question by replying A, B, C or D.'
        prompt = f"{_hint}\nQuestion: {item['input']}\nA. {item['A']}\nB. {item['B']}\nC. {item['C']}\nD. {item['D']}\nAnswer: "
        target = item['target']
        if with_target:
            prompt += f"\n{target}"
        return prompt, target
    
    def get_num_questions(self):
        return len(self.dataset['test'])

In [None]:
mmlu_root = '/mnt/lustrenew/zhulei1/ssd_cache/opencompass/data/mmlu'
high_school_european_history = MMLUDataset(mmlu_root, name='high_school_european_history')

print('------------\n')
print(high_school_european_history.fewshot)
print('------------\n')
print(high_school_european_history.fewshot[4]['input'])
print('------------\n')
print(high_school_european_history.get_fewshot_prefix(2))
print('------------\n')
print(high_school_european_history.get_question(4)[0])
print('------------\n')
print(high_school_european_history.question_sep.join(
    [high_school_european_history.get_fewshot_prefix(2), 
     high_school_european_history.get_question(4)[0]
    ]
))

In [None]:
from transformers import (AutoTokenizer,
                          PreTrainedTokenizerBase)
import numpy as np

model = '/mnt/cachenew2/zhulei1/huggingface/local/Llama-2-7b-hf'
mmlu_root = '/mnt/lustrenew/zhulei1/ssd_cache/opencompass/data/mmlu'
nshot = 5

tokenizer = AutoTokenizer.from_pretrained(
    model, trust_remote_code=True)

prefix_lens = []
prompt_lens = []
for subset in MMLUDataset.mmlu_all_sets:
    print(f"Processing {subset}")
    dataset = MMLUDataset(path=mmlu_root, name=subset)
    num_questions = dataset.get_num_questions()
    all_items = [ dataset.get_question(i) for i in range(num_questions) ]
    prompts = [ question for (question, _) in all_items ]
    prefix = dataset.get_fewshot_prefix(nshot)
    prefix_lens.append(len(tokenizer(prefix).input_ids))
    prompt_lens.extend([len(tokenizer(x).input_ids) for x in prompts])

In [None]:
prefix_min, prefix_max = np.min(prefix_lens), np.max(prefix_lens)
prompt_min, prompt_max = np.min(prompt_lens), np.max(prompt_lens)

print(prefix_min, prefix_max)
print(prompt_min, prompt_max)

In [None]:
import json
import os

save_dir = '../stat/dataset'

prefix_stat_file = os.path.join(save_dir, 'mmlu_5shotprefix_stat.json')
prompt_stat_file = os.path.join(save_dir, 'mmlu_prompt_stat.json')

with open(prefix_stat_file, 'w') as fp:
    json.dump(prefix_lens, fp)

with open(prompt_stat_file, 'w') as fp:
    json.dump(prompt_lens, fp)
