In [1]:
import xml.etree.ElementTree as ET
import copy
import string

def isLeaf(node):
    t = True
    for child in node:
        t = False
    return t

def hasToken(node):
    if node.text == None:
        return False
    return node.text.strip() != ''

def nodeIsVisited(node):
    return node.get('visited') != None

def visitNode(node):
    node.set('visited', True)

def getTag(node):
    tag = node.tag
    tag = tag.replace('{http://www.srcML.org/srcML/src}', '')
    if tag_id.get(tag) == None:
        tag_id[tag] = -1
    if node.get('tgid') == None:
        node.set('tgid', tag_id[tag])
        tag_id[tag] += 1
    return f'{tag}:{tag_id[tag]}'
    # return tag

tag_id = {}

def isPunctuation(token):
    return token in string.punctuation

def isComment(node):
    tag = node.tag.replace('{http://www.srcML.org/srcML/src}', '')
    return tag == 'comment'

def dfsPathKeep(node, keepPunctuation = False, keepComment=False):
    # tag_id.clear()
    path_stack = []
    stack = []
    visit_list = []
    
    stack.append(node)
    visited_child_count_list = []
    while (len(stack) > 0):
        currNode = stack.pop()
        # print(currNode)
        if not nodeIsVisited(currNode):
            tag = getTag(currNode)
            path_stack.append(tag)
            if hasToken(currNode):
                token = currNode.text.strip()
                # path_token_str = f'<{path_stack}>:{token}'
                path_token_str = {
                    'path': copy.deepcopy(path_stack),
                    'token': token
                }
                visit_list.append(path_token_str)
                
                if isComment(currNode) and not keepComment:
                    visit_list.remove(path_token_str)
                if isPunctuation(token) and not keepPunctuation:
                    visit_list.remove(path_token_str)
                
            visitNode(currNode)
            visited_child_count = 0
            
            child_list = []
            for child in currNode:
                child_list.append(child)
                
            # reverse to restore the order of nodes
            child_list.reverse()
            for child in child_list:
                if not nodeIsVisited(child):
                    stack.append(child)
                    visited_child_count += 1

            # if len(visit_list) > 0:
            #     print(visited_child_count_list, visited_child_count, visit_list[-1])
            if visited_child_count == 0:
                if len(visited_child_count_list) > 0:
                    visited_child_count_list[-1] -= 1
                    path_stack.pop()
                    # print('pop')
                while len(visited_child_count_list) > 0 and visited_child_count_list[-1] == 0:
                    visited_child_count_list = visited_child_count_list[:-1]
                    if len(visited_child_count_list) > 0:
                        visited_child_count_list[-1] -= 1
                        path_stack.pop()
                        # print('pop3')
                        
                    # path_stack.pop()
                    # print('pop2')
            else:
                visited_child_count_list.append(visited_child_count)
            

    # for v in visit_list:
    #     print(v)
        
    return visit_list

def parseXML(xml_path, *args, **kwargs):
    tree = ET.parse(xml_path)
    
    root = tree.getroot()
    
    dfsPathKeep(root, *args, **kwargs)
    
    
# parseXML('srcxml/juliet.java.0.test.xml')
# parseXML('srcxml/helloworld.java.xml')
# parseXML('srcxml/helloworld.java.xml', keepComment=True)
# parseXML('srcxml/helloworld.java.xml', keepPunctuation=True)
# print("-----")
# parseXML('srcxml/juliet.java.0.test.xml', keepPunctuation=True)

In [2]:
def printMethod(xml_path, *args, **kwargs):
    tree = ET.parse(xml_path)
    
    root = tree.getroot()
    
    for function in root.findall('.//{http://www.srcML.org/srcML/src}function'):
        print(function)
        
printMethod('srcxml/helloworld.java.xml')

<Element '{http://www.srcML.org/srcML/src}function' at 0x7fca700d5360>
<Element '{http://www.srcML.org/srcML/src}function' at 0x7fca700d5c70>


In [3]:
def token_generalize(t):
    return t.translate(str.maketrans('', '', string.punctuation))\
        .translate({ord(c): None for c in string.whitespace}).lower()

# token_id = 0
# token_id_map = {}
# token_frequecy_map = {}

# dep_id = 0
# dep_id_map = {}
# dep_frequecy_map = {}

