Reformat the output JSON & code from the preprocessing step in `notebooks/codeio/PreprocessCode.ipynb`.

The output format will align with the data we extract from existing CodeI/O dataset, in `notebooks/codeio.ipynb`.

In [1]:
import json
from pathlib import Path

with open(Path("output/processed_code.jsonl"), "r") as f:
    samples = [json.loads(line) for line in f]

for sample in samples:
    main_code = sample["reference_code"]
    del sample["reference_code"]
    if "def main(" in main_code:
        main_code = main_code.replace("def main(", "def main_solution(")
    sample["code_sample"] = main_code

    input_generator = sample["input_generator"]
    if "def input_generator()" in input_generator:
        input_generator = input_generator.replace("def input_generator()", "def generate_inputs(random: Random)")
    if "import random" in input_generator:
        input_generator = input_generator.replace("import random\n    ", "").replace("import random\n", "")
    sample["input_generator"] = input_generator

    sample["input_output_spec"] = sample["parameters"]
    del sample["parameters"]

    sample["task_description"] = sample["query"]
    del sample["query"]

with open(Path("output/formatted_code.jsonl"), "w") as f:
    for sample in samples:
        f.write(json.dumps(sample) + "\n")

Now we need to filter out unsuitable samples from the data. First we prioritise samples which are inherently random, reliant on external services (e.g. network requests), or whose input generators do not match the correct random usage requirements, as this could cause irreproducibility in RL training.

In [2]:
def verify_input_generator(input_generator_code):
    if "def generate_inputs(random: Random)" not in input_generator_code and "def generate_inputs(rng: Random)" not in input_generator_code:
        return False
    if "import numpy" in input_generator_code or "np.random" in input_generator_code:
        return False
    if "import random" in input_generator_code:
        return False
    return True

def verify_main_solution(main_solution_code):
    if "def main_solution(" not in main_solution_code:
        return False
    if "import random" in main_solution_code:
        return False
    if "from random import" in main_solution_code:
        return False
    if "np.random" in main_solution_code:
        return False
    if "import requests" in main_solution_code or " requests." in main_solution_code or "from requests import" in main_solution_code:
        return False
    return True

remove = set()
for i, sample in enumerate(samples):
    if not verify_input_generator(sample["input_generator"]):
        remove.add(i)
        print(f"Removing sample {i} due to bad input generator")
    elif not verify_main_solution(sample["code_sample"]):
        remove.add(i)
        print(f"Removing sample {i} due to bad main solution")

removed_samples = [sample for i, sample in enumerate(samples) if i in remove]
samples = [sample for i, sample in enumerate(samples) if i not in remove]
print(f"Removed {len(remove)} samples")

with open(Path("output/filtered_code.jsonl"), "w") as f:
    for sample in samples:
        f.write(json.dumps(sample) + "\n")

Removing sample 6 due to bad input generator
Removing sample 8 due to bad input generator
Removing sample 28 due to bad input generator
Removing sample 30 due to bad input generator
Removing sample 39 due to bad main solution
Removing sample 43 due to bad main solution
Removing sample 47 due to bad main solution
Removing sample 53 due to bad input generator
Removing sample 59 due to bad input generator
Removing sample 64 due to bad main solution
Removing sample 87 due to bad main solution
Removing sample 112 due to bad main solution
Removing sample 116 due to bad main solution
Removing sample 121 due to bad input generator
Removing sample 141 due to bad main solution
Removing sample 144 due to bad main solution
Removing sample 150 due to bad main solution
Removing sample 155 due to bad main solution
Removing sample 159 due to bad main solution
Removing sample 162 due to bad input generator
Removing sample 168 due to bad input generator
Removing sample 170 due to bad main solution
Remov

In [3]:
removed_samples[0]["input_generator"]

'def generate_inputs(random: Random):\n    import numpy as np\n    \n    height = random.randint(10, 20)\n    width = random.randint(10, 20)\n    image0 = np.random.rand(height, width)\n    image1 = np.random.rand(height, width)\n    num_iter = random.randint(10, 100)\n    alpha = random.uniform(0.01, 1.0) if random.choice([True, False]) else None\n\n    return {"image0": image0, "image1": image1, "num_iter": num_iter, "alpha": alpha}'

In [4]:
removed_samples[43]["code_sample"]

'def main_solution(search_terms):\n    import requests\n    from bs4 import BeautifulSoup\n    from fake_useragent import UserAgent\n    import webbrowser\n\n    url = "https://www.google.com/search?q=" + " ".join(search_terms)\n    res = requests.get(url, headers={"UserAgent": UserAgent().random}, timeout=10)\n    soup = BeautifulSoup(res.text, "html.parser")\n    links = list(soup.select(".eZt8xd"))[:5]\n\n    opened_links = []\n    for link in links:\n        if link.text == "Maps":\n            opened_links.append(link.get("href"))\n            webbrowser.open(link.get("href"))\n        else:\n            opened_links.append(f"https://google.com{link.get(\'href\')}")\n            webbrowser.open(f"https://google.com{link.get(\'href\')}")\n\n    return {"opened_links": opened_links}'

In [5]:
from dotenv import load_dotenv
load_dotenv()
import asyncio
import os
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from typing import Any, Iterable

