In [None]:
import os
os.environ["EXECUTOR_URL"] = "***REMOVED***"
os.environ["EXECUTOR_AUTH"] = "***REMOVED***"

In [None]:
import json
from datasets import load_dataset, DatasetDict, Dataset
from backtranslate import BackTranslateModel
from base_classes import Problem, Test
from exec_utils import run_tests_per_code
from typing import Any
import datetime


In [None]:
# DATASET_NAME = "codegenning/finetuning-taco-plain300k-rlxf-withtest"
# DATASET_NAME = "codegenning/taco-rl-tests10-easy"
NUM_BATCHES_TO_TRY = 3
GEN_BATCH_SIZE = 2

NUM_WORDS = 100
PATH_TO_RESULT = f"temp_backtranslate_gen_logs/{datetime.datetime.now().strftime('%m-%dT%H-%M-%S')}"

DATASET_NAME = "codegenning/taco-rl-tests10-withpassingsolutions_v3"
data = load_dataset(DATASET_NAME)

In [None]:
prompts = list(data["train"]["prompt"])
def filter_prompt(prompt: str, delimiter: str = '"""') -> str:
    return prompt.split(delimiter)[1].split(delimiter)[0].strip()

actual_prompts = [filter_prompt(prompt) for prompt in prompts]
actual_starter_code = list(data["train"]["starter_code"]) 

tests = list(data["train"]["input_output"])
actual_tests = [json.loads(test_set) for test_set in tests]

actual_solutions = list(data["train"]["solutions"])

In [None]:
problems: list[Problem] = []
for i, (prompt, starter_code, test, solution) in list(enumerate(zip(actual_prompts, actual_starter_code, actual_tests, actual_solutions))):
    try:
        problems.append(Problem.from_coderm_item(prompt, starter_code, None, tests=test, solutions=solution))
    except:
        new_outputs = []
        for t in test["outputs"]:
            if isinstance(t, str):
                new_outputs.append(t)
            else:
                assert isinstance(t, list)
                new_outputs.append('\n'.join(t) + '\n')
        test["outputs"] = new_outputs
        problems.append(Problem.from_coderm_item(prompt, starter_code, None, tests=test, solutions=solution))

In [None]:
btm = BackTranslateModel(model_name="gpt-4o-mini", experiment_directory=PATH_TO_RESULT, cache_file="caches/temp_backtranslate_cache.json", num_words=NUM_WORDS)

In [None]:
tproblems = problems[:20]

In [None]:
expanded_problems: list[Problem] = []
problem_to_expand_idx: list[list[int]] = []
expand_to_problem_idx: list[int] = []
for i, problem in enumerate(tproblems):
    problem_to_expand_idx.append([])
    for solution in problem.solutions:
        problem_to_expand_idx[i].append(len(expanded_problems))
        expanded_problems.append(Problem(problem.problem_str, problem.starter_code, problem.public_tests, problem.private_tests, [solution]))
        expand_to_problem_idx.append(i)

In [None]:
selected_codes = [None] * len(expanded_problems)
selected_nl_sols = [None] * len(expanded_problems)
unsolved_idxs = list(range(len(expanded_problems)))

for iter_num in range(NUM_BATCHES_TO_TRY):
    unsolved_problems = [problems[i] for i in unsolved_idxs]
    tiled_problems = unsolved_problems * GEN_BATCH_SIZE 
    
    btm.querier.set_log_directory(os.path.join(PATH_TO_RESULT, f"iter_{iter_num}"))
    generated = btm.generate_solutions(tiled_problems, requery=True)
    assert len(generated) == len(tiled_problems)

    results = run_tests_per_code(generated, [problem.private_tests for problem in tiled_problems], [30] * len(tiled_problems))


    query_path = os.path.join(PATH_TO_RESULT, f"iter_{iter_num}")
    solution_files = [f for f in os.listdir(query_path) if f.startswith("solution")]

    solution_paths = []
    for solution_file in solution_files:
        solution_path = os.path.join(query_path, solution_file)
        print(f"Found solution file: {solution_path}")
        solution_paths.append(solution_path)
    assert len(solution_paths)

    nl_solutions = []
    for path in solution_paths:
        with open(solution_path, "r") as solution_file:
            nl_sub_solutions = json.load(solution_file)
            nl_solutions.extend([e["completion"]["text"] for e in nl_sub_solutions])

    assert len(nl_solutions) == len(results) == len(generated)


    for i, (result, gen_code, gen_nl_sol) in enumerate(zip(results, generated, nl_solutions)):
        original_idx = unsolved_idxs[i % len(unsolved_problems)]
        result_good, _ = result
        if result_good:
            selected_codes[original_idx] = gen_code
            selected_nl_sols[original_idx] = gen_nl_sol

    unsolved_idxs = [i for i, code in enumerate(selected_codes) if code is None]
    print(f"Remaining 'unsolved' problems: {len(unsolved_idxs)}")

    if len(unsolved_idxs) == 0:
        break

In [None]:
def convert_test_list(tests: list[Test]) -> dict[str, Any]:
    output_dict = {"inputs": [], "outputs": []}
    assert len(tests)
    for test in tests:
        assert test.fn_name == tests[0].fn_name, "All tests must have the same fn_name"

    fn_name = tests[0].fn_name
    if fn_name is not None and fn_name != "":
        output_dict["fn_name"] = fn_name

    for test in tests:
        output_dict["inputs"].append(test.get_input_no_kwargs())
        output_dict["outputs"].append(test.output)
    
    return output_dict

In [None]:
new_problems_dataset = {"problem_str": [], "starter_code": [], "tests": [], "code_solutions": [], "nl_solutions": []}

for orig_idx, expand_idxs in enumerate(problem_to_expand_idx):
    filtered_nl_solutions = []
    filtered_code_solutions = []
    for idx in expand_idxs:
        assert (selected_nl_sols[idx] is None) == (selected_codes[idx] is None)
        if selected_nl_sols[idx] is not None:
            filtered_nl_solutions.append(selected_nl_sols[idx])
            filtered_code_solutions.append(expanded_problems[idx].solutions[0])

    new_problems_dataset["problem_str"].append(tproblems[orig_idx].problem_str)
    new_problems_dataset["starter_code"].append(tproblems[orig_idx].starter_code)
    new_problems_dataset["tests"].append(convert_test_list(tproblems[orig_idx].private_tests))
    new_problems_dataset["code_solutions"].append(filtered_code_solutions)
    new_problems_dataset["nl_solutions"].append(filtered_nl_solutions)

new_problems_dataset = Dataset.from_dict(new_problems_dataset)

In [None]:
len(new_problems_dataset["code_solutions"][0])

In [None]:
data

In [None]:
ds = DatasetDict({"train": new_problems_dataset})
ds.push_to_hub(DATASET_NAME + "_with_nlsols", commit_message="With NL solutions")

In [None]:

exec_results = run_tests_per_code(impls, [problem.private_tests for problem in expanded_problems], [30] * len(expanded_problems))
results = [stat for stat, _ in exec_results]
check = [c for _, c in exec_results]

In [None]:
query_path = os.path.join(PATH_TO_RESULT, "queries")
solution_files = [f for f in os.listdir(query_path) if f.startswith("solution")]

solution_path = None
for solution_file in solution_files:
    solution_path = os.path.join(query_path, solution_file)
    print(f"Found solution file: {solution_path}")

assert solution_path is not None

with open(solution_path, "r") as solution_file:
    nl_solutions = json.load(solution_file)

nl_solutions = [e["completion"]["text"] for e in nl_solutions]
assert len(nl_solutions) == len(results)