In [None]:
# Lazify Samples

In [3]:
import json, tqdm, random, concurrent
from datasets import load_dataset
from llms import generate_json
from functools import partial

with open("prompts/gsm8k_sharding_segmentation.txt", "r") as f:
    prompt_segment = f.read()

with open("prompts/gsm8k_sharding_conversational.txt", "r") as f:
    prompt_conversational = f.read()

with open("prompts/sharding_verification.txt", "r") as f:
    prompt_verification = f.read()

MODEL = "t-o3"

def shard_gsm8k(seed):

    prompt_extractive_populated = prompt_segment.replace("[[QUESTION]]", seed["question"])

    verification_status = None
    attempt = 0
    while verification_status != "complete":
        attempt += 1
        if attempt > 5:
            print(f"Failed to get valid segments after 5 attempts. Skipping {seed['id']}")
            break
        try:
            segments_res = generate_json([{"role": "user", "content": prompt_extractive_populated}], model=MODEL, step="sharding-gsm8k-extractive")

            # Now do the conversational version
            prompt_conversational_populated = prompt_conversational.replace("[[QUESTION]]", seed["question"]).replace("[[SEGMENTS]]", json.dumps(segments_res["segments"]))
            conversational_res = generate_json([{"role": "user", "content": prompt_conversational_populated}], model=MODEL, step="sharding-gsm8k-conversational")


            # Assign hint_id per item in hints, starting from 1
            shard_attempt = [{"shard_id": 1, "shard": conversational_res["initial_query"]}]
            for i, h in enumerate(conversational_res["hints"]):
                shard_attempt.append({"shard_id": i + 2, "shard": h["hint"]})

            # Now do the verification
            prompt_verification_populated = prompt_verification.replace("[[QUERY]]", seed["question"]).replace("[[SHARDS]]", json.dumps(shard_attempt))
            verification_res = generate_json([{"role": "user", "content": prompt_verification_populated}], model=MODEL, step="sharding-gsm8k-verification")

            verification_status = verification_res["coverage"]
        except Exception as e:
            print("Fatal error when generating segments: ", e)
            continue

    shards = [{"shard_id": 1, "shard": conversational_res["initial_query"]}]
    for i, h in enumerate(conversational_res["hints"]):
        shards.append({"shard_id": i + 2, "shard": h["hint"]})

    seed["task_id"] = f"sharded-gsm8k-train-{seed['id']}"
    seed["fully_specified_question"] = seed["question"]
    seed["shards"] = shards
    seed["model_name"] = MODEL

    return seed


def shard_data(output_path, shuffle=True, max_samples=500, num_workers=10):
    dataset = load_dataset("gsm8k", "main")
    seeds = list(dataset['train'])
    # add "id" to each sample
    for i, seed in enumerate(seeds):
        seed["id"] = i
        seed["task"] = "math"

    if shuffle:
        random.shuffle(seeds)
    
    seed_samples = seeds[:max_samples] if max_samples else seeds

    print("Total to convert: ", len(seed_samples))

    process_sample = partial(shard_gsm8k)

    # Process samples in parallel using ThreadPoolExecutor instead of ProcessPoolExecutor
    # This works better in Jupyter notebooks
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        # Use tqdm to show progress
        results = list(tqdm.tqdm(
            executor.map(process_sample, seed_samples),
            total=len(seed_samples),
            desc=f"Processing with {num_workers} threads"
        ))
    
    # Filter out None results and add to output data
    output_data = [result for result in results if result is not None and "shards" in result]
    print(f"Generated {len(output_data)} sharded instructions")

    with open(output_path, "w") as f:
        json.dump(output_data, f, indent=2)

    print(f"Processed {len(output_data)} samples successfully out of {len(seed_samples)} total samples")
    return output_data


In [5]:
converted = shard_data("data/sharded_gsm8k_train_400.json", shuffle=True, max_samples=400, num_workers=5)

Total to convert:  400


Processing with 5 threads:   0%|          | 2/400 [01:56<5:55:01, 53.52s/it]

Fatal error when generating segments:  list indices must be integers or slices, not str
Fatal error when generating segments:  list indices must be integers or slices, not str
Failed to get valid segments after 5 attempts. Skipping 404
Failed to get valid segments after 5 attempts. Skipping 84


Processing with 5 threads:   1%|          | 4/400 [04:58<8:27:32, 76.90s/it]

Failed to get valid segments after 5 attempts. Skipping 211


