In [1]:
import numpy as np
import torch
from torch_geometric.data import Data
from transformers import RobertaTokenizer, RobertaModel

In [84]:
def get_nodes(json_data, function_index=1, graph_types=["AST", "CFG", "PDG"]):
    function = json_data["functions"][function_index]
    nodes = []
    for graph_type in graph_types:
        nodes.extend(function.get(graph_type, []))  # Lấy toàn bộ node từ các loại đồ thị
    return nodes

In [48]:
def get_edges(json_data, graph_types=["AST", "CFG", "PDG"], function_index=1):
    function_data = json_data["functions"][function_index]
    edges = []
    for graph_type in graph_types:
            for node in function_data[graph_type]:
                for edge in node.get("edges", []):
                    edges.append({
                        "in": edge["in"],
                        "out": edge["out"],
                        "type": edge["label"], 
                        "code": edge["code"]
                    })
    return edges

In [85]:
def filter_nodes(nodes):
    valid_nodes = {}
    for node in nodes:
        if "id" in node and "line_number" in node and "code" in node and "type" in node:
            valid_nodes[node["id"]] = node
        else:
            print(f"Node không hợp lệ: {node}")
    return valid_nodes


In [86]:
from collections import OrderedDict

def order_nodes(nodes, max_nodes):
    nodes_by_line = sorted(nodes.items(), key=lambda n: n[1].get("line_number", 0))
    for i, (node_id, node) in enumerate(nodes_by_line):
        node["order"] = i
    if len(nodes_by_line) > max_nodes:
        nodes_by_line = nodes_by_line[:max_nodes]
    return OrderedDict(nodes_by_line)


In [89]:
def parse_to_nodes(json_data, max_nodes=500):
    raw_nodes = get_nodes(json_data, function_index=1)
    filtered_nodes = filter_nodes(raw_nodes)
    return order_nodes(filtered_nodes, max_nodes)

In [112]:
import json
file_path = "data/json/0.json"
with open(file_path, "r") as f:
    json_data = json.load(f)

# Kiểm thử parse_to_nodes
parsed_nodes = parse_to_nodes(json_data, max_nodes=500)
print("Parsed Nodes:")
for node_id, node in parsed_nodes.items():
    print(f"ID: {node_id}, Line: {node['line_number']}, Type: {node['type']}, Code: {node['code']}, Order: {node['order']}")

# Kiểm thử get_edges
edges = get_edges(json_data, function_index=1)
print("\nEdges:")
for edge in edges:
    print(f"In: {edge['in']}, Out: {edge['out']}, Type: {edge['type']}, Code: {edge['code']}")

