Datasets to make:

- [x] cola
- [x] qnli
- [x] qqp
- [x] sst2
- [x] ag_news
- [x] commonsense_qa
- [x] mnli
- [x] mmlu


In [39]:
import datasets
from datasets import load_dataset, concatenate_datasets
from promptsource.templates import DatasetTemplates
from pathlib import Path
import json
from tqdm import tqdm
from collections import Counter

In [40]:
def load_prompt(dataset_name, config=None, prompt_idx=0):
    """this function loads a test prompt for a specified dataset to see if promptSource supports it"""
    all_prompts = DatasetTemplates(dataset_name, config) if config != None else DatasetTemplates(dataset_name)
    prompt_name_list = list(all_prompts.name_to_id_mapping.keys())
    prompt = all_prompts[prompt_name_list[prompt_idx]]
    return prompt 

## cola

In [41]:
# Function to add the original split name to each example
def add_original_split(example, split_name):
    example['original_dataset_subset'] = split_name
    return example

In [42]:
cola_dataset = load_dataset("nyu-mll/glue","cola")
print(cola_dataset)

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})


In [43]:
for split in ['train', 'validation', 'test']:
    cola_dataset[split] = cola_dataset[split].map(lambda x: add_original_split(x, split))

In [44]:
print(cola_dataset)

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'original_dataset_subset'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'original_dataset_subset'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'original_dataset_subset'],
        num_rows: 1063
    })
})


In [63]:
cola_dataset = cola_dataset.shuffle(seed=17)

In [64]:
cola_prompt_template = load_prompt("glue", "cola", 0)

In [65]:
train_set = []
num_training_samples = 1000

for i in range(num_training_samples):
    dataset_element = cola_dataset['train'][i]
    input_txt, output_txt = cola_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_idx": dataset_element['idx']
    }
    
    train_set.append(dataset_obj)

In [66]:
val_set = []
num_eval_samples = 5000

for i in range(num_eval_samples):
    dataset_element = cola_dataset['train'][num_training_samples + i]
    input_txt, output_txt = cola_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_idx": dataset_element['idx']
    }
    
    val_set.append(dataset_obj)

In [67]:
combined_cola_dataset = {
    "train": train_set,
    "validation": val_set
}

save_path = Path(f"datasets/cola.json")
with save_path.open("w") as f:
    json.dump(combined_cola_dataset, f)

In [68]:
# show the label distribution of both the train and validation sets
cola_train_labels = [x['output'] for x in combined_cola_dataset["train"]]
cola_validation_labels = [x['output'] for x in combined_cola_dataset["validation"]]


cola_train_counter = Counter(cola_train_labels)
cola_validation_counter = Counter(cola_validation_labels)

print(cola_train_counter)
print(cola_validation_counter)

Counter({'yes': 717, 'no': 283})
Counter({'yes': 3540, 'no': 1460})


# qnli

In [70]:
# Function to add the original split name to each example
def add_original_split(example, split_name):
    example['original_dataset_subset'] = split_name
    return example

In [71]:
qnli_dataset = load_dataset("nyu-mll/glue","qnli")
print(qnli_dataset)

DatasetDict({
    train: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 104743
    })
    validation: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 5463
    })
    test: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 5463
    })
})


In [72]:
for split in ['train', 'validation', 'test']:
    qnli_dataset[split] = qnli_dataset[split].map(lambda x: add_original_split(x, split))

In [73]:
qnli_dataset = qnli_dataset.shuffle(seed=42)

In [74]:
qnli_prompt_template = load_prompt("glue", "qnli", 0)

In [75]:
train_set = []
num_training_samples = 1000

for i in range(num_training_samples):
    dataset_element = qnli_dataset['train'][i]
    input_txt, output_txt = qnli_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_idx": dataset_element['idx']
    }
    
    train_set.append(dataset_obj)

In [76]:
qnli_val_set = []
num_eval_samples = 5000

for i in range(num_eval_samples):
    dataset_element = qnli_dataset['validation'][i]
    input_txt, output_txt = qnli_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_idx": dataset_element['idx']
    }
    
    qnli_val_set.append(dataset_obj)

In [77]:
combined_qnli_dataset = {
    "train": train_set,
    "validation": qnli_val_set
}

