In [1]:
import torch
from transformers import AutoModelForSeq2SeqLM

model_path = "tf_board/impact/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)

In [8]:
def clean_generated_str(input):
    return  ' '.join([i.replace(" ", "") for i in input.split('  ') if i != ''])

In [2]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [45]:
vul_code = """
void host_lookup(char *user_supplied_addr){
struct hostent *hp;
in_addr_t *addr;
char hostname[64];
in_addr_t inet_addr(const char *cp);

/*routine that ensures user_supplied_addr is in the right format for conversion */ 

validate_addr_form(user_supplied_addr);
addr = inet_addr(user_supplied_addr);
hp = gethostbyaddr( addr, sizeof(struct in_addr), AF_INET);
strcpy(hostname, hp->h_name);
}
"""

In [46]:
vul_code_vec = tokenizer(vul_code, return_tensors="pt").to(device)

In [47]:
output = model.generate(**vul_code_vec, max_length=153)

In [48]:
explain = tokenizer.decode(output[0], skip_special_tokens=True)

In [49]:
clean_generated_str(explain)

'cause a denial of service (out-of-bounds read )'

In [50]:
sub_graph_code = """
strcpy(hostname, hp->h_name);
"""

In [51]:
sub_code_vec = tokenizer(sub_graph_code, return_tensors="pt").to(device)
sub_output = model.generate(**sub_code_vec, max_length=153)
sub_explain = tokenizer.decode(sub_output[0], skip_special_tokens=True)

In [52]:
clean_generated_str(sub_explain)

'cause a denial of service (out-of-bounds read and application crash )'

In [53]:
not_relevant_graph_code = """
struct hostent *hp;
in_addr_t *addr;
char hostname[64];
"""

In [54]:
not_relevant_code_vec = tokenizer(not_relevant_graph_code, return_tensors="pt").to(device)
not_relevant_output = model.generate(**not_relevant_code_vec, max_length=153)
not_relevant_explain = tokenizer.decode(not_relevant_output[0], skip_special_tokens=True)

In [55]:
clean_generated_str(not_relevant_explain)

'execute arbitrary code or cause a denial of service (memory corruption )'