In [46]:
from datasets import load_dataset
import random
import json

## Test on math subsets

In [3]:
mmlu_math1 = load_dataset("cais/mmlu", "elementary_mathematics")
mmlu_math2 = load_dataset("cais/mmlu", "high_school_mathematics")
mmlu_math3 = load_dataset("cais/mmlu", "college_mathematics")

Generating test split: 100%|██████████| 378/378 [00:00<00:00, 14046.79 examples/s]
Generating validation split: 100%|██████████| 41/41 [00:00<00:00, 27781.34 examples/s]
Generating dev split: 100%|██████████| 5/5 [00:00<00:00, 4374.53 examples/s]
Generating test split: 100%|██████████| 270/270 [00:00<00:00, 87293.77 examples/s]
Generating validation split: 100%|██████████| 29/29 [00:00<00:00, 16428.26 examples/s]
Generating dev split: 100%|██████████| 5/5 [00:00<00:00, 3339.95 examples/s]
Generating test split: 100%|██████████| 100/100 [00:00<00:00, 28360.97 examples/s]
Generating validation split: 100%|██████████| 11/11 [00:00<00:00, 4101.46 examples/s]
Generating dev split: 100%|██████████| 5/5 [00:00<00:00, 2997.64 examples/s]


In [6]:
# Check the dataset format
print(mmlu_math1['test'][0])


{'question': 'What is the value of p in 24 = 2p?', 'subject': 'elementary_mathematics', 'choices': ['p = 4', 'p = 8', 'p = 12', 'p = 24'], 'answer': 2}


In [58]:
def gen_iccl_single_example(example):
    question = example['question'] + "\nChoose the best answer from the following options:" + "\n" + "\n".join([f"{i}. {option}" for i, option in enumerate(example['choices'])]) + "\nAnswer: "
    return question, example['answer']


In [60]:
print(gen_iccl_single_example(mmlu_math1['test'][0])[0])

What is the value of p in 24 = 2p?
Choose the best answer from the following options:
0. p = 4
1. p = 8
2. p = 12
3. p = 24
Answer: 


In [53]:
def gen_all_iccl_prompts(easy_dataset, medium_dataset, hard_dataset, n_prompts, topic):
    try:
        with open("iccl_prompts.json", 'r') as f:
            iccl_prompts = json.load(f)
    except FileNotFoundError:
        iccl_prompts = {}
    if topic not in iccl_prompts:
        iccl_prompts[topic] = {}
    
    # Convert datasets to lists for easier shuffling & tracking
    easy_examples = list(easy_dataset['test'])
    medium_examples = list(medium_dataset['test'])
    hard_examples = list(hard_dataset['test'])

    # Use indices to track used hard examples
    used_hard_indices = set()
    prompts = []
    for prompt_idx in range(n_prompts):
        # Curriculum examples for demonstrations
        demo_easy = random.choice(easy_examples)
        demo_medium = random.choice(medium_examples)
        demo_hard_index = random.choice([i for i in range(len(hard_examples)) if i not in used_hard_indices])
        demo_hard = hard_examples[demo_hard_index]

        # Can't use hard example that appeared in demonstration for testing
        used_hard_indices.add(demo_hard_index)
        available_hard_indices = [i for i in range(len(hard_examples)) if i not in used_hard_indices]
        if not available_hard_indices:
            raise ValueError("Not enough unique hard examples available")
        test_hard_index = random.choice(available_hard_indices)
        test_hard = hard_examples[test_hard_index]
        used_hard_indices.add(test_hard_index)

        demonstrations = [
            gen_iccl_single_example(demo_easy),
            gen_iccl_single_example(demo_medium),
            gen_iccl_single_example(demo_hard)
        ]
        test_question, test_answer = gen_iccl_single_example(test_hard)

        prompt = ""
        for demo_q, demo_a in demonstrations:
            prompt += f"{demo_q}" + str(demo_a) + "\n\n"
        prompt += f"{test_question}"

        iccl_prompts[topic][f"{prompt_idx}"] = {
            "question": prompt,
            "answer": test_answer
        }

    # Write the updated JSON back to the file
    with open("iccl_prompts.json", 'w') as f:
        json.dump(iccl_prompts, f, indent=4)


In [57]:
gen_all_iccl_prompts(mmlu_math1, mmlu_math2, mmlu_math3, int(len(mmlu_math3['test']) // 2), "math")

## Generate ICCL prompts for all topics