save_path = Path(f"datasets/qnli.json")
with save_path.open("w") as f:
    json.dump(combined_qnli_dataset, f)

In [78]:
# show the label distribution of both the train and validation sets
qnli_train_labels = [x['output'] for x in combined_qnli_dataset["train"]]
qnli_validation_labels = [x['output'] for x in combined_qnli_dataset["validation"]]


qnli_train_counter = Counter(qnli_train_labels)
qnli_validation_counter = Counter(qnli_validation_labels)

print(qnli_train_counter)
print(qnli_validation_counter)

Counter({'yes': 507, 'no': 493})
Counter({'no': 2501, 'yes': 2499})


# qqp

In [81]:
# Function to add the original split name to each example
def add_original_split(example, split_name):
    example['original_dataset_subset'] = split_name
    return example

In [82]:
qqp_dataset = load_dataset("nyu-mll/glue","qqp")
print(qqp_dataset)

DatasetDict({
    train: Dataset({
        features: ['question1', 'question2', 'label', 'idx'],
        num_rows: 363846
    })
    validation: Dataset({
        features: ['question1', 'question2', 'label', 'idx'],
        num_rows: 40430
    })
    test: Dataset({
        features: ['question1', 'question2', 'label', 'idx'],
        num_rows: 390965
    })
})


In [83]:
for split in ['train', 'validation', 'test']:
    qqp_dataset[split] = qqp_dataset[split].map(lambda x: add_original_split(x, split))

In [84]:
qqp_dataset = qqp_dataset.shuffle(seed=42)

In [85]:
qqp_prompt_template = load_prompt("glue", "qqp", 0)

In [86]:
qqp_train_set = []
num_training_samples = 1000

for i in range(num_training_samples):
    dataset_element = qqp_dataset['train'][i]
    input_txt, output_txt = qqp_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_idx": dataset_element['idx']
    }
    
    qqp_train_set.append(dataset_obj)

In [87]:
qqp_val_set = []
num_eval_samples = 5000

for i in range(num_eval_samples):
    dataset_element = qqp_dataset['validation'][i]
    input_txt, output_txt = qqp_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_idx": dataset_element['idx']
    }
    
    qqp_val_set.append(dataset_obj)

In [88]:
combined_qqp_dataset = {
    "train": qqp_train_set,
    "validation": qqp_val_set
}

save_path = Path(f"datasets/qqp.json")
with save_path.open("w") as f:
    json.dump(combined_qqp_dataset, f)

In [89]:
# show the label distribution of both the train and validation sets
qqp_train_labels = [x['output'] for x in combined_qqp_dataset["train"]]
qqp_validation_labels = [x['output'] for x in combined_qqp_dataset["validation"]]


qqp_train_counter = Counter(qqp_train_labels)
qqp_validation_counter = Counter(qqp_validation_labels)

print(qqp_train_counter)
print(qqp_validation_counter)

Counter({'no': 628, 'yes': 372})
Counter({'no': 3147, 'yes': 1853})


# sst2

In [90]:
# Function to add the original split name to each example
def add_original_split(example, split_name):
    example['original_dataset_subset'] = split_name
    return example

In [91]:
sst2_dataset = load_dataset("nyu-mll/glue","sst2")
print(sst2_dataset)

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})


In [92]:
for split in ['train', 'validation', 'test']:
    sst2_dataset[split] = sst2_dataset[split].map(lambda x: add_original_split(x, split))

In [93]:
sst2_dataset = sst2_dataset.shuffle(seed=42)

In [94]:
sst2_prompt_template = load_prompt("glue", "sst2", 0)

In [95]:
sst2_train_set = []
num_training_samples = 1000

for i in range(num_training_samples):
    dataset_element = sst2_dataset['train'][i]
    input_txt, output_txt = sst2_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_idx": dataset_element['idx']
    }
    
    sst2_train_set.append(dataset_obj)

In [96]:
sst2_val_set = []
num_eval_samples = 5000

for i in range(num_eval_samples):
    dataset_element = sst2_dataset['train'][i + num_training_samples]
    input_txt, output_txt = sst2_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_idx": dataset_element['idx']
    }
    
    sst2_val_set.append(dataset_obj)