VERIFY_PROMPT = """
Given the following code snippet, you must verify whether it is deterministic.

It is not deterministic if it utilises potentially non-deterministic functions such as random number generators, network requests, or time functions. It also qualifies as non-deterministic if it calls another function or library which in turn produces non-deterministic outputs.

Code snippet:

{0}

If the function is deterministic, return True. Otherwise, return False. Respond only with this one work, no other content or explanation.
"""

# Cap concurrent requests. I had to set this to 1 for the DeepSeek API to work, YMMV
semaphore = asyncio.Semaphore(1)

async def llm_generate(
    client: AsyncOpenAI,
    messages: Iterable[ChatCompletionMessageParam],
    sampling_params: dict[str, Any],
    retry_empty_response: bool = True,
    max_retries: int = 3,
) -> ChatCompletion:
    for trial in range(max_retries):
        async with semaphore:
            try:
                completion = await client.chat.completions.create(
                    messages=messages, **sampling_params
                )
                if completion.choices[0].message.content or not retry_empty_response:
                    return completion
                await asyncio.sleep(5)
            except Exception as e:
                print(f"Failure response (trial {trial}):", e)
                await asyncio.sleep(3 * (trial + 1))
                if trial == max_retries - 1:
                    raise

client = AsyncOpenAI(
    base_url=os.getenv("API_BASE_URL"),
    api_key=os.getenv("API_KEY"),
    timeout=120.0,
)

sampling_params = {
    "model": "deepseek-chat",  # For DeepSeek API
    #"model": "deepseek/deepseek-chat:free",  # For OpenRouter
    "max_tokens": 8192,
}

In [6]:
from tqdm import tqdm

remove_nondeterministic = set()
for i, sample in tqdm(enumerate(samples)):
    messages = [
        {"role": "user", "content": VERIFY_PROMPT.format(sample["code_sample"])},
    ]
    completion = await llm_generate(client, messages, sampling_params)
    content = completion.choices[0].message.content
    if not content or content.strip() not in ["True", "False"]:
        print(f"Sample {i} failed to verify")
        print(content)
    elif content.strip() == "False":
        print(f"Sample {i} is non-deterministic")
        remove_nondeterministic.add(i)

removed_samples = [sample for i, sample in enumerate(samples) if i in remove]
samples = [sample for i, sample in enumerate(samples) if i not in remove]
print(f"Removed {len(remove)} samples")

with open(Path("output/filtered_code_2.jsonl"), "w") as f:
    for sample in samples:
        f.write(json.dumps(sample) + "\n")

33it [04:49,  8.14s/it]

Sample 32 is non-deterministic


58it [08:49,  9.66s/it]

Sample 57 is non-deterministic


147it [23:40, 12.39s/it]

Sample 146 is non-deterministic


152it [24:19,  8.55s/it]

Sample 151 is non-deterministic


158it [25:30, 10.53s/it]

Sample 157 is non-deterministic


172it [27:33,  7.87s/it]

Sample 171 is non-deterministic


173it [27:47,  9.64s/it]

Sample 172 is non-deterministic


231it [37:31,  9.87s/it]

Sample 230 is non-deterministic


285it [48:06, 10.91s/it]

Sample 284 is non-deterministic


343it [58:49, 15.48s/it]

Sample 342 is non-deterministic


363it [1:02:19, 11.92s/it]

Sample 362 is non-deterministic


374it [1:04:16, 11.96s/it]

Sample 373 is non-deterministic


394it [1:07:47, 11.56s/it]

Sample 393 is non-deterministic


429it [1:14:50, 11.54s/it]

Sample 428 is non-deterministic


451it [1:19:16, 12.64s/it]

Sample 450 is non-deterministic


555it [1:40:31,  9.80s/it]

Sample 554 is non-deterministic


603it [1:48:46,  9.54s/it]

Sample 602 is non-deterministic


634it [1:53:27, 10.77s/it]

Sample 633 is non-deterministic


638it [1:53:59,  8.85s/it]

Sample 637 is non-deterministic


685it [2:01:43, 10.44s/it]

Sample 684 is non-deterministic


689it [2:02:21,  9.03s/it]

Sample 688 is non-deterministic


782it [2:19:05, 10.67s/it]

Removed 81 samples





In [7]:
removed_samples[0]["code_sample"]

'def main_solution(message, word_percentage=20, letter_percentage=85):\n    ENGLISH_WORDS = {}\n    with open("dictionary.txt") as dictionary_file:\n        for word in dictionary_file.read().split("\\n"):\n            ENGLISH_WORDS[word] = None\n\n    def remove_non_letters(message):\n        return "".join(symbol for symbol in message if symbol in ascii_letters + " \\t\\n")\n\n    def get_english_count(message):\n        message = message.upper()\n        message = remove_non_letters(message)\n        possible_words = message.split()\n        matches = len([word for word in possible_words if word in ENGLISH_WORDS])\n        return float(matches) / len(possible_words)\n\n    words_match = get_english_count(message) * 100 >= word_percentage\n    num_letters = len(remove_non_letters(message))\n    message_letters_percentage = (float(num_letters) / len(message)) * 100\n    letters_match = message_letters_percentage >= letter_percentage\n    is_english = words_match and letters_match\n\n 

Note: following the above steps, two further filtering steps were taken:

- manually review every code snippet for security issues, dependencies on libraries, or non-determinism missed by the LLM classification
- run every code snippet and input generator 100 times, dropping any which caused an error