In [6]:
from transformers import AutoModel,LlamaTokenizer,LlamaConfig,LlamaForCausalLM,AutoTokenizer,AutoModelForCausalLM,AutoConfig,GenerationConfig
import torch
import json
from transformers import AutoTokenizer
import pandas as pd
from config import *
from pathlib import Path
from string import Template
from tqdm import tqdm

model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,trust_remote_code=True,proxies=proxies , use_auth_token=hf_token , cache_dir=cache_dir , device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name,proxies=proxies , use_auth_token=hf_token , cache_dir=cache_dir)
test_set = json.load(open('../vul4c_dataset/test.json',mode='r'))

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [73]:
small_test_set = list(filter(lambda x:len(x['func'].splitlines()) < 30 , test_set))

small_test_set_df = pd.DataFrame(small_test_set)
small_test_set_df[small_test_set_df['vul'] == 1].to_dict('records')

[{'cve': 'CVE-2016-6663',
  'cwe_list': ['CWE-362'],
  'repo_name': 'MariaDB/server',
  'commit_hash': '347eeefbfc658c8531878218487d729f4e020805',
  'git_url': 'https://github.com/MariaDB/server/commit/347eeefbfc658c8531878218487d729f4e020805',
  'func': 'int my_redel(const char *org_name, const char *tmp_name,\n             time_t backup_time_stamp, myf MyFlags)\n{\n  int error=1;\n  DBUG_ENTER("my_redel");\n  DBUG_PRINT("my",("org_name: \'%s\' tmp_name: \'%s\'  MyFlags: %d",\n\t\t   org_name,tmp_name,MyFlags));\n\n  if (my_copystat(org_name,tmp_name,MyFlags) < 0)\n    goto end;\n  if (MyFlags & MY_REDEL_MAKE_BACKUP)\n  {\n    char name_buff[FN_REFLEN + MY_BACKUP_NAME_EXTRA_LENGTH];    \n    my_create_backup_name(name_buff, org_name, backup_time_stamp);\n    if (my_rename(org_name, name_buff, MyFlags))\n      goto end;\n  }\n  else if (my_delete(org_name, MyFlags))\n      goto end;\n  if (my_rename(tmp_name,org_name,MyFlags))\n    goto end;\n\n  error=0;\nend:\n  DBUG_RETURN(error);\n

In [80]:
balanced_small_test_set = pd.concat(
    [small_test_set_df[small_test_set_df['vul'] == 0].sample(100,random_state=2020),
    small_test_set_df[small_test_set_df['vul'] == 1].sample(100,random_state=2021)]
).sample(frac=1,random_state=2022)
balanced_small_test_set.to_dict('records')

[{'cve': 'CVE-2014-0064',
  'cwe_list': ['CWE-189'],
  'repo_name': 'postgres',
  'commit_hash': '31400a673325147e1205326008e32135a78b4d8a',
  'git_url': 'https://github.com/postgres/postgres/commit/31400a673325147e1205326008e32135a78b4d8a',
  'func': "static void\nfindoprnd(ITEM *ptr, int32 *pos)\n{\n\tif (ptr[*pos].type == VAL || ptr[*pos].type == VALTRUE)\n\t{\n\t\tptr[*pos].left = 0;\n\t\t(*pos)++;\n\t}\n\telse if (ptr[*pos].val == (int32) '!')\n\t{\n\t\tptr[*pos].left = 1;\n\t\t(*pos)++;\n\t\tfindoprnd(ptr, pos);\n\t}\n\telse\n\t{\n\t\tITEM\t   *curitem = &ptr[*pos];\n\t\tint32\t\ttmp = *pos;\n\n\t\t(*pos)++;\n\t\tfindoprnd(ptr, pos);\n\t\tcuritem->left = *pos - tmp;\n\t\tfindoprnd(ptr, pos);\n\t}\n}",
  'graph_idx': 64081,
  'vul': 1,
  'id': 36335},
 {'cve': 'CVE-2019-19536',
  'cwe_list': ['CWE-909'],
  'repo_name': 'torvalds/linux',
  'commit_hash': 'ead16e53c2f0ed946d82d4037c630e2f60f4ab69',
  'git_url': 'https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/comm

In [85]:


generation_config =  GenerationConfig(
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.9,
    top_k=50,
    # do_sample=True,
    # eos_token_id=tokenizer.eos_token_id,
    # bos_token_id=tokenizer.bos_token_id,
    # pad_token_id=model.generation_config.pad_token_id,
    # repetition_penalty=1.1
)

prompt_templates = [
    """I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable, I will give you a function, answer `yes` if the function is vulnerable, answer `no` if there is no vulnerability, no other information needs to be output. The first function is

$func""",
    """
I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable. I will give you a function, answer "yes" if the function is vulnerable, answer "no" if there is no vulnerability, no other information needs to be output.
$func
    """,
    """
Takes a function as input and outputs "yes" if the function is vulnerable and "no" if it is not.
$func
    """,
    """
takes a function as input and outputs "yes" if the function is vulnerable and "no" if it is not. Please note that to determine whether a function is vulnerable, you should analyze its code for potential security vulnerabilities, such as input validation issues, buffer overflow vulnerabilities, SQL injection vulnerabilities, and other security risks. You should pay particular attention to any input or output processing, file handling, and network communication.
$func
    """,
    """As a vulnerability detector, your goal is to determine whether a given function is vulnerable. You will be provided with a function and your task is to identify whether it contains a vulnerability. "return 1" if the function is vulnerable, "return 0" if there is no vulnerability.
$func
    """,
    """I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable. I will give you a function, you should analyze its code for potential security vulnerabilities, such as input validation issues, buffer overflow vulnerabilities, SQL injection vulnerabilities, and other security risks.

$func

Is there any vulnerability in the above code? Answer "yes" or "no"
    """,
    """I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable. I will give you a function, you should analyze its code for potential security vulnerabilities, such as input validation issues, buffer overflow vulnerabilities, SQL injection vulnerabilities, and other security risks.

$func

Is there any vulnerability in the above code?
"return 1" if the function is vulnerable, "return 0" if there is no vulnerability.
    """,
    """I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable. I will give you a function, you should analyze its code for potential security vulnerabilities, such as input validation issues, buffer overflow vulnerabilities, SQL injection vulnerabilities, and other security risks. Answer "yes" if the function is vulnerable, answer "no" if there is no vulnerability, fill answer in #vulnerable

# function
$func

# vulnerable(yes or no)
    """,
    """I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable. I will give you a function, you should analyze its code for potential security vulnerabilities, such as input validation issues, buffer overflow vulnerabilities, SQL injection vulnerabilities, and other security risks. Output "yes" if the function is vulnerable, output "no" if there is no vulnerability.

Input:
$func

Output:
    """,
    """I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable. I will give you a function, you should analyze its code for potential security vulnerabilities, such as input validation issues, buffer overflow vulnerabilities, SQL injection vulnerabilities, and other security risks. Output "yes" if the function is vulnerable, output "no" if there is no vulnerability.

$func

The answer (Yes or No) is
    """,
    """I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable. I will give you a function, you should analyze its code for potential security vulnerabilities, such as input validation issues, buffer overflow vulnerabilities, SQL injection vulnerabilities, and other security risks.
1. If the function is vulnerable, return Yes.
2. If the function is not vulnerable, return No.

$func
    """,
    """As a vulnerability detector, your goal is to determine whether a given function is vulnerable. You will be provided with a function and your task is to identify whether it contains a vulnerability. Return "yes" if the function is vulnerable, or "no" if there is no vulnerability.

$func
Return """,
    """I want you to act as a vulnerability detector, your objective is to detect if a function is vulnerable. I will give you a function, you should analyze its code for potential security vulnerabilities, such as input validation issues, buffer overflow vulnerabilities, SQL injection vulnerabilities, and other security risks. Write a python function named "vulnerable" if the function is vulnerable, or "non_vulnerable" if there is no vulnerability.

$func
    """
]

def prompt_postprocessing(model_name:str,prompt:str):
    if model_name == "WizardLM/WizardLM-7B-V1.0":
        return f"""{prompt}

### Response:
"""
    assert False


for prompt_id,prompt_template in enumerate(prompt_templates):
    llm_result = []
    save_dir = Path(f"result/zero_shot//{model_name.replace('/','-')}")
    save_dir.mkdir(parents=True,exist_ok=True)
    f = open( save_dir /f'prompt_{prompt_id}.txt',mode='w')
    f.write('*'* 40 + f'{"Prompt Template":^20}' +  '*' * 40 + '\n')
    f.write(prompt_template)
    f.write('\n\n')

    for test_id,test_set_item in enumerate(tqdm(balanced_small_test_set.to_dict('records'))):
        func = test_set_item['func']
        vul = test_set_item['vul']
        prompt = Template(prompt_template).substitute(func = func)
        prompt = prompt_postprocessing(model_name,prompt)
        x = tokenizer(prompt,return_tensors='pt').to('cuda')
        y = model.generate(**x,generation_config=generation_config,eos_token_id = tokenizer.eos_token_id,
                           pad_token_id=tokenizer.eos_token_id)
        decode_result = tokenizer.decode(y[0], skip_special_tokens=True)


        f.write('\n' + '*'* 40 + f'{test_id:^20}' +  '*' * 40 + '\n')
        f.write(prompt)
        f.write('\n' + '*'* 40 + f'{"Generate":^20}' +  '*' * 40 + '\n')
        f.write(decode_result)
        f.write('\n' + '*'* 40 + f'{"END":^20}' +  '*' * 40 + '\n' + '*'*100 )
        f.write('\n'*5)

    f.close()


# prompts = ["Write a python function sum two number", "write a multiprocess function"]
#
# for i in range(100):
#     x = tokenizer(prompts[i % 2],return_tensors='pt').to('cuda')
#     y = model.generate(**x,generation_config=generation_config,eos_token_id = tokenizer.eos_token_id,bos_token_id=tokenizer.bos_token_id )
#     print(tokenizer.batch_decode(y, skip_special_tokens=True))



100%|██████████| 200/200 [20:18<00:00,  6.09s/it]
100%|██████████| 200/200 [11:56<00:00,  3.58s/it]
100%|██████████| 200/200 [15:23<00:00,  4.62s/it]
100%|██████████| 200/200 [13:29<00:00,  4.05s/it]
100%|██████████| 200/200 [17:34<00:00,  5.27s/it]
100%|██████████| 200/200 [13:47<00:00,  4.14s/it]
100%|██████████| 200/200 [16:24<00:00,  4.92s/it]
100%|██████████| 200/200 [22:29<00:00,  6.75s/it]
100%|██████████| 200/200 [20:42<00:00,  6.21s/it]
100%|██████████| 200/200 [18:01<00:00,  5.41s/it]
100%|██████████| 200/200 [13:47<00:00,  4.14s/it]
100%|██████████| 200/200 [17:46<00:00,  5.33s/it]
100%|██████████| 200/200 [21:05<00:00,  6.33s/it]


In [49]:

x=20
print(f'{x:^20}')

         20         
