In [1]:
import torch
from transformers import AutoModelForSeq2SeqLM

model_path = "tf_board/root_cause/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)

In [2]:
def clean_generated_str(input):
    return  ' '.join([i.replace(" ", "") for i in input.split('  ') if i != ''])

In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [4]:
vul_code = """
static sk_sp <SkImage> unPremulSkImageToPremul(SkImage *input){
    SkImageInfo info = SkImageInfo::Make(input->width(), input->height(),
                                         kN32_SkColorType, kPremul_SkAlphaType);
    RefPtr<Uint8Array> dstPixels = copySkImageData(input, info);
    if (!dstPixels)
        return nullptr;
    return newSkImageFromRaster(
        info, std::move(dstPixels),
        static_cast<size_t>(input->width()) * info.bytesPerPixel();
    )
}
"""

In [5]:
vul_code_vec = tokenizer(vul_code, return_tensors="pt").to(device)

In [6]:
output = model.generate(**vul_code_vec, max_length=153)

In [7]:
explain = tokenizer.decode(output[0], skip_special_tokens=True)

In [8]:
clean_generated_str(explain)

'does not consider the kernel and verify an associated with a stream system call is processes'

In [9]:
sub_graph_code = """
return newSkImageFromRaster(
        info, std::move(dstPixels),
        static_cast<size_t>(input->width()) * info.bytesPerPixel();
    )
"""

In [10]:
sub_code_vec = tokenizer(sub_graph_code, return_tensors="pt").to(device)
sub_output = model.generate(**sub_code_vec, max_length=153)
sub_explain = tokenizer.decode(sub_output[0], skip_special_tokens=True)

In [11]:
clean_generated_str(sub_explain)

'copying large amounts of kernel memory to userland'

In [12]:
not_relevant_graph_code = """
if (!dstPixels)
    return nullptr;
"""

In [13]:
not_relevant_code_vec = tokenizer(not_relevant_graph_code, return_tensors="pt").to(device)
not_relevant_output = model.generate(**not_relevant_code_vec, max_length=153)
not_relevant_explain = tokenizer.decode(not_relevant_output[0], skip_special_tokens=True)

In [14]:
clean_generated_str(not_relevant_explain)

'mishandles reference counts'