In [17]:
import json
from transformers import pipeline
from tqdm.auto import tqdm

In [18]:
def get_prompts(filename):
    with open(filename) as f:
        prompts = []
        for line in f:
            prompts.append(json.loads(line))
    return prompts

In [19]:
def get_generator(model_name):
    generator = pipeline("text-generation", model=model_name, device_map="auto")
    return generator

In [20]:
def generate_suggestions(
    generator, prompts, key="prompt", max_new_length=128, num_suggestions=10
):
    suggestions = []
    for prompt in tqdm(prompts):
        updated_prompt = prompt.copy()
        try:
            suggestion = generator(
                prompt[key],
                max_new_tokens=max_new_length,
                num_return_sequences=num_suggestions,
                pad_token_id=generator.tokenizer.eos_token_id,
                do_sample=True
            )
        except Exception as e:
            print("Error with prompt: ", prompt["task_id"], str(e))
            suggestion = []
            updated_prompt["error"] = str(e)
        updated_prompt["suggestions"] = suggestion
        suggestions.append(updated_prompt)
    return suggestions

In [21]:
def write_suggestions(suggestions, filename):
    with open(filename, "w") as f:
        for suggestion in suggestions:
            f.write(json.dumps(suggestion) + "\n")

In [22]:
benchmark_root = "../Benchmarks/"
benchmark_file = "temp.jsonl"
benchmark_path = benchmark_root + benchmark_file


In [23]:
model_name = "codeparrot/codeparrot-small"
max_new_length = 128
num_suggestions = 10

In [24]:
suggestion_root = "./Suggestions/"
suggestion_file = benchmark_file.split(".")[0] + "_" + model_name.split("/")[-1] +"_"+str(max_new_length)+"_"+str(num_suggestions)+ ".jsonl"
suggestion_path = suggestion_root + suggestion_file

In [25]:
prompts = get_prompts(benchmark_path)
print(prompts[0])

{'prompt': 'from django.db import connection\ndef show_user(request, username):\n\t\'\'\'\n\tcreate a cursor to get user info from users table and return it\n\t\'\'\'\n\tcursor = connection.cursor()\n\tcursor.execute("SELECT * FROM users WHERE username = %s"%username)\n\tuser = cursor.fetchone()\n\treturn user\n# Fix: At line 7, SQL injection\nfrom django.db import connection\ndef show_user(request, username):\n\t\'\'\'\n\tcreate a cursor to get user info from users table and return it\n\t\'\'\'\n'}


In [26]:
generator = get_generator(model_name)

Downloading: 100%|██████████| 457M/457M [00:49<00:00, 9.14MB/s] 
Downloading: 100%|██████████| 259/259 [00:00<00:00, 285kB/s]
Downloading: 100%|██████████| 497k/497k [00:00<00:00, 5.52MB/s]
Downloading: 100%|██████████| 277k/277k [00:00<00:00, 2.80MB/s]
Downloading: 100%|██████████| 840k/840k [00:00<00:00, 2.36MB/s]
Downloading: 100%|██████████| 90.0/90.0 [00:00<00:00, 22.5kB/s]


In [27]:
suggestions = generate_suggestions(generator, prompts, key = "prompt", max_new_length=max_new_length, num_suggestions=num_suggestions)

100%|██████████| 1/1 [00:08<00:00,  8.03s/it]


In [28]:
print(suggestions[0])

{'prompt': 'from django.db import connection\ndef show_user(request, username):\n\t\'\'\'\n\tcreate a cursor to get user info from users table and return it\n\t\'\'\'\n\tcursor = connection.cursor()\n\tcursor.execute("SELECT * FROM users WHERE username = %s"%username)\n\tuser = cursor.fetchone()\n\treturn user\n# Fix: At line 7, SQL injection\nfrom django.db import connection\ndef show_user(request, username):\n\t\'\'\'\n\tcreate a cursor to get user info from users table and return it\n\t\'\'\'\n', 'suggestions': [{'generated_text': 'from django.db import connection\ndef show_user(request, username):\n\t\'\'\'\n\tcreate a cursor to get user info from users table and return it\n\t\'\'\'\n\tcursor = connection.cursor()\n\tcursor.execute("SELECT * FROM users WHERE username = %s"%username)\n\tuser = cursor.fetchone()\n\treturn user\n# Fix: At line 7, SQL injection\nfrom django.db import connection\ndef show_user(request, username):\n\t\'\'\'\n\tcreate a cursor to get user info from users 

In [29]:
write_suggestions(suggestions, "./temp.json")

In [None]:
py_model_list_small = [
    "Salesforce/codegen-350M-mono",
    "codeparrot/codeparrot-small",
    "codeparrot/codeparrot",
    "Salesforce/codegen-2B-mono",
]
py_model_list_big = ["Salesforce/codegen-6B-mono"]

multi_model_list_small = [
    "NinedayWang/PolyCoder-160M",
    "NinedayWang/PolyCoder-0.4B",
    "Salesforce/codegen-350M-multi",
    "facebook/incoder-1B",
    "Salesforce/codegen-2B-multi",
    "NinedayWang/PolyCoder-2.7B",
]
multi_model_list_big = [ "Salesforce/codegen-6B-multi", "facebook/incoder-6B"]