Processing with 5 threads:   5%|▌         | 20/400 [06:31<1:32:14, 14.56s/it]

Failed to get valid segments after 5 attempts. Skipping 2239
Failed to get valid segments after 5 attempts. Skipping 3954


Processing with 5 threads:   7%|▋         | 27/400 [06:45<1:02:37, 10.07s/it]

Failed to get valid segments after 5 attempts. Skipping 4413


Processing with 5 threads:   8%|▊         | 34/400 [08:22<1:15:11, 12.33s/it]

Failed to get valid segments after 5 attempts. Skipping 991


Processing with 5 threads:  12%|█▎        | 50/400 [10:28<1:14:02, 12.69s/it]

Failed to get valid segments after 5 attempts. Skipping 1423


Processing with 5 threads:  13%|█▎        | 52/400 [11:01<1:18:32, 13.54s/it]

Fatal error when generating segments:  list indices must be integers or slices, not str


Processing with 5 threads:  14%|█▍        | 55/400 [11:37<1:15:00, 13.05s/it]

Failed to get valid segments after 5 attempts. Skipping 7201


Processing with 5 threads:  16%|█▌        | 62/400 [13:37<1:24:07, 14.93s/it]

Failed to get valid segments after 5 attempts. Skipping 3535


Processing with 5 threads:  16%|█▌        | 63/400 [14:12<1:36:54, 17.25s/it]

Failed to get valid segments after 5 attempts. Skipping 3366


Processing with 5 threads:  19%|█▉        | 75/400 [14:41<40:38,  7.50s/it]  

Failed to get valid segments after 5 attempts. Skipping 4242


Processing with 5 threads:  19%|█▉        | 77/400 [16:57<1:38:12, 18.24s/it]

Failed to get valid segments after 5 attempts. Skipping 3156


Processing with 5 threads:  22%|██▏       | 87/400 [17:06<44:43,  8.57s/it]  

Failed to get valid segments after 5 attempts. Skipping 5490


Processing with 5 threads:  22%|██▎       | 90/400 [17:59<53:18, 10.32s/it]

Failed to get valid segments after 5 attempts. Skipping 638


Processing with 5 threads:  24%|██▍       | 96/400 [19:22<58:47, 11.61s/it]

Failed to get valid segments after 5 attempts. Skipping 5987


Processing with 5 threads:  26%|██▌       | 104/400 [20:37<52:40, 10.68s/it]

Failed to get valid segments after 5 attempts. Skipping 1459
Failed to get valid segments after 5 attempts. Skipping 3053


Processing with 5 threads:  27%|██▋       | 107/400 [22:19<1:12:22, 14.82s/it]

Failed to get valid segments after 5 attempts. Skipping 1363


Processing with 5 threads:  32%|███▏      | 127/400 [25:19<48:20, 10.63s/it]  

Failed to get valid segments after 5 attempts. Skipping 1944
Fatal error when generating segments:  list indices must be integers or slices, not str


Processing with 5 threads:  38%|███▊      | 151/400 [28:11<42:22, 10.21s/it]

Failed to get valid segments after 5 attempts. Skipping 2265


Processing with 5 threads:  39%|███▉      | 157/400 [29:16<42:10, 10.41s/it]

Failed to get valid segments after 5 attempts. Skipping 2376


Processing with 5 threads:  40%|████      | 161/400 [30:26<52:24, 13.16s/it]

Failed to get valid segments after 5 attempts. Skipping 2597


Processing with 5 threads:  42%|████▏     | 168/400 [31:49<48:45, 12.61s/it]

Fatal error when generating segments:  list indices must be integers or slices, not str


Processing with 5 threads:  43%|████▎     | 173/400 [33:17<53:41, 14.19s/it]

Failed to get valid segments after 5 attempts. Skipping 4880


Processing with 5 threads:  46%|████▌     | 183/400 [33:21<27:31,  7.61s/it]

Fatal error when generating segments:  list indices must be integers or slices, not str


Processing with 5 threads:  46%|████▌     | 184/400 [34:37<43:02, 11.96s/it]

Failed to get valid segments after 5 attempts. Skipping 1765


Processing with 5 threads:  47%|████▋     | 189/400 [35:36<41:58, 11.94s/it]

Failed to get valid segments after 5 attempts. Skipping 1558
Failed to get valid segments after 5 attempts. Skipping 777


Processing with 5 threads:  48%|████▊     | 191/400 [36:27<48:42, 13.98s/it]

Failed to get valid segments after 5 attempts. Skipping 3378


