In [None]:
from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from copy import deepcopy
from search.base_classes import Test, Problem
from search.exec_utils import run_tests_per_code
from coderm.execution import smart_exec_tests_queuebatched
import json

In [None]:
data = load_dataset("openai/openai_humaneval")

In [None]:
d = data["test"]
new116 = d["prompt"][116].replace(
    ">>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]\n"
    "    >>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]\n"
    "    >>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]\n",
    ">>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 4, 3, 5]\n"
    "    >>> sort_array([-2, -3, -4, -5, -6]) == [-4, -2, -6, -5, -3]\n"
    "    >>> sort_array([1, 0, 2, 3, 4]) == [0, 1, 2, 4, 3]\n",
)
new47 = d["prompt"][47].replace("median([-10, 4, 6, 1000, 10, 20])\n    15.0", "median([-10, 4, 6, 1000, 10, 20])\n    8.0")
new_map = {47: new47, 116: new116}

d = d.map(lambda row, idx: {"prompt": row["prompt"] if idx not in new_map else new_map[idx]}, with_indices=True)

In [None]:
asserts = [[line.strip() for line in test.splitlines() if line.strip().startswith("assert")] for test in d["test"]]

In [None]:
def match_parens(s: str) -> str:
    assert s[0] == '('
    curr_count = 1
    for i in range(1, len(s)):
        if s[i] == ')':
            curr_count -= 1
        elif s[i] == '(':
            curr_count += 1
        if curr_count == 0:
            return s[:i+1]

def parse_input(fn_call: str, fn_name: str) -> list:
    assert fn_call.strip().startswith(fn_name), "Has to start with fn_name"
    wo_fn_name = fn_call.split(fn_name)[1].strip()
    assert wo_fn_name[0] == "(" and wo_fn_name[-1] == ")", "Missing parens"
    return eval('[' + wo_fn_name[1:-1] + ']')

In [None]:
def parse_public(question: str, entrypoint: str):
    tests = []
    lines = question.splitlines()
    
    do_arrow = all("==" not in line for line in question.splitlines() if ">>>" in line) and ">>>" in question
    if do_arrow:
        lines = question.splitlines()
        new_test = []
        for line in lines:
            if ">>>" in line:
                fn_call = line.split(">>>")[1].strip()
                new_test.append(fn_call)
            else:
                if len(new_test) == 1:
                    new_test.append(line.strip())
                    tests.append(new_test)
                    new_test = []
    else:
        POSS = ["should return", "# returns", "returns", " = "]
        question = question.replace("==>", "==").replace("=>", "==").replace("->", "==").replace("➞", "==")
        lines = question.splitlines()
        for line in lines:
            if "==" in line:
                tests.append([line.split("==")[0].strip(), line.split("==")[1].strip()])
            else:
                for poss in POSS:
                    if entrypoint + "(" in line and poss in line and entrypoint + "()" not in line:
                        tests.append([line.split(poss)[0].strip(), line.split(poss)[1].strip()])
                        break
        
    assert len(tests)
    return tests