def extract_dep_path(srcT, srcP, desT, desP, rs_obj, max_dep_len = None):
    global dep_id, dep_id_map, dep_frequecy_map
    intersection = [p for p in srcP if p in desP]
    intersection_point = intersection[-1]
    # print(intersection)
    # print(intersection_point)
    t1_left = [p for p in srcP if p not in intersection]
    t2_left = [p for p in desP if p not in intersection]
    # print(t1_left)
    # print(t2_left)

    # reverse t1 path
    t1_left.reverse()

    # remove id and add arrows
    intersection_point = '-' + intersection_point.split(":")[0] + '-'
    t1_left = [f'{p.split(":")[0]}↑' for p in t1_left]
    t2_left = [f'{p.split(":")[0]}↓' for p in t2_left]

    t1_2_t2_path = [*t1_left, intersection_point, *t2_left]
    dep_len = len(t1_2_t2_path)
    # print(t1_2_t2_path)

    t1_2_t2_path = ''.join(t1_2_t2_path)

    # print(t1, t2)

    if max_dep_len == None or dep_len <= max_dep_len:
        if rs_obj['dep_id_map'].get(t1_2_t2_path) == None:
            rs_obj['dep_id_map'][t1_2_t2_path] = rs_obj['dep_id']
            rs_obj['dep_id'] += 1
        if rs_obj['dep_frequecy_map'].get(t1_2_t2_path) == None:
            rs_obj['dep_frequecy_map'][t1_2_t2_path] = 0
        rs_obj['dep_frequecy_map'][t1_2_t2_path] += 1
    
    return (dep_len, srcT, desT, t1_2_t2_path)
    

def parseMethodXML(xml_path, *args, rs_obj, window_size = None, max_dep_len = None, **kwargs):
    tag_id.clear()
    tree = ET.parse(xml_path)
    
    root = tree.getroot()
    
    for function in root.findall('.//{http://www.srcML.org/srcML/src}function'):
        # print(function)
        
        visit_list = dfsPathKeep(function, *args, **kwargs)
        
        for v in visit_list:
            token = token_generalize(v['token'])
            if token == '':
                continue
            if rs_obj['token_id_map'].get(token) == None:
                rs_obj['token_id_map'][token] = rs_obj['token_id']
                rs_obj['token_id'] += 1
                
            if rs_obj['token_frequecy_map'].get(token) == None:
                rs_obj['token_frequecy_map'][token] = 0
            rs_obj['token_frequecy_map'][token] += 1
        
        dep_triples = []
            
        for i in range(len(visit_list)):
            t1 = visit_list[i]
            t1_path = copy.deepcopy(t1['path'])
            t1_token = t1['token']
            t1_token = token_generalize(t1_token)
            
            # token forward
            if window_size == None:
                forward_window_size = 0
            else:
                forward_window_size = (i - 1 - window_size) if (i - 1 - window_size) >= 0 else 0
            for h in range(forward_window_size, i):
                t0 = visit_list[h]
                t0_path = copy.deepcopy(t0['path'])
                t0_token = t0['token']
                t0_token = token_generalize(t0_token)

                if (t0_token == '' or t1_token == ''):
                    continue
                
                dep_len, srcT, desT, t1_2_t2_path = extract_dep_path(
                    t1_token, t1_path, t0_token, t0_path, rs_obj, max_dep_len=max_dep_len 
                )
                
                if max_dep_len == None or dep_len <= max_dep_len:
                    dep_triples.append((srcT, desT, t1_2_t2_path))
            
            # token backward
            if window_size == None:
                backward_window_size = len(visit_list)
            else:
                backward_window_size = (i + 1 + window_size) if (i + 1 + window_size) <= len(visit_list) else len(visit_list)
            
            for j in range(i + 1, backward_window_size):
                t2 = visit_list[j]
                t2_path = copy.deepcopy(t2['path'])
                t2_token = t2['token']
                t2_token = token_generalize(t2_token)

                if (t1_token == '' or t2_token == ''):
                    continue
                
                dep_len, srcT, desT, t1_2_t2_path = extract_dep_path(
                    t1_token, t1_path, t2_token, t2_path, rs_obj, max_dep_len=max_dep_len
                )
                if max_dep_len == None or dep_len <= max_dep_len:
                    dep_triples.append((srcT, desT, t1_2_t2_path))
                    
        
        rs_obj['sentence_triples'].append(dep_triples)    
        # for triple in dep_triples:
        #     print(triple)
    return rs_obj
            

def printData(rs_obj):
    print()
    print()
    print('----------')
    print('Token ID Map')
    for _token, _id in rs_obj['token_id_map'].items():
        print(_token, _id)
    print()
    print()
    print('Token Fec Map')
    for _token, _fec in rs_obj['token_frequecy_map'].items():
        print(_token, _fec)
    print()
    print()
    print('Dep ID Map')
    for _dep, _id in rs_obj['dep_id_map'].items():
        print(_dep, _id)
    print()
    print()
    print('Dep Fec Map')
    for _dep, _fec in rs_obj['dep_frequecy_map'].items():
        print(_dep, _fec)
        
# parseMethodXML('srcxml/helloworld.java.xml')
rs_obj = parseMethodXML('srcxml/juliet/juliet.java.76.xml', rs_obj = dict(
    token_id = 0,
    token_id_map = {},
    token_frequecy_map = {},
    dep_id = 0,
    dep_id_map = {},
    dep_frequecy_map = {},
    sentence_triples = []
))