Processing with 5 threads:  48%|████▊     | 193/400 [37:19<55:26, 16.07s/it]

Failed to get valid segments after 5 attempts. Skipping 1246


Processing with 5 threads:  48%|████▊     | 194/400 [38:39<1:20:10, 23.35s/it]

Failed to get valid segments after 5 attempts. Skipping 606


Processing with 5 threads:  50%|█████     | 201/400 [39:50<54:03, 16.30s/it]  

Failed to get valid segments after 5 attempts. Skipping 6075


Processing with 5 threads:  51%|█████     | 204/400 [40:53<57:12, 17.51s/it]

Failed to get valid segments after 5 attempts. Skipping 5368


Processing with 5 threads:  53%|█████▎    | 213/400 [41:54<37:26, 12.02s/it]

Failed to get valid segments after 5 attempts. Skipping 1783


Processing with 5 threads:  55%|█████▌    | 220/400 [44:11<44:08, 14.72s/it]

Failed to get valid segments after 5 attempts. Skipping 2863
Failed to get valid segments after 5 attempts. Skipping 1428
Failed to get valid segments after 5 attempts. Skipping 7364
Failed to get valid segments after 5 attempts. Skipping 1170


Processing with 5 threads:  59%|█████▉    | 236/400 [47:40<35:03, 12.83s/it]

Failed to get valid segments after 5 attempts. Skipping 6715
Fatal error when generating segments:  'coverage'


Processing with 5 threads:  65%|██████▌   | 261/400 [53:13<38:57, 16.82s/it]

Failed to get valid segments after 5 attempts. Skipping 665
Fatal error when generating segments:  'coverage'
Failed to get valid segments after 5 attempts. Skipping 1901


Processing with 5 threads:  68%|██████▊   | 270/400 [55:43<36:15, 16.74s/it]

Failed to get valid segments after 5 attempts. Skipping 160
Fatal error when generating segments:  list indices must be integers or slices, not str


Processing with 5 threads:  73%|███████▎  | 293/400 [57:25<16:38,  9.33s/it]

Fatal error when generating segments:  list indices must be integers or slices, not str
Failed to get valid segments after 5 attempts. Skipping 1361


Processing with 5 threads:  74%|███████▎  | 294/400 [59:31<28:39, 16.22s/it]

Failed to get valid segments after 5 attempts. Skipping 1776
Failed to get valid segments after 5 attempts. Skipping 5671


Processing with 5 threads:  75%|███████▌  | 301/400 [1:00:57<24:23, 14.78s/it]

Fatal error when generating segments:  list indices must be integers or slices, not str
Failed to get valid segments after 5 attempts. Skipping 2191


Processing with 5 threads:  76%|███████▌  | 303/400 [1:00:59<20:56, 12.95s/it]

Failed to get valid segments after 5 attempts. Skipping 5179


Processing with 5 threads:  77%|███████▋  | 309/400 [1:02:37<23:29, 15.48s/it]

Failed to get valid segments after 5 attempts. Skipping 4149
Fatal error when generating segments:  list indices must be integers or slices, not str
Failed to get valid segments after 5 attempts. Skipping 817
Fatal error when generating segments:  list indices must be integers or slices, not str


Processing with 5 threads:  78%|███████▊  | 311/400 [1:07:07<57:19, 38.65s/it]

Failed to get valid segments after 5 attempts. Skipping 2456
Cannot parse
|{
    "initial_segment": "How many more pens will Alex have than Jane",
    "initial_query": "pens difference between Alex and Jane",
    "hints": [
        {"segment": "Every week her pen collection doubles.", "hint": "Alex doubles her pen collection each week"},
        {"segment": "Alex has 4 pens in the first week of a month.", "hint": "Alex starts the month with 4 pens"},
        {"segment": "Jane will have 16 pens after a month", "hint": "Jane ends the month with 16 pens"},
    ]
}|
Fatal error when generating segments:  'NoneType' object is not subscriptable


Processing with 5 threads:  82%|████████▏ | 328/400 [1:07:34<15:10, 12.65s/it]

Failed to get valid segments after 5 attempts. Skipping 1118


Processing with 5 threads:  86%|████████▌ | 342/400 [1:09:46<12:04, 12.49s/it]

Failed to get valid segments after 5 attempts. Skipping 5717


Processing with 5 threads:  86%|████████▌ | 343/400 [1:10:00<11:57, 12.59s/it]

Failed to get valid segments after 5 attempts. Skipping 1231


