In [1]:
### Parse general AST using Tree Sitter

In [None]:
from process_utils import dfs_graph
import networkx as nx
import matplotlib.pyplot as plt
import json
import os
import shutil
from tqdm.notebook import tqdm

In [None]:
from tree_sitter import Language, Parser
from networkx.readwrite import json_graph

def init_parser(language):
    Language.build_library(
        f"tree_sitter_build/{language}.so",
        [
            f"../../tree_sitter/tree-sitter-{language}",
        ],
    )
    language = Language(f"tree_sitter_build/{language}.so", language)
    lang_parser = Parser()
    lang_parser.set_language(language)
    return lang_parser


def clean_code(code):
    return (
        code.replace(" DCNL DCSP ", "\n\t")
        .replace(" DCNL  DCSP ", "\n\t")
        .replace(" DCNL   DCSP ", "\n\t")
        .replace(" DCNL ", "\n")
        .replace(" DCSP ", "\t")
    )

def read_file(file_name):
    with open(file_name, "r") as f:
        return f.readlines()
    
def get_root(graph):
    for node in graph.nodes():
        if graph.in_degree(node) == 0:
            return node

        
def generate_pairs(origin_code_file, tp):
    origin_code_list = read_file(origin_code_file)
    # Use fullset, do not skip
    ast_list = []
    parser = init_parser('python')
    for i, origin_code in tqdm(enumerate(origin_code_list)):
        origin_code = clean_code(origin_code)
        data_lines = origin_code.splitlines()
        # Parse all codes
        tree = parser.parse(bytes(origin_code, 'utf-8'))
        
        cursor = tree.walk()
        node_lst = []
        language='python'
        G = nx.DiGraph()
        dfs_graph(origin_code, data_lines, cursor.node, G, 0, node_lst, 0, language)
        root = get_root(G)
        json_tree = []
        json_tree = G2json(G, json_tree, root)
        ast_list.append(json.dumps(json_tree, separators=(",", ":"), ensure_ascii=False))
    return ast_list

# Prune AST based on dangling identifier
def prune_ast(G):
    keep_node = set()
    for e in G.nodes():
        if e.split(':')[0] == 'idt':
            tmp = e
            while len([e for e in G.predecessors(tmp)]) <= 1:
                # import ipdb
                # ipdb.set_trace()
                keep_node.add(tmp)
                parents = [e for e in G.predecessors(tmp)]
                if len(parents) == 0:
                    break
                tmp = parents[0]
    
    original_nodes = set(G.nodes)
    remove_nodes = original_nodes - keep_node
    G.remove_nodes_from(list(remove_nodes))
    return G

# if the child's name is same as parent, remove all below the child
def prune_ast_remove_redundant(G):
    remove_node = set()
    for e in G.nodes():
        if e.split(':')[0] == 'idt':
            node_id = e.split(':')[1]
            # if it has parent
            if len([e for e in G.predecessors(e)]) == 1:
                parent_node = [e for e in G.predecessors(e)][0]
                parent_type = parent_node.split(':')[0]
                parent_id = parent_node.split(':')[1]
                if parent_type == 'nont' and parent_id == node_id:
                    remove_node.add(e)
                    print(parent_node)
                    print(e)
                
                for child in G.successors(e):
                    print(f'remove child - {child}')
                    remove_node.add(child)
    G.remove_nodes_from(list(remove_node))
    return G

# if the child's name is same as parent, remove just the child
def prune_ast_remove_redundant_strict(G):
    remove_node = set()
    for e in G.nodes():
        if e.split(':')[0] == 'idt':
            node_id = e.split(':')[1]
            # if it has parent
            if len([e for e in G.predecessors(e)]) == 1:
                parent_node = [e for e in G.predecessors(e)][0]
                parent_type = parent_node.split(':')[0]
                parent_id = parent_node.split(':')[1]
                if parent_type == 'nont' and parent_id == node_id:
                    remove_node.add(e)
                    # print(parent_node)
                    # print(e)
                
                if len([e for e in G.successors(e)]) > 0:
                    child_node = [e for e in G.successors(e)][0]
                    G.add_edge(parent_node, child_node)
    G.remove_nodes_from(list(remove_node))
    return G

def draw_ast(root_first_seq, color_nodes = []):
    G = nx.DiGraph()
    n_label_dict = {}
    node_colors = []
    for idx, node in enumerate(root_first_seq):
        n_label_dict[node.label + ':' + str(node.num)] = node.label + ':' + str(node.num)
        G.add_node(node.label + ':' + str(node.num))
        
        
        if node.num in color_nodes:
            node_colors.append('black')
        else:
            node_colors.append('white')
        
        if node.parent != None:
            G.add_edge(node.parent.label + ':' + str(node.parent.num), node.label + ':' + str(node.num))
    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot')
    plt.figure(1, (60, 60))
    nx.draw(G, pos, node_color=node_colors, with_labels=True)
    plt.show()
    
def G2json(graph, json_tree, root_node, parent=None, new_idx=0):
    json_node = {'children': []}
    json_node['label'] = root_node
    
    if parent != None:
        json_node['parent'] = parent['label']
        parent['children'].append(json_node['label'])
        
    json_tree = [json_node]
    for child in graph.successors(root_node):
        json_tree += G2json(graph, json_tree, child, json_node)
    return json_tree

# In case pruning may have mismatched the node inexs, we use the function to reorder the labels
def reorder_label(json_tree):
    idx2idx = {}
    # new label assignment for every node
    for idx, e in enumerate(json_tree):
        old_label = ':'.join(e['label'].split(':')[:-1])
        
        old_idx = e['label'].split(':')[-1]
        new_idx = idx + 1
        idx2idx[old_idx] = str(new_idx)
        
        new_label = old_label + f':{new_idx}'
        e['label'] = new_label
    
    # giving new node ids to children
    for e in json_tree:
        new_children = []
        for child in e['children']:
            old_label = ':'.join(child.split(':')[:-1])
            old_idx = child.split(':')[-1]
            new_idx = idx2idx[old_idx]
            new_label = old_label + f':{new_idx}'
            new_children.append(new_label)
        e['children'] = new_children
    return json_tree


def save_file(data, file_name):
    with open(file_name, "w") as f:
        for d in data:
            if not d.endswith("\n"):
                d = d + "\n"
            f.write(d)

In [None]:
work_dir = "./py"
out_dir = "./tree_sitter_python"
data_sets = ["test", "dev", "train"]

for data_set in data_sets:
    code_file = work_dir + data_set + "/code.original"
    ast_list = generate_pairs(code_file, data_set)
    save_file(ast_list, out_dir + data_set + "/ast.original")
    


In [None]:
work_dir = "./py"
out_dir = "./tree_sitter_python"
data_sets = ["test", "dev", "train"]
for data_set in data_sets:
    os.remove(out_dir+data_set + '/nl.original')
    shutil.copy(work_dir+data_set+'/javadoc.original', out_dir+data_set+'/nl.original')
    
