In [1]:
from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
import json
from search.python_utils import stringify, unstringify
from copy import deepcopy

In [2]:
data = load_dataset("google-research-datasets/mbpp")

In [3]:
all_data = []
for split in data:
    all_data.append(data[split])
comb_data = concatenate_datasets(all_data)

In [4]:
test_data = data["test"]

In [5]:
def parse_test(test: str):
    assert test.startswith("assert ")
    fn_name = test.split("assert ")[1].split("(")[0].strip()
    
    assert test.count("=") >= 2
    assert test.count("==") == 1
    # assert test.count(" == ") == 1
    
    start_idx = test.index(fn_name) + len(fn_name)  # Start after the function name and the opening parenthesis
    end_idx = test.index("==")  # End before the equality check
    parsed_input = test[start_idx:end_idx].strip()[1:-1]
    parsed_output = test[end_idx + len("=="):].strip()
   
    return (parsed_input, parsed_output, fn_name)

def parse_tests(tests: list[str]):
    inputs = []
    outputs = []
    
    fn_names = set()
    for test in tests:
        parsed_input, parsed_output, fn_name = parse_test(test)
        fn_names.add(fn_name)
        inputs.append(eval('[' + parsed_input + ']'))
        outputs.append(eval(parsed_output))
    
    assert len(fn_names) == 1, "Function names are not consistent across tests"
    fn_name = fn_names.pop()
    return {"inputs": inputs, "outputs": outputs, "fn_name": fn_name}

test_dicts = []
public_test_dicts = []
for i in range(len(test_data)):
    if i == 356:
        test_dicts.append({"inputs": [[True]], "outputs": [False], "fn_name": "is_tree_balanced"})
        public_test_dicts.append({"inputs": [[True]], "outputs": [False], "fn_name": "is_tree_balanced"})
    else:
        public_test_dicts.append(parse_tests(test_data["test_list"][i]))
        test_dicts.append(parse_tests(test_data["test_list"][i] # + test_data["challenge_test_list"][i]
                                      ))

In [50]:
[i for i, fn in enumerate(test_dicts) if fn["fn_name"] == "find_Product"]

[14]

In [51]:
def retrieve_starter_code(s: str, fn_name: str) -> str:
    starter_code = None
    for line in s.splitlines():
        if fn_name in line:
            starter_code = line.split(":")[0].strip() + ":\n    "
            break
    assert starter_code is not None and starter_code.startswith("def ")
    return starter_code

In [52]:
def publicify(dic: dict) -> dict:
    return {"inputs": dic["inputs"][:1], "outputs": dic["outputs"][:1], "fn_name": dic["fn_name"]}

In [10]:
from coderm.execution import smart_exec_tests_queuebatched
results = smart_exec_tests_queuebatched(test_data["code"], test_dicts)

Executing tests: 100%|██████████| 500/500 [00:06<00:00, 78.19it/s] 


In [17]:
test_dicts[351]

{'inputs': [[[1, 2, 3, 1, 2, 3, 12, 4, 2]],
  [[1, 2, 6, 7, 0, 1, 0, 1, 0]],
  [[1, 2, 3, 1, 2, 4, 1]]],
 'outputs': [2, (1, 0), 1],
 'fn_name': 'max_occurrences'}

In [18]:
test_data[425]

{'task_id': 436,
 'text': 'Write a python function to print negative numbers in a list.',
 'code': 'def neg_nos(list1):\r\n  for num in list1: \r\n    if num < 0: \r\n       return num ',
 'test_list': ['assert neg_nos([-1,4,5,-6]) == -1,-6',
  'assert neg_nos([-1,-2,3,4]) == -1,-2',
  'assert neg_nos([-7,-6,8,9]) == -7,-6'],
 'test_setup_code': '',
 'challenge_test_list': []}

In [15]:
assert (1, 2) == 1, 2

AssertionError: 2

In [11]:
[i for i, (stat, _) in enumerate(results) if not stat]

[302, 351, 356, 425]

In [53]:
newline = '\n'
new_test_data = deepcopy(test_data)
new_test_data= new_test_data.rename_column("task_id", "id")
new_test_data = new_test_data.map(lambda row, idx: {"starter_code": retrieve_starter_code(row["code"], test_dicts[idx]["fn_name"])}, with_indices=True)
new_test_data = new_test_data.map(lambda row, idx: {"solutions": [row["code"]]}, with_indices=True)
new_test_data = new_test_data.map(lambda row, idx: {"public_input_output": repr(public_test_dicts[idx])}, with_indices=True)
new_test_data = new_test_data.map(lambda row, idx: {"input_output": repr(test_dicts[idx])}, with_indices=True)
new_test_data = new_test_data.map(lambda row, idx: {"question": row['text'] + f"\n\nYour code should pass these tests:\n```\n{newline.join(row['test_list'])}\n```"}, with_indices=True)
new_test_data = new_test_data.remove_columns(["challenge_test_list", "text", "test_list", "test_setup_code", "code"])

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [55]:
dd = DatasetDict({"test": new_test_data})
dd.push_to_hub("codegenning/F_mbpp", private=True)

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/codegenning/F_mbpp/commit/42408286872a802ebc91c07f286fa81e03ec7b82', commit_message='Upload dataset', commit_description='', oid='42408286872a802ebc91c07f286fa81e03ec7b82', pr_url=None, pr_revision=None, pr_num=None)

In [57]:
[len(unstringify(test)["inputs"]) for test in load_dataset("codegenning/F_mbpp", split="test")["input_output"]]

[3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