In [97]:
combined_sst2_dataset = {
    "train": sst2_train_set,
    "validation": sst2_val_set
}

save_path = Path(f"datasets/sst2.json")
with save_path.open("w") as f:
    json.dump(combined_sst2_dataset, f)

In [98]:
# show the label distribution of both the train and validation sets
sst2_train_labels = [x['output'] for x in combined_sst2_dataset["train"]]
sst2_validation_labels = [x['output'] for x in combined_sst2_dataset["validation"]]


sst2_train_counter = Counter(sst2_train_labels)
sst2_validation_counter = Counter(sst2_validation_labels)

print(sst2_train_counter)
print(sst2_validation_counter)

Counter({'positive': 557, 'negative': 443})
Counter({'positive': 2747, 'negative': 2253})


In [99]:
print(combined_sst2_dataset["train"][0])

{'input': 'klein , charming in comedies like american pie and dead-on in election , \nQuestion: Was that sentence positive or negative? Answer:', 'output': 'positive', 'combined': 'klein , charming in comedies like american pie and dead-on in election , \nQuestion: Was that sentence positive or negative? Answer:\npositive', 'original_dataset_subset': 'train', 'original_idx': 32326}


# ag_news

In [100]:
# Function to add the original split name to each example
def add_original_split(example, split_name):
    example['original_dataset_subset'] = split_name
    return example

In [102]:
ag_news_dataset = load_dataset("fancyzhx/ag_news")
print(ag_news_dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})


In [104]:
for split in ['train', 'test']:
    ag_news_dataset[split] = ag_news_dataset[split].map(lambda x: add_original_split(x, split))

Map:   0%|          | 0/120000 [00:00<?, ? examples/s]

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]

In [105]:
ag_news_dataset = ag_news_dataset.shuffle(seed=42)

In [106]:
ag_news_prompt_template = load_prompt("ag_news", 0)

In [108]:
ag_news_train_set = []
num_training_samples = 1000

for i in range(num_training_samples):
    dataset_element = ag_news_dataset['train'][i]
    input_txt, output_txt = ag_news_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        # "original_idx": dataset_element['idx']
    }
    
    ag_news_train_set.append(dataset_obj)

In [110]:
ag_news_val_set = []
num_eval_samples = 5000

for i in range(num_eval_samples):
    dataset_element = ag_news_dataset['test'][i]
    input_txt, output_txt = ag_news_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        # "original_idx": dataset_element['idx']
    }
    
    ag_news_val_set.append(dataset_obj)

In [111]:
combined_ag_news_dataset = {
    "train": ag_news_train_set,
    "validation": ag_news_val_set
}

save_path = Path(f"datasets/ag_news.json")
with save_path.open("w") as f:
    json.dump(combined_ag_news_dataset, f)

In [112]:
# show the label distribution of both the train and validation sets
ag_news_train_labels = [x['output'] for x in combined_ag_news_dataset["train"]]
ag_news_validation_labels = [x['output'] for x in combined_ag_news_dataset["validation"]]


ag_news_train_counter = Counter(ag_news_train_labels)
ag_news_validation_counter = Counter(ag_news_validation_labels)

print(ag_news_train_counter)
print(ag_news_validation_counter)

Counter({'Science and technology': 271, 'World politics': 244, 'Sports': 243, 'Business': 242})
Counter({'Sports': 1263, 'World politics': 1255, 'Business': 1248, 'Science and technology': 1234})


# mnli

In [164]:
# Function to add the original split name to each example
def add_original_split(example, split_name):
    example['original_dataset_subset'] = split_name
    return example

In [165]:
mnli_dataset = load_dataset("nyu-mll/glue","mnli")
print(mnli_dataset)

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9847
    })
})


In [166]:
for split in ['train', 'validation_matched']:
    mnli_dataset[split] = mnli_dataset[split].map(lambda x: add_original_split(x, split))

In [167]:
mnli_dataset = mnli_dataset.shuffle(seed=42)

In [168]:
mnli_prompt_template = load_prompt("glue/mnli", prompt_idx=2)

In [169]:
mnli_train_set = []
num_training_samples = 1000

for i in range(num_training_samples):
    dataset_element = mnli_dataset['train'][i]
    input_txt, output_txt = mnli_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
    }
    
    mnli_train_set.append(dataset_obj)

