In [3]:
from human_eval.data import write_jsonl, read_problems
from langchain.prompts import PromptTemplate
import requests
from tqdm.notebook import tqdm
import json
import re
import random
import ast
import black
import traceback
import textwrap

In [5]:
url = 'http://0.0.0.0:8000/v1/completions'
with open('prompt_com.txt', 'r') as file:
    template_content = file.read()
prompt_template = PromptTemplate(
    input_variables=["problem_text"],
    template=template_content)


def clean_code(raw_output: str, entry_point: str) -> str:

    code = re.sub(r"^```(python)?", "", raw_output)
    code = re.sub(r"```$", "", code)
    code = re.sub(r"(?i)sure.*?:", "", code)

    if code[:4] == 'def ':
        code = code.split("\n", 1)[1] if "\n" in code else ""

    other_func = re.search(r"\n(def |class )", code)
    if other_func:
        code = code[:other_func.start()]
        

    code = textwrap.dedent(code.strip("\n"))
    fixed_lines = []
    for line in code.splitlines():
        if line.strip():  
            fixed_lines.append("    " + line)
        else:
            fixed_lines.append("")
    return "\n".join(fixed_lines).rstrip()


def format_and_validate(code: str):
    try:
        mode = black.Mode()
        code = black.format_str(code, mode=mode)

        ast.parse(code)
        return code
    except Exception:
        return None

def enhance_prompt(item, max_cases=2, seed=42):

    random.seed(seed)
    prompt = item["prompt"]
    test_code = item.get("test", "")
    function_name = item.get('entry_point', '')

    asserts = re.findall(r"assert (candidate\(.*?\)\s*==\s*(?:True|False))", test_code)

    if not asserts:
        return prompt

    true_cases = [a for a in asserts if a.endswith("True")]
    false_cases = [a for a in asserts if a.endswith("False")]

    selected = []
    if true_cases:
        selected.append(random.choice(true_cases))
    if false_cases:
        selected.append(random.choice(false_cases))

    if len(selected) < max_cases:
        extra = list(set(asserts) - set(selected))
        if extra:
            selected.extend(random.sample(extra, min(max_cases - len(selected), len(extra))))

    selected = [v.replace('candidate', function_name) for v in selected]
    comment_cases = "    # some more testing cases:\n" + "\n".join([f"    # {case}" for case in selected])

    enhanced_prompt = prompt.rstrip() + "\n" + comment_cases
    return enhanced_prompt


def generate_single(problem_text, temp=.2, n=1, tk=20, tp=.85):
    """
    The main inference function

    :param problem_text: the code problem text
    :param temp: generation temperature
    :param n: samples
    :param tk: top k
    :param tp: top p
    :param enhancing: if 
    :return:
    """
    assembled_prompt = prompt_template.format(problem_text=problem_text)

    if temp == 0:
        n = 1
    
    data = {
        "prompt": assembled_prompt,
        "max_tokens": 512,
        "temperature": temp,
        "stop": ["\n\n", "\nclass ", "\ndef ", "<|im_end|>", "</s>"],
        "top_k": tk, 
        "top_p": tp,
        "n": n,
        "logits_processor": [
            {"type": "ban_tokens", "banned_ids": [6385, 750, 1112]}]
    }
    response = requests.post(url, headers={"Content-Type": "application/json"}, data=json.dumps(data))
    
    if response.status_code == 200:
        result = response.json()
        generated = result["choices"]
        
        return [v['text'] for v in generated]
        
    else:
        return ['return']


def solve(problem, temp, n, tk=20, tp=.85):
    entry_point = problem['entry_point']
    problem_text = enhance_prompt(problem)

    generated_list = generate_single(problem_text, temp, n, tk, tp)

    rslt = []
    for generated in generated_list:
        cleaned = clean_code(generated, entry_point)
        code = f"{problem_text}\n{cleaned}"
        formatted = format_and_validate(code)
        # stage 1: check formatting
        if formatted is None:
            # print('s1 triggered')
            new_generated = generate_single(problem_text, temp=.6, n=1, tk=40, tp=.9)[0]
            stage_1 = clean_code(new_generated, entry_point)
        else:
            stage_1 = cleaned

        # stage 2: check if there is a return
        s2_count = 0
        while s2_count < 2:
            if stage_1.find('    return') < 0:
                # if s2_count == 0:
                #     print('s2 triggered')
                new_problem_text = problem_text + '\n' + stage_1 + '\n'
                new_generated = generate_single(new_problem_text, temp, 1, tk, tp)[0]
                stage_2 = f"{stage_1}\n{new_generated}"
            else:
                stage_2 = stage_1
                break
            stage_1 = stage_2
            s2_count += 1
        rslt.append(stage_2)
    return rslt

In [6]:
problems = read_problems('human-eval/data/HumanEval.jsonl.gz')
temps = [0, .05, .1, .2, .4]
ns = 1
for t in temps:
    samples = []
    for task_id in tqdm(problems):
        problem = problems[task_id]
        
        function_head = enhance_prompt(problems[task_id])
        generated = solve(problem, t, ns, tk=20, tp=.85)
        samples.extend([dict(task_id=task_id, completion=generated[j]) for j in range(len(generated))])
    write_jsonl(f"complex_{t}_{ns}.jsonl", samples) 

  0%|          | 0/164 [00:00<?, ?it/s]

s2 triggered
s2 triggered
s1 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered


  0%|          | 0/164 [00:00<?, ?it/s]

s2 triggered
s2 triggered
s1 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered


  0%|          | 0/164 [00:00<?, ?it/s]

s2 triggered
s2 triggered
s1 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered


  0%|          | 0/164 [00:00<?, ?it/s]

s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered


  0%|          | 0/164 [00:00<?, ?it/s]

s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
s1 triggered
s2 triggered
s2 triggered
s2 triggered
s2 triggered