public_tests = []
for i, row in enumerate(d):
    prompt = row["prompt"]
    if i in [38, 41, 50, 83]:
        public_tests.append({"inputs": [], "outputs": [], "fn_name": row["entry_point"]})
        continue
    if i == 67:
        public_tests.append({"inputs": [['5 apples and 6 oranges', 19], ['0 apples and 1 oranges', 3], ['2 apples and 3 oranges', 100], ['100 apples and 1 oranges', 120]], "outputs": [8, 2, 95, 19], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "fruit_distribution"
        # ["fruit_distribution('5 apples and 6 oranges', 19)", 8]
        # ["fruit_distribution('0 apples and 1 oranges', 3)", 2]
        # ["fruit_distribution('2 apples and 3 oranges', 100)", 95]
        # ["fruit_distribution('100 apples and 1 oranges', 120)", 19]
        continue
    if i == 68:
        public_tests.append({"inputs": [[[4, 2, 3]], [[1, 2, 3]], [[]], [[5, 0, 3, 0, 4, 2]]], "outputs": [[2, 1], [2, 1], [], [0, 1]], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "pluck"
        # ["pluck([4, 2, 3])", [2, 1]]
        # ["pluck([1, 2, 3])", [2, 1]]
        # ["pluck([])", []]
        # ["pluck([5, 0, 3, 0, 4, 2])", [0, 1]]
        continue
    if i == 78:
        public_tests.append({"inputs": [['AB'], ['1077E'], ['ABED1A33'], ['123456789ABCDEF0'], ['2020']], "outputs": [1, 2, 4, 6, 2], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "hex_key"
        # ["hex_key('AB')", 1]
        # ["hex_key('1077E')", 2]
        # ["hex_key('ABED1A33')", 4]
        # ["hex_key('123456789ABCDEF0')", 6]
        # ["hex_key('2020')", 2]
        continue
    if i == 80:
        public_tests.append({"inputs": [['a'], ['aa'], ['abcd'], ['aabb'], ['adb'], ['xyy']], "outputs": [False, False, True, False, True, False], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "is_happy"
        # ["is_happy('a')", False]
        # ["is_happy('aa')", False]
        # ["is_happy('abcd')", True]
        # ["is_happy('aabb')", False]
        # ["is_happy('adb')", True]
        # ["is_happy('xyy')", False]
        continue
    if i == 81:
        public_tests.append({"inputs": [[[4.0, 3, 1.7, 2, 3.5]]], "outputs": [['A+', 'B', 'C-', 'C', 'A-']], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "numerical_letter_grade"
        # ["grade_equation([4.0, 3, 1.7, 2, 3.5])", ['A+', 'B', 'C-', 'C', 'A-']]
        continue
    if i == 84:
        public_tests.append({"inputs": [[1000], [150], [147]], "outputs": ["1", "110", "1100"], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "solve"
        # ["solve(1000)", "1"]
        # ["solve(150)", "110"]
        # ["solve(147)", "1100"]
        continue
    if i == 87:
        public_tests.append({"inputs": [[[
      [1,2,3,4,5,6],
      [1,2,3,4,1,6],
      [1,2,3,4,5,1]
    ], 1], [[], 1], [[[], [1], [1, 2, 3]], 3]], "outputs": [[(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)], [], [(2, 2)]], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "get_row"
        #     get_row([
        #   [1,2,3,4,5,6],
        #   [1,2,3,4,1,6],
        #   [1,2,3,4,5,1]
        # ], 1) == [(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)]
        # get_row([], 1) == []
        # get_row([[], [1], [1, 2, 3]], 3) == [(2, 2)]
        continue
    if i == 94:
        public_tests.append({"inputs": [[[0,3,2,1,3,5,7,4,5,5,5,2,181,32,4,32,3,2,32,324,4,3]], [[1,0,1,8,2,4597,2,1,3,40,1,2,1,2,4,2,5,1]], [[1,3,1,32,5107,34,83278,109,163,23,2323,32,30,1,9,3]], [[0,724,32,71,99,32,6,0,5,91,83,0,5,6]], [[0,81,12,3,1,21]], [[0,8,1,2,1,7]]], "outputs": [10, 25, 13, 11, 3, 7], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "skjkasdkd"
        # ["skjkasdkd([0,3,2,1,3,5,7,4,5,5,5,2,181,32,4,32,3,2,32,324,4,3])", 10]
        # ["skjkasdkd([1,0,1,8,2,4597,2,1,3,40,1,2,1,2,4,2,5,1])", 25]
        # ["skjkasdkd([1,3,1,32,5107,34,83278,109,163,23,2323,32,30,1,9,3])", 13]
        # ["skjkasdkd([0,724,32,71,99,32,6,0,5,91,83,0,5,6])", 11]
        # ["skjkasdkd([0,81,12,3,1,21])", 3]
        # ["skjkasdkd([0,8,1,2,1,7])", 7]
        continue
    if i == 105:
        public_tests.append({"inputs": [[[2, 1, 1, 4, 5, 8, 2, 3]], [[]], [[1, -1 , 55]]], "outputs": [["Eight", "Five", "Four", "Three", "Two", "Two", "One", "One"], [], ["One"]], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "by_length"
        # ["by_length([2, 1, 1, 4, 5, 8, 2, 3])", ["Eight", "Five", "Four", "Three", "Two", "Two", "One", "One"]]
        # ["by_length([])", []]
        # ["by_length([1, -1 , 55])", ["One"]]
        continue
    if i == 107:
        public_tests.append({"inputs": [[3], [12]], "outputs": [(1, 2), (4, 6)], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "even_odd_palindrome"
        # ["even_odd_palindrome(3)", (1, 2)]
        # ["even_odd_palindrome(12)", (4, 6)]
        continue
    if i == 112:
        public_tests.append({"inputs": [['abcde', 'ae'], ['abcdef', 'b'], ['abcdedcba', 'ab']], "outputs": [('bcd', False), ('acdef', False), ('cdedc', True)], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "reverse_delete"
        # ["reverse_delete('abcde', 'ae')", ('bcd', False)]
        # ["reverse_delete('abcdef', 'b')", ('acdef', False)]
        # ["reverse_delete('abcdedcba', 'ab')", ('cdedc', True)]
        continue
    if i == 113:
        public_tests.append({"inputs": [[['1234567']], [['3', '11111111']]], "outputs": [["the number of odd elements 4n the str4ng 4 of the 4nput."], ["the number of odd elements 1n the str1ng 1 of the 1nput.", "the number of odd elements 8n the str8ng 8 of the 8nput."]], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "odd_count"
        # ["odd_count(['1234567'])", ["the number of odd elements 4n the str4ng 4 of the 4nput."]]
        # ["odd_count(['3','11111111'])", ["the number of odd elements 1n the str1ng 1 of the 1nput.", "the number of odd elements 8n the str8ng 8 of the 8nput."]]
        continue
    if i == 115:
        public_tests.append({"inputs": [[[[0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 1]], 1], [[[0,0,1,1], [0,0,0,0], [1,1,1,1], [0,1,1,1]], 2], [[[0,0,0], [0,0,0]], 5]], "outputs": [6, 5, 0], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "max_fill"
        # ["max_fill([[0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 1]], 1)", 6]
        # ["max_fill([[0,0,1,1], [0,0,0,0], [1,1,1,1], [0,1,1,1]], 2)", 5]
        # ["max_fill([[0,0,0], [0,0,0]], 5)", 0]
        continue
    if i == 120:
        public_tests.append({"inputs": [[[-3, -4, 5], 3], [[4, -4, 4], 2], [[-3, 2, 1, 2, -1, -2, 1], 1]], "outputs": [[-4, -3, 5], [4, 4], [2]], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "maximum"
        # ["maximum([-3, -4, 5], 3)", [-4, -3, 5]]
        # ["maximum([4, -4, 4], 2)", [4, 4]]
        # ["maximum([-3, 2, 1, 2, -1, -2, 1], 1)", [2]]
        continue
    if i == 122:
        public_tests.append({"inputs": [[[111,21,3,4000,5,6,7,8,9], 4]], "outputs": [24], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "add_elements"
        # ["add_elements([111,21,3,4000,5,6,7,8,9], 4)", 24]
        continue
    if i == 129:
        public_tests.append({"inputs": [[[ [1,2,3], [4,5,6], [7,8,9]], 3], [[ [5,9,3], [4,1,6], [7,8,2]], 1]], "outputs": [[1, 2, 1], [1]], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "minPath"
        # ["minPath([ [1,2,3], [4,5,6], [7,8,9]], 3)", [1, 2, 1]]
        # ["minPath([ [5,9,3], [4,1,6], [7,8,2]], 1)", [1]]
        continue
    if i == 130:
        public_tests.append({"inputs": [[3]], "outputs": [[1, 3, 2, 8]], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "tri"
        # ["tri(3)", [1, 3, 2, 8]]
        continue
    if i == 133:
        public_tests.append({"inputs": [[[1, 2, 3]], [[1, 4, 9]], [[1, 3, 5, 7]], [[1.4, 4.2, 0]], [[-2.4, 1, 1]]], "outputs": [14, 98, 84, 29, 6], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "sum_squares"
        # ["sum_squares([1, 2, 3])", 14]
        # ["sum_squares([1, 4, 9])", 98]
        # ["sum_squares([1, 3, 5, 7])", 84]
        # ["sum_squares([1.4, 4.2, 0])", 29]
        # ["sum_squares([-2.4, 1, 1])", 6]
        continue
    if i == 141:
        public_tests.append({"inputs": [['example.txt'], ['1example.dll']], "outputs": ["Yes", "No"], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "file_name_check"
        # ["file_name_check('example.txt')", "Yes"]
        # ["file_name_check('1example.dll')", "No"]
        continue
    if i == 142:
        public_tests.append({"inputs": [[[1, 2, 3]], [[]], [[-1, -5, 2, -1, -5]]], "outputs": [6, 0, -126], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "sum_squares"
        # ["sum_squares([1, 2, 3])", 6]
        # ["sum_squares([])", 0]
        # ["sum_squares([-1, -5, 2, -1, -5])", -126]
        continue
    if i == 143:
        public_tests.append({"inputs": [['This is a test'], ['lets go for swimming']], "outputs": ["is", "go for"], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "words_in_sentence"
        # ["words_in_sentence('This is a test')", "is"]
        # ["words_in_sentence('lets go for swimming')", "go for"]
        continue
    if i == 147:
        public_tests.append({"inputs": [[5]], "outputs": [1], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "get_max_triples"
        # ["get_max_triples(5)", 1]
        continue
    if i == 149:
        public_tests.append({"inputs": [[['aa', 'a', 'aaa']], [['ab', 'a', 'aaa', 'cd']]], "outputs": [["aa"], ["ab", "cd"]], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "sorted_list_sum"
        # ["sorted_list_sum(['aa', 'a', 'aaa'])", ["aa"]]
        # ["sorted_list_sum(['ab', 'a', 'aaa', 'cd'])", ["ab", "cd"]]
        continue
    if i == 160:
        public_tests.append({"inputs": [[['+', '*', '-'], [2, 3, 4, 5]]], "outputs": [9], "fn_name": row["entry_point"]})
        assert public_tests[-1]["fn_name"] == "do_algebra"
        # ["do_algebra(['+', '*', '-'], [2, 3, 4, 5])", 9]
        continue
    try:
        tests = parse_public(prompt, row["entry_point"])
        inputs = []
        outputs = []
        for j, test in enumerate(tests):
            new_inp = test[0]
            if i not in [61, 119] and not new_inp.strip().startswith(row["entry_point"]):
                new_inp = row["entry_point"] + new_inp.split(row["entry_point"])[1]
                print(f"Warning: ({i}, {j}) doesn't start with")
                print("Old:", test[0], "Fixed:", new_inp)
            if i not in [61, 119] and match_parens(new_inp.split(row["entry_point"])[1]) != new_inp.split(row["entry_point"])[1]:
                new_inp = row["entry_point"] + match_parens(new_inp.split(row["entry_point"])[1])
                print(f"Warning: ({i}, {j}) doesn't match parens")
                print("Fixed:", new_inp)
            
            inputs.append(parse_input(new_inp, row["entry_point"]))

            if test[1] == "":
                print(f"None on ({i}, {j})")
                outputs.append(None)
            elif test[1] == "true":
                print(f"true on ({i}, {j})")
                outputs.append(True)
            elif test[1] == "false":
                print(f"false on ({i}, {j})")
                outputs.append(False)
            elif test[1] == "True.":
                print(f"True on ({i}, {j})")
                outputs.append(True)
            elif test[1] == "False.":
                print(f"False. on ({i}, {j})")
                outputs.append(False)
            elif i == 148 and ")" == test[1][-1]:
                outputs.append(eval(test[1][:-1] + ",)"))
            elif i == 151 and "=" in test[1]:
                print(f"= in ({i}, {j}), {test[1]}")
                outputs.append(eval(test[1].split("=")[-1]))
            elif i == 158 and '""aaaaaaa"' == test[1].strip():
                outputs.append("aaaaaaa")
            else:
                outputs.append(eval(test[1]))

        public_tests.append({"inputs": inputs, "outputs": outputs, "fn_name": row["entry_point"]})
    except Exception as e:
        print(i, j, repr(e))
        break 

In [None]:
all_codes = [row["prompt"] + row["canonical_solution"] for row in d]
results = smart_exec_tests_queuebatched(all_codes, public_tests)
[i for i, (stat, _) in enumerate(results) if not stat]

In [None]:
CHECK_CALL = "def check(candidate):"
new_tests = []
for row in d:
    assert len(row["test"].split(CHECK_CALL)) == 2
    new_tests.append(CHECK_CALL + row["test"].split(CHECK_CALL)[1] + f"check({row['entry_point']})")

In [None]:
results = smart_exec_tests_queuebatched(all_codes, new_tests)
[i for i, (stat, _) in enumerate(results) if not stat]

In [None]:
def remove_docstrings(s: str) -> str:
    s = s.replace("'''", '"""')
    assert s.count('"""') % 2 == 0 and s.count('"""') >= 2
    segmented = s.split('"""')
    curr_str = segmented[0].strip()
    for i in range(1, len(segmented)):
        if i % 2:
            continue
        curr_str += "\n    " + segmented[i].strip()
    return curr_str + "\n    "

    
[i for i, p in enumerate(d["prompt"]) if p.count('"""') != 2]
questions = []
starter_codes = []
for i, row in enumerate(d):
    prob = row["prompt"]
    if "FIX" in row["prompt"]:
        prob = '\n'.join(prob.splitlines()[4:])
    questions.append(f"Unfinished code:\n```\n{prob.strip()}\n    pass\n```\n\n" + f"Your task is to implement `{row['entry_point']}` as described in the docstring.")
    starter_codes.append(remove_docstrings(prob))

In [None]:
solutions = [start.strip() + "\n    " + row["canonical_solution"].strip() for row, start in zip(d, starter_codes)]
results = smart_exec_tests_queuebatched(solutions, new_tests)
[i for i, (stat, _) in enumerate(results) if not stat]

In [None]:
new_test_data = deepcopy(d)
new_test_data= new_test_data.rename_column("task_id", "id")
new_test_data = new_test_data.map(lambda row, idx: {"question": questions[idx]}, with_indices=True)
new_test_data = new_test_data.map(lambda row, idx: {"starter_code": starter_codes[idx]}, with_indices=True)
new_test_data = new_test_data.map(lambda row, idx: {"solutions": [solutions[idx]]}, with_indices=True)
new_test_data = new_test_data.map(lambda row, idx: {"public_input_output": json.dumps(public_tests[idx])}, with_indices=True)
new_test_data = new_test_data.map(lambda row, idx: {"input_output": json.dumps({"inputs": [], "outputs": [], "fn_name": row["entry_point"], "exec_string": new_tests[idx]})}, with_indices=True)
new_test_data = new_test_data.remove_columns(["prompt", "canonical_solution", "test", "entry_point"])

In [None]:
dd = DatasetDict({"test": new_test_data})
dd.push_to_hub("codegenning/F_human_eval", private=True)