In [None]:
from datasets import load_dataset
import os
import json

from rllm.system_prompts import LCB_FORMATTING_MESSAGE_WITH_STARTER_CODE, LCB_SYSTEM_MESSAGE_GENERIC

train_dataset = load_dataset("KodCode/KodCode-V1", split="train")
print("Training set:", train_dataset)


## Filter Rules

`style`: instruct

`subset`: Leetcode, Codeforces, Code Contests, Taco, Apps

`GPT4o Pass Count`: < 9

`Benchmark Similarity`: < 0.9

Test count: >= 8

In [2]:
def kodcode_filter(x):
    if x['subset'] in ['Leetcode', 'Codeforces', 'Code_Contests', 'Apps', 'Taco']:
        if x['style'] == 'instruct':
            if x['gpt_pass_trial_num'] < 9:
                if x['benchmark_similarity'] < 0.9:
                    if x['test_code'].count('def') >= 8:
                        return True
    return False

In [None]:
def_filtered_dataset = train_dataset.filter(kodcode_filter)

## Bad Data Removal

In [None]:
bad_ids = ["Codeforces_12376_I"]
error_filtered_dataset = def_filtered_dataset.filter(lambda x: x['question_id'] not in bad_ids)

In [None]:
from pprint import pprint
def format_test_info(entry):
    """
    Format test_info into Python starter code with docstring and function declaration.
    
    Args:
        test_info: List of test info dictionaries containing function declaration and docstring
        
    Returns:
        Formatted Python starter code string
    """
    # Return empty if no test info
    test_info = entry.get('test_info', [])
    if not test_info:
        return ""
        
    # Get the function declaration and docstring from first test info
    tests = entry['test_code']
    solution_import = '\n'.join([line for line in tests.split('\n') if line.strip().startswith('from solution import')])
    
    solution_funcs = []
    # Get all the solution functions from solution import
    if solution_import:
        # Extract function names from the import statement
        import_parts = solution_import.replace("from solution import ", "").strip()
        # Split by commas and strip whitespace
        solution_funcs = [func.strip() for func in import_parts.split(',')]

    for t in test_info:
        if t['function_name'] in tests and t['function_name'] not in solution_funcs:
            solution_funcs.append(t['function_name'])
    # Check test infos
    relevant_test_infos = {}
    for t_info in test_info:
        func_dec = t_info.get('function_name', '')
        if not func_dec:
            continue
        
        if func_dec in solution_funcs:
            relevant_test_infos[func_dec] = t_info

    func_decl_instruction = "The code you write must contain the following functions or classes:\n\n"
    for func_name in solution_funcs:
        func_decl_instruction += f"{func_name}\n"
    
    instruction = f"{LCB_FORMATTING_MESSAGE_WITH_STARTER_CODE}\n\n"
    for func_name in solution_funcs:
        if func_name in relevant_test_infos:
            
            instruction += f"{relevant_test_infos[func_name]['function_declaration']}\n"
            doc_string = relevant_test_infos[func_name].get('docstring', '')
            if doc_string:
                indented_docstring = '\n'.join(f"    {line}" for line in doc_string.split('\n'))
                instruction += f"    \"\"\"\n{indented_docstring}\n    \"\"\"\n\n"
            instruction += "\n"

    return func_decl_instruction + '\n\n' + instruction
    

inst = format_test_info(error_filtered_dataset[1251])
print(inst)

## Dedupe Questions

We deduped based on cosine similarity, with a threshold of 0.792

In [None]:
from rllm.utils import RAG
from pprint import pprint
from tqdm import tqdm
questions = [entry['question'] for entry in error_filtered_dataset]

rag = RAG(docs=questions)

In [None]:
rag_cutoff = 0.792

indices_to_remove = set()

for i in tqdm(range(len(error_filtered_dataset))):
    if i in indices_to_remove:
        continue
    similars = rag.top_k(error_filtered_dataset[i]['question'], k=5)
    for entry in similars[1:]:
        if entry['score'].item() > rag_cutoff:
            indices_to_remove.add(entry['index'])
print(len(indices_to_remove))

In [None]:
deduped_dataset = error_filtered_dataset.filter(lambda x, idx: idx not in indices_to_remove, with_indices=True)

print(f"Original dataset size: {len(error_filtered_dataset)}")
print(f"Number of indices to remove: {len(indices_to_remove)}")
print(f"Filtered dataset size: {len(deduped_dataset)}")

In [None]:
dataset = []
for entry in deduped_dataset:
    tests = entry['test_code']
    solution_import = '\n'.join([line for line in tests.split('\n') if line.strip().startswith('from solution import')])
    tests = '\n'.join([line for line in tests.split('\n') if not line.strip().startswith('from solution import')])
    
    instruction = format_test_info(entry)
    
    problem =  f"""{LCB_SYSTEM_MESSAGE_GENERIC} Make sure your code consists of just standalone classes and functions, which can then be tested in a pytest suite for the correctness of your function/class using assertions on return values.
No reading from stdin or writing to stdout is allowed.
    
{entry['question'].strip()}

{instruction}
"""
    
    if len(tests) == 0:
        continue
    new_entry = {
        "problem": problem,
        "solutions": entry["solution"],
        "tests": tests,
    }
    
    dataset.append(new_entry)

print(f'Dataset size: {len(dataset)}')

output_dir = os.path.abspath("../../train/code")
output_file = os.path.join(output_dir, "kodcode.json")

with open(output_file, "w") as f:
    json.dump(dataset, f, indent=4)