In [170]:
mnli_val_set = []
num_eval_samples = 5000

for i in range(num_eval_samples):
    dataset_element = mnli_dataset['validation_matched'][i]
    input_txt, output_txt = mnli_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
    }
    
    mnli_val_set.append(dataset_obj)

In [171]:
combined_mnli_dataset = {
    "train": mnli_train_set,
    "validation": mnli_val_set
}

save_path = Path(f"datasets/mnli.json")
with save_path.open("w") as f:
    json.dump(combined_mnli_dataset, f)

In [172]:
# show the label distribution of both the train and validation sets
mnli_train_labels = [x['output'] for x in combined_mnli_dataset["train"]]
mnli_validation_labels = [x['output'] for x in combined_mnli_dataset["validation"]]


mnli_train_counter = Counter(mnli_train_labels)
mnli_validation_counter = Counter(mnli_validation_labels)

print(mnli_train_counter)
print(mnli_validation_counter)

Counter({'Maybe': 363, 'Yes': 337, 'No': 300})
Counter({'Yes': 1765, 'Maybe': 1629, 'No': 1606})


# commonsense_qa

In [189]:
# Function to add the original split name to each example
def add_original_split(example, split_name):
    example['original_dataset_subset'] = split_name
    return example

In [190]:
commonsense_qa_dataset = load_dataset("tau/commonsense_qa")
print(commonsense_qa_dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 9741
    })
    validation: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 1221
    })
    test: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 1140
    })
})


In [191]:
for split in ['train']:
    commonsense_qa_dataset[split] = commonsense_qa_dataset[split].map(lambda x: add_original_split(x, split))

In [192]:
commonsense_qa_dataset = commonsense_qa_dataset.shuffle(seed=42)

In [193]:
commonsense_qa_prompt_template = load_prompt("commonsense_qa", prompt_idx=2)

In [194]:
commonsense_qa_prompt_template.apply(commonsense_qa_dataset['train'][0])

["Given the following options, what do you think is the correct answer to the question below:\n\nI needed to send a piece of mail, where did I go?\n\nOptions:\n\n- A: table\n\n- B: post office\n\n- C: neighbor's house\n\n- D: railway station",
 'B']

In [195]:
commonsense_qa_dataset = commonsense_qa_dataset.shuffle(seed=42)

In [196]:
commonsense_qa_train_set = []
num_training_samples = 1000

for i in range(num_training_samples):
    dataset_element = commonsense_qa_dataset['train'][i]
    input_txt, output_txt = commonsense_qa_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_id": dataset_element['id']
    }
    
    commonsense_qa_train_set.append(dataset_obj)

In [197]:
commonsense_qa_val_set = []
num_eval_samples = 5000

for i in range(num_eval_samples):
    dataset_element = commonsense_qa_dataset['train'][i + num_training_samples]
    input_txt, output_txt = commonsense_qa_prompt_template.apply(dataset_element)
    
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "original_id": dataset_element['id']

    }
    
    commonsense_qa_val_set.append(dataset_obj)

In [198]:
combined_commonsense_qa_dataset = {
    "train": commonsense_qa_train_set,
    "validation": commonsense_qa_val_set
}

save_path = Path(f"datasets/commonsense_qa.json")
with save_path.open("w") as f:
    json.dump(combined_commonsense_qa_dataset, f)

In [199]:
# show the label distribution of both the train and validation sets
commonsense_qa_train_labels = [x['output'] for x in combined_commonsense_qa_dataset["train"]]
commonsense_qa_validation_labels = [x['output'] for x in combined_commonsense_qa_dataset["validation"]]


commonsense_qa_train_counter = Counter(commonsense_qa_train_labels)
commonsense_qa_validation_counter = Counter(commonsense_qa_validation_labels)

print(commonsense_qa_train_counter)
print(commonsense_qa_validation_counter)

Counter({'D': 211, 'B': 210, 'A': 203, 'C': 198, 'E': 178})
Counter({'D': 1038, 'B': 1013, 'C': 995, 'E': 988, 'A': 966})


In [200]:
print(commonsense_qa_train_set[0]['combined'])