Parsed Nodes:
ID: 111669149697, Line: 1, Type: METHOD, Code: page_check_range, Order: 0
ID: 115964116992, Line: 1, Type: PARAM, Code: target_ulong start, Order: 1
ID: 115964116993, Line: 1, Type: PARAM, Code: target_ulong len, Order: 2
ID: 115964116994, Line: 1, Type: PARAM, Code: int flags, Order: 3
ID: 128849018880, Line: 1, Type: METHOD_RETURN, Code: int, Order: 4
ID: 25769803777, Line: 2, Type: BLOCK,{
    PageDesc *p;
    target_ulong end;
    ta..., Code: {
    PageDesc *p;
    target_ulong end;
    ta..., Order: 5
ID: 94489280512, Line: 3, Type: LOCAL, Code: PageDesc* p: PageDesc*, Order: 6
ID: 94489280513, Line: 4, Type: LOCAL, Code: target_ulong end: target_ulong, Order: 7
ID: 94489280514, Line: 5, Type: LOCAL, Code: target_ulong addr: target_ulong, Order: 8
ID: 47244640256, Line: 6, Type: CONTROL_STRUCTURE IF, Code: if (start + len < start), Order: 9
ID: 30064771072, Line: 6, Type: <operator>.lessThan, Code: start + len < start, Order: 10
ID: 30064771073, Line: 6, Type: <oper

In [93]:
import re
import codecs
from typing import List

keywords = frozenset({'__asm', '__builtin', '__cdecl', '__declspec', '__except', '__export', '__far16', '__far32',
                      '__fastcall', '__finally', '__import', '__inline', '__int16', '__int32', '__int64', '__int8',
                      '__leave', '__optlink', '__packed', '__pascal', '__stdcall', '__system', '__thread', '__try',
                      '__unaligned', '_asm', '_Builtin', '_Cdecl', '_declspec', '_except', '_Export', '_Far16',
                      '_Far32', '_Fastcall', '_finally', '_Import', '_inline', '_int16', '_int32', '_int64',
                      '_int8', '_leave', '_Optlink', '_Packed', '_Pascal', '_stdcall', '_System', '_try', 'alignas',
                      'alignof', 'and', 'and_eq', 'asm', 'auto', 'bitand', 'bitor', 'bool', 'break', 'case',
                      'catch', 'char', 'char16_t', 'char32_t', 'class', 'compl', 'const', 'const_cast', 'constexpr',
                      'continue', 'decltype', 'default', 'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum',
                      'explicit', 'export', 'extern', 'false', 'final', 'float', 'for', 'friend', 'goto', 'if',
                      'inline', 'int', 'long', 'mutable', 'namespace', 'new', 'noexcept', 'not', 'not_eq', 'nullptr',
                      'operator', 'or', 'or_eq', 'override', 'private', 'protected', 'public', 'register',
                      'reinterpret_cast', 'return', 'short', 'signed', 'sizeof', 'static', 'static_assert',
                      'static_cast', 'struct', 'switch', 'template', 'this', 'thread_local', 'throw', 'true', 'try',
                      'typedef', 'typeid', 'typename', 'union', 'unsigned', 'using', 'virtual', 'void', 'volatile',
                      'wchar_t', 'while', 'xor', 'xor_eq', 'NULL'})
main_set = frozenset({'main'})
main_args = frozenset({'argc', 'argv'})

operators3 = {'<<=', '>>='}
operators2 = {
    '->', '++', '--', '**',
    '!~', '<<', '>>', '<=', '>=',
    '==', '!=', '&&', '||', '+=',
    '-=', '*=', '/=', '%=', '&=', '^=', '|='
}
operators1 = {
    '(', ')', '[', ']', '.',
    '+', '&',
    '%', '<', '>', '^', '|',
    '=', ',', '?', ':',
    '{', '}', '!', '~'
}


def to_regex(lst):
    return r'|'.join([f"({re.escape(el)})" for el in lst])


regex_split_operators = to_regex(operators3) + to_regex(operators2) + to_regex(operators1)
def clean_gadget(gadget):
    fun_symbols = {}
    var_symbols = {}

    fun_count = 1
    var_count = 1
    
    rx_fun = re.compile(r'\b([_A-Za-z]\w*)\b(?=\s*\()')
    rx_var = re.compile(r'\b([_A-Za-z]\w*)\b((?!\s*\**\w+))(?!\s*\()')
    cleaned_gadget = []

    for line in gadget:
        ascii_line = re.sub(r'[^\x00-\x7f]', r'', line)
        hex_line = re.sub(r'0[xX][0-9a-fA-F]+', "HEX", ascii_line)
        user_fun = rx_fun.findall(hex_line)
        user_var = rx_var.findall(hex_line)
        for fun_name in user_fun:
            if len({fun_name}.difference(main_set)) != 0 and len({fun_name}.difference(keywords)) != 0:
                if fun_name not in fun_symbols.keys():
                    fun_symbols[fun_name] = 'FUN' + str(fun_count)
                    fun_count += 1
                hex_line = re.sub(r'\b(' + fun_name + r')\b(?=\s*\()', fun_symbols[fun_name], hex_line)

        for var_name in user_var:
            if len({var_name[0]}.difference(keywords)) != 0 and len({var_name[0]}.difference(main_args)) != 0:
                if var_name[0] not in var_symbols.keys():
                    var_symbols[var_name[0]] = 'VAR' + str(var_count)
                    var_count += 1
                hex_line = re.sub(r'\b(' + var_name[0] + r')\b(?:(?=\s*\w+\()|(?!\s*\w+))(?!\s*\()',
                                  var_symbols[var_name[0]], hex_line)

        cleaned_gadget.append(hex_line)
    return cleaned_gadget

def tokenizer(code, flag=False):
    gadget: List[str] = []
    tokenized: List[str] = []
    no_str_lit_line = re.sub(r'["]([^"\\\n]|\\.|\\\n)*["]', '', code)
    no_char_lit_line = re.sub(r"'.*?'", "", no_str_lit_line)
    code = no_char_lit_line

    if flag:
        code = codecs.getdecoder("unicode_escape")(no_char_lit_line)[0]

    for line in code.splitlines():
        if line == '':
            continue
        stripped = line.strip()
        gadget.append(stripped)

    clean = clean_gadget(gadget)

    for cg in clean:
        if cg == '':
            continue

        pat = re.compile(r'(/\*([^*]|(\*+[^*\/]))*\*+\/)|(\/\/.*)')
        cg = re.sub(pat, '', cg)
        cg = re.sub('(\n)|(\\\\n)|(\\\\)|(\\t)|(\\r)', '', cg)
        splitter = r' +|' + regex_split_operators + r'|(\/)|(\;)|(\-)|(\*)'
        cg = re.split(splitter, cg)
        cg = list(filter(None, cg))
        cg = list(filter(str.strip, cg))
        tokenized.extend(cg)

    return tokenized


In [94]:
sample_code = "int main() { int x = 1; if (x > 0) return x; }"
tokens = tokenizer(sample_code)
print("Tokens:", tokens)

Tokens: ['int', 'main', '(', ')', '{', 'int', 'VAR1', '=', '1', ';', 'if', '(', 'VAR1', '>', '0', ')', 'return', 'VAR1', ';', '}']


In [96]:
def encode_input(text, tokenizer):
    max_length = 512
    input = tokenizer(text, max_length=max_length, truncation=True, padding='max_length', return_tensors='pt')
#     print(input.keys())
    return input.input_ids, input.attention_mask

In [97]:
codebert_tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")

input_ids, attention_mask = encode_input(tokens, codebert_tokenizer)
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)

Input IDs: tensor([[    0,  2544,     2,  ...,     1,     1,     1],
        [    0, 17894,     2,  ...,     1,     1,     1],
        [    0,  1640,     2,  ...,     1,     1,     1],
        ...,
        [    0,   846,  2747,  ...,     1,     1,     1],
        [    0,   131,     2,  ...,     1,     1,     1],
        [    0, 24303,     2,  ...,     1,     1,     1]])
Attention Mask: tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])