printData(rs_obj)




----------
Token ID Map
public 0
void 1
action 2
string 3
data 4
httpservletrequest 5
request 6
httpservletresponse 7
response 8
throws 9
throwable 10
if 11
null 12
uri 13
try 14
new 15
catch 16
urisyntaxexception 17
excepturisyntax 18
getwriter 19
write 20
invalidredirecturl 21
return 22
sendredirect 23


Token Fec Map
public 1
void 1
action 1
string 1
data 4
httpservletrequest 1
request 1
httpservletresponse 1
response 3
throws 1
throwable 1
if 1
null 1
uri 4
try 1
new 1
catch 1
urisyntaxexception 1
excepturisyntax 1
getwriter 1
write 1
invalidredirecturl 1
return 2
sendredirect 1


Dep ID Map
specifier↑-type-name↓ 0
specifier↑type↑-function-name↓ 1
specifier↑type↑-function-parameter_list↓parameter↓decl↓type↓name↓ 2
specifier↑type↑-function-parameter_list↓parameter↓decl↓name↓ 3
specifier↑type↑-function-throws↓ 4
specifier↑type↑-function-throws↓argument↓expr↓name↓ 5
specifier↑type↑-function-block↓block_content↓if_stmt↓if↓ 6
specifier↑type↑-function-block↓block_content↓if_stmt↓if↓con

In [4]:
import os
from tqdm import tqdm

def getWordGCNDataForDataSet(data_set_name, *args, **kwargs):
    
    data_set_xml_root_path = os.path.join('srcxml', data_set_name)
    
    rs_obj = dict(
        token_id = 0,
        token_id_map = {},
        token_frequecy_map = {},
        dep_id = 0,
        dep_id_map = {},
        dep_frequecy_map = {},
        sentence_triples = []
    )
    
    ls = [f for f in os.listdir(data_set_xml_root_path) if f.endswith('.xml')]
    for i in tqdm(range(len(ls))):
        xml_path = os.path.join(data_set_xml_root_path, ls[i])
        parseMethodXML(xml_path, rs_obj = rs_obj, *args, **kwargs)
    
    # printData()
    return rs_obj

In [13]:
def outputData(data_set_name, rs_obj):
    out_dir = os.path.join('srcxml', f'{data_set_name}_out')
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    
    with open(os.path.join(out_dir, f'voc2id.txt'), 'w') as f:
        for _tk, _id in rs_obj['token_id_map'].items():
            f.write(f'{_tk}\t{_id}\r\n')
            
    with open(os.path.join(out_dir, f'id2feq.txt'), 'w') as f:
        for _tk, _feq in rs_obj['token_frequecy_map'].items():
            _id = rs_obj['token_id_map'][_tk]
            f.write(f'{_id}\t{_feq}\r\n')

    with open(os.path.join(out_dir, f'de2id.txt'), 'w') as f:
        for _dep, _id in rs_obj['dep_id_map'].items():
            f.write(f'{_dep}\t{_id}\r\n')
            
    with open(os.path.join(out_dir, f'de2feq.txt'), 'w') as f:
        for _dep, _feq in rs_obj['dep_frequecy_map'].items():
            f.write(f'{_feq}\t\t{_dep}\r\n')
            
    with open(os.path.join(out_dir, f'data.txt'), 'w') as f:
        for sentence_triples in rs_obj['sentence_triples']:
            tk_set = set([])
            dep_path_list = []
            for triple in sentence_triples:
                srcT, desT, dep_path = triple
                srcT_id = rs_obj['token_id_map'][srcT]
                desT_id = rs_obj['token_id_map'][desT]
                dep_path_id = rs_obj['dep_id_map'][dep_path]
                tk_set.add(srcT_id)
                tk_set.add(desT_id)
                
                dep_path_list.append((srcT_id, desT_id, dep_path_id))
                
            tk_id_list = list(tk_set)
            
            # dep_path_list = [f"{sdd[0]}|{sdd[1]}|{sdd[2]}" for sdd in dep_path_list]
            dep_path_list = [f"{tk_id_list.index(sdd[0])}|{tk_id_list.index(sdd[1])}|{sdd[2]}" for sdd in dep_path_list]
            
            f.write(f'{len(tk_id_list)} ')
            f.write(f'{len(dep_path_list)} ')
            f.write(' '.join([str(_id) for _id in tk_id_list]))
            f.write(' ')
            f.write(' '.join(dep_path_list))
            f.write('\r\n')

In [14]:
rs_obj = getWordGCNDataForDataSet('juliet', window_size = None, max_dep_len = None)
outputData('juliet_new', rs_obj)

100%|██████████| 115/115 [00:11<00:00, 10.26it/s]


In [None]:
rs_obj = getWordGCNDataForDataSet('owasp')
outputData('owasp', rs_obj)