Processing with 5 threads:  90%|████████▉ | 358/400 [1:12:52<07:42, 11.01s/it]

Failed to get valid segments after 5 attempts. Skipping 667


Processing with 5 threads:  92%|█████████▏| 366/400 [1:14:13<07:02, 12.41s/it]

Failed to get valid segments after 5 attempts. Skipping 3032


Processing with 5 threads:  92%|█████████▏| 368/400 [1:15:38<09:48, 18.39s/it]

Failed to get valid segments after 5 attempts. Skipping 6981


Processing with 5 threads:  94%|█████████▎| 374/400 [1:16:11<05:23, 12.42s/it]

Failed to get valid segments after 5 attempts. Skipping 123


Processing with 5 threads:  97%|█████████▋| 387/400 [1:18:37<02:48, 12.99s/it]

Failed to get valid segments after 5 attempts. Skipping 1304


Processing with 5 threads: 100%|██████████| 400/400 [1:20:16<00:00, 12.04s/it]

Failed to get valid segments after 5 attempts. Skipping 1149
Generated 400 sharded instructions
Processed 400 samples successfully out of 400 total samples





In [4]:
import json
with open('data/sharded_gsm8k_train_400.json', 'r') as f:
    data = json.load(f)

# {'question': 'Saturday at the ice cream shop, there were twice as many people who ordered vanilla ice cream as ordered chocolate ice cream.  If 220 people ordered ice cream on Saturday, and 20% of those ordered vanilla ice cream, how many people ordered chocolate ice cream?', 'answer': 'If 20% of 220 customers ordered vanilla, then there were 0.2*220=<<0.2*220=44>>44 customers who ordered vanilla.\nIf there were twice as many people who ordered vanilla as chocolate, then 44/2=<<44/2=22>>22 customers ordered chocolate.\n#### 22', 'id': 2765, 'task': 'math', 'task_id': 'sharded-gsm8k-train-2765', 'fully_specified_question': 'Saturday at the ice cream shop, there were twice as many people who ordered vanilla ice cream as ordered chocolate ice cream.  If 220 people ordered ice cream on Saturday, and 20% of those ordered vanilla ice cream, how many people ordered chocolate ice cream?', 'shards': [{'shard_id': 2, 'shard': 'I see that twenty percent of the customers chose vanilla ice cream'}, {'shard_id': 3, 'shard': 'The vanilla orders are double the chocolate orders'}, {'shard_id': 1, 'shard': 'number of chocolate ice cream orders'}, {'shard_id': 4, 'shard': 'I know 220 people ordered ice cream on Saturday'}], 'model_name': 't-o3', 'verifications': {'full-avg': 1.0, 'full-all': {'t-gpt-4.1': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, 'concat-avg': 1.0, 'concat-all': {'t-gpt-4.1': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, 'shuffle-concat-avg': 1.0, 'shuffle-concat-all': {'t-gpt-4.1': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, 'sharded-avg': 0.875, 'sharded-all': {'t-gpt-4.1': [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]}}}
# select samples with above 0.7 on concat, full and shuffle-concat, and then any number between 0.0 < sharded < 1.0

for sample in data:
    validation_keys = ['full-avg', 'concat-avg', 'shuffle-concat-avg']
    is_valid = all(sample['verifications'][key] >= 0.7 for key in validation_keys)
    shard_ok = 0.0 < sample['verifications']['sharded-avg'] < 0.8
    sample["selected"] = is_valid and shard_ok

selected_samples = [sample for sample in data if sample["selected"]]

print(len(selected_samples))

# save the valid samples to a new file
with open('data/sharded_gsm8k_train_400_selected.json', 'w') as f:
    json.dump(selected_samples, f, indent=4)

67


In [2]:
# merge with code
from collections import Counter
import json

output_fn = "data/sharded_train_synthetic_0.1.json"
synthetic_fns = ["data/code_synthetic_problems_0.1_final_transformed.json", "data/sharded_gsm8k_train_400_selected.json"]

all_samples = []

for fn in synthetic_fns:
    with open(fn, "r") as f:
        data = json.load(f)
    all_samples += data

all_samples = [s for s in all_samples if len(s["shards"]) == 4]

print(Counter([sample["task"] for sample in all_samples]))
print(Counter([len(sample["shards"]) for sample in all_samples]))

with open(output_fn, "w") as f:
    json.dump(all_samples, f, indent=4)

Counter({'math': 66, 'code': 43})
Counter({4: 109})