In [110]:
class NodesEmbedding:
    def __init__(self, nodes_dim: int):
        self.tokenizer_bert = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
        self.bert_model = RobertaModel.from_pretrained("microsoft/codebert-base").to("cuda")
        self.nodes_dim = nodes_dim
        assert self.nodes_dim >= 0
        self.target = torch.zeros(self.nodes_dim, 769).float()  # 768 là kích thước của CodeBERT

        # Ánh xạ node type sang ID
        self.type_to_id = {}
        self.next_type_id = 0

    def get_type_id(self, node_type):
        """
        Lấy ID cho loại node. Nếu chưa có, thêm vào từ điển.
        """
        if node_type not in self.type_to_id:
            self.type_to_id[node_type] = self.next_type_id
            self.next_type_id += 1
        return self.type_to_id[node_type]

    def __call__(self, nodes):
        embedded_nodes = self.embed_nodes(nodes)
        nodes_tensor = torch.from_numpy(embedded_nodes).float()
        self.target[:nodes_tensor.size(0), :] = nodes_tensor
        return self.target

    def embed_nodes(self, nodes):
        embeddings = []
        for n_id, node in nodes.items():
            node_code = node.get('code', '')
            if not node_code:
                print(f"Node ID {n_id} has no code. Skipping.")
                continue

            # Tokenize và encode code
            tokenized_code = self.tokenizer_bert.tokenize(node_code)
            inputs = self.tokenizer_bert.encode_plus(
                tokenized_code,
                return_tensors="pt",
                max_length=512,
                truncation=True,
                padding="max_length"
            )
            input_ids, attention_mask = inputs["input_ids"].to("cuda"), inputs["attention_mask"].to("cuda")

            # Lấy embedding từ CodeBERT
            with torch.no_grad():
                cls_embedding = self.bert_model(input_ids, attention_mask=attention_mask)[0][:, 0].cpu().numpy()

            # Đảm bảo cls_embedding là 1D
            cls_embedding = cls_embedding.squeeze()

            # Lấy ID cho node type
            node_type = node.get('type', 'Unknown').split(',')[0]
            type_id = self.get_type_id(node_type)
            type_embedding = np.array([type_id], dtype=np.float32)

            # Ghép type embedding với code embedding
            embedding = np.concatenate((type_embedding, cls_embedding), axis=0)
            embeddings.append(embedding)

        return np.array(embeddings)