Given the following options, what do you think is the correct answer to the question below:

A school is necessary for every one of these. What are they?

Options:

- A: every city

- B: community

- C: playground

- D: residential neighborhood
B


# mmlu

In [251]:
# Function to add the original split name to each example
def add_original_split(example, split_name):
    example['original_dataset_subset'] = split_name
    return example

In [252]:
mmlu_dataset = load_dataset("cais/mmlu", 'all')
print(mmlu_dataset)

DatasetDict({
    test: Dataset({
        features: ['question', 'subject', 'choices', 'answer'],
        num_rows: 14042
    })
    validation: Dataset({
        features: ['question', 'subject', 'choices', 'answer'],
        num_rows: 1531
    })
    dev: Dataset({
        features: ['question', 'subject', 'choices', 'answer'],
        num_rows: 285
    })
    auxiliary_train: Dataset({
        features: ['question', 'subject', 'choices', 'answer'],
        num_rows: 99842
    })
})


In [253]:
for split in ['test']:
    mmlu_dataset[split] = mmlu_dataset[split].map(lambda x: add_original_split(x, split))

In [254]:
mmlu_dataset = mmlu_dataset.shuffle(seed=50)

In [255]:
print(mmlu_dataset['test'][0])

{'question': "The demand curve for a perfectly competitive firm's product is", 'subject': 'high_school_microeconomics', 'choices': ['downward sloping and equal to the market demand curve.', 'perfectly elastic.', 'perfectly inelastic.', 'kinked at the going market price.'], 'answer': 1, 'original_dataset_subset': 'test'}


In [256]:
mmlu_train_set = []
num_training_samples = 1000

for i in range(num_training_samples):
    dataset_element = mmlu_dataset['test'][i]
    
    # generate input txt and output txt
    letters = ['A', 'B', 'C', 'D']
    choices = dataset_element['choices']
    
    input_txt = f"{dataset_element['question']}\n\nA: {choices[0]}\nB: {choices[1]}\nC: {choices[2]}\nD: {choices[3]}\nAnswer:"
    
    answer_idx = dataset_element['answer']
    output_txt = letters[answer_idx]
    
        
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "subject": dataset_element['subject']
    }
    
    mmlu_train_set.append(dataset_obj)

In [257]:
print(mmlu_train_set[0]['combined'])

The demand curve for a perfectly competitive firm's product is

A: downward sloping and equal to the market demand curve.
B: perfectly elastic.
C: perfectly inelastic.
D: kinked at the going market price.
Answer:
B


In [258]:
mmlu_val_set = []
num_val_samples = 5000

for i in range(num_val_samples):
    dataset_element = mmlu_dataset['test'][i + num_training_samples]
    
    # generate input txt and output txt
    letters = ['A', 'B', 'C', 'D']
    choices = dataset_element['choices']
    
    input_txt = f"{dataset_element['question']}\n\nA: {choices[0]}\nB: {choices[1]}\nC: {choices[2]}\nD: {choices[3]}\nAnswer:"
    
    answer_idx = dataset_element['answer']
    output_txt = letters[answer_idx]
    
        
    dataset_obj = {
        "input": input_txt,
        "output": output_txt,
        "combined": input_txt + "\n" + output_txt,
        "original_dataset_subset": dataset_element['original_dataset_subset'],
        "subject": dataset_element['subject']
    }
    
    mmlu_val_set.append(dataset_obj)

In [259]:
combined_mmlu_dataset = {
    "train": mmlu_train_set,
    "validation": mmlu_val_set
}

save_path = Path(f"datasets/mmlu.json")
with save_path.open("w") as f:
    json.dump(combined_mmlu_dataset, f)

In [260]:
mmlu_train_labels = [x['output'] for x in combined_mmlu_dataset["train"]]
mmlu_validation_labels = [x['output'] for x in combined_mmlu_dataset["validation"]]


mmlu_train_counter = Counter(mmlu_train_labels)
mmlu_validation_counter = Counter(mmlu_validation_labels)

print(mmlu_train_counter)
print(mmlu_validation_counter)

Counter({'B': 263, 'C': 252, 'D': 243, 'A': 242})
Counter({'D': 1346, 'C': 1281, 'B': 1207, 'A': 1166})