In [116]:
# Kiểm thử NodesEmbedding
nodes_dim = 128
embedding_model = NodesEmbedding(nodes_dim)

# Nhận embedding cho các node
embeddings = embedding_model(parsed_nodes)

print(f"Embedding Shape: {embeddings.shape}")
print(f"Embedding Example:\n{embeddings[:100]}")


Embedding Shape: torch.Size([128, 769])
Embedding Example:
tensor([[ 0.0000, -0.1113,  0.2505,  ..., -0.1603, -0.2753,  0.2782],
        [ 1.0000, -0.1119,  0.3147,  ..., -0.2378, -0.3043,  0.3429],
        [ 1.0000, -0.1403,  0.2893,  ..., -0.2206, -0.2935,  0.3933],
        ...,
        [ 5.0000, -0.0895,  0.1100,  ..., -0.1717, -0.3729,  0.3513],
        [19.0000, -0.1081,  0.1599,  ..., -0.2375, -0.2775,  0.3235],
        [14.0000, -0.1513,  0.2139,  ..., -0.2190, -0.3627,  0.3414]])


In [122]:
class GraphsEmbedding:
    def __init__(self, edge_type):
        self.edge_type = edge_type

    def __call__(self, nodes):
        connections = self.nodes_connectivity(nodes)
        return torch.tensor(connections).long()

    def nodes_connectivity(self, nodes):
        coo = [[], []]
        for node_idx, (node_id, node) in enumerate(nodes.items()):
            if node_idx != node.get('order', -1):
                raise Exception("Order mismatch in node list")

            for edge in node.get("edges", []):
                if edge['label'] != self.edge_type:
                    continue

                # Add edges to connectivity matrix
                if edge['in'] in nodes and edge['in'] != node_id:
                    coo[0].append(nodes[edge['in']]['order'])
                    coo[1].append(node_idx)

                if edge['out'] in nodes and edge['out'] != node_id:
                    coo[0].append(node_idx)
                    coo[1].append(nodes[edge['out']]['order'])

        return coo


In [124]:
edge_type = "CFG"
graph_embedding = GraphsEmbedding(edge_type)

# Tạo ma trận kết nối
connections = graph_embedding(parsed_nodes)

print("Connectivity Matrix (Edge Index):")
print(connections)

Connectivity Matrix (Edge Index):
tensor([[  0, 117,  16,  52,  63,  78, 115,  94, 112,  10,  10,  11,  11,   0,
          16,  17,  10,  17,  19,  21,  21,  22,  10,  22,  25,  27,  19,  27,
          25,  32,  35,  35,  38,  32,  38,  82,  42,  44,  44,  45,  35,  45,
          42,  49,  49,  52,  53,  49,  53,  56,  56,  57,  57,  58,  58,  60,
          49,  60,  63,  64,  56,  64,  67,  67,  68,  71,  56,  68,  68,  71,
          72,  72,  73,  73,  75,  68,  75,  78,  79,  67,  79,  67,  82,  82,
          87,  87,  88,  88,  89,  89,  91,  82,  91,  94,  95,  87,  95,  98,
          98,  99,  99, 100, 100, 102,  87, 102, 106, 106, 107,  98, 107, 112,
         113, 106, 113, 115,  98, 106, 117,  35],
        [ 11,   4,   4,   4,   4,   4,   4,   4,   4,  17,  22,  10,  10,  11,
           4,  16,  17,  16,  27,  19,  19,  21,  22,  21,  32,  25,  27,  25,
          32,  35,  45, 117,  35,  35,  35,  38,  49,  42,  42,  44,  45,  44,
          49,  53,  60,   4,  52,  53,  52,  64

In [None]:
def nodes_to_input(nodes, target, nodes_dim, edge_type):
    nodes_embedding = NodesEmbedding(nodes_dim)
    graphs_embedding = GraphsEmbedding(edge_type)
    label = torch.tensor([target]).float()

    return Data(x=nodes_embedding(nodes), edge_index=graphs_embedding(nodes), y=label)