In [None]:
from tree_sitter import Language, Parser
import networkx as nx
import matplotlib.pyplot as plt
from regraph import NXGraph, Rule
from regraph import plot_graph, plot_instance, plot_rule

In [None]:
code_1 = """
import numpy as np

def perform_bernoulli_trials(n, p):

    # Initialize number of successes: n_success
    n_success = 0

    # Perform trials
    for i in range(n):
        # Choose random number between zero and one: random_number
        random_number = np.random.random()

        # If less than p, it's a success so add one to n_success
        if random_number < p:
            n_success += 1

    return n_success
"""

code_2 = """
import numpy as np

def perform_bernoulli_trials(n, p):

    # Initialize number of successes: n_success
    n_success = 0

    return n_success
"""

code_3 = """
import numpy as np
import libs as l

def perform_bernoulli_trials(n, p):

    # Initialize number of successes: n_success
    n_success = 0
    return n_success
    
print("Python is great!")

np.ndarray(size=(5,5))

x = 5 * 3
"""

In [None]:
PY_LANGUAGE = Language('build/my-languages.so', 'python')
python_parser = Parser()
python_parser.set_language(PY_LANGUAGE)

In [None]:
def parse_py(code):
    
    # get a tree
    tree = python_parser.parse(bytes(code, "utf8"))

    # traverse tree to get nodes & edges
    G = NXGraph()
    
    G = bfs_tree_traverser(tree.root_node, G)
    return G

In [None]:
def bfs_tree_traverser(root_node, G):
    # node_id = id of current node being traversed
    # parent_id = id of the parent of the current node
    node_id, parent_id = 0, 0
    # lists to queue the nodes in order and identify already visited nodes
    visited, queue = [], []

    visited.append(root_node)
    queue.append(root_node)

    # add root_node to the graph
    G.add_node(0, attrs={"type": root_node.type, "text": root_node.text})

    # loop to visit each node
    while queue:
        node = queue.pop(0)

        for child_node in node.children:
            if child_node not in visited:
                node_id += 1
                # add child node to graph
                G.add_node(node_id, attrs={"type": child_node.type, "text": child_node.text, "parent_id":parent_id})
                # add edge between parent_node and child_node
                G.add_edge(parent_id, node_id)

                visited.append(child_node)
                queue.append(child_node)

        # set parent_id to the id of the next node in queue
        parent_id = parent_id + 1

    return G

In [None]:
G = parse_py(code_3)
type(G)

In [None]:
pattern_from_graph = G
rule3 = Rule.from_transform(pattern_from_graph)
plot_rule(rule3)

In [None]:
json_rules = rule3.to_json()
print(json_rules)

--------------------------------------

<h2>Graph Simplification by Node Generalization</h2>

In [None]:
parsed_json = G.to_json()

# print clear view of all nodes and their edges
print("List of nodes: ")
for n, attrs in G.nodes(data=True):
    print("\t", n, attrs)
print("List of edges: ")
for s, t, attrs in G.edges(data=True):
    print("\t{}->{}".format(s, t), attrs)

In [None]:
def remove_descendants(node_type, instances, rule):
    print(node_type)
    for ins in instances:
        node_id = ins[node_type]
        print(str(node_id) + ":")
        desc = G.descendants(node_id)
        print(desc)
        for id in list(desc):
            rule.inject_remove_node(id)
    return rule


def remove_everything_else(ids, rule, num_nodes):
    for i in range(num_nodes):
        if i not in ids:
            rule.inject_remove_node(i)
    return rule


def get_ids(node_type, instances):
    ids = []
    for ins in instances:
        node_id = ins[node_type]
        ids.append(node_id)
    return ids

def create_simple_pattern(attr_name, node_type):
    pattern = NXGraph()
    pattern.add_node(attr_name)
    pattern.add_node_attrs(attr_name, {"type" : node_type})
    return pattern

def add_attrs_from_patterns(ids, patterns, is_import):

    for id in ids:
        subg_nodes = list(G.descendants(id)) if is_import else list(G.successors(id))
        subg_nodes.append(id)
        subgraph = G.generate_subgraph(G, subg_nodes)
        
        for patt in patterns:
            
            instances = subgraph.find_matching(patt[1])
            sub_id = get_ids(patt[0], instances)
            
            G.add_node_attrs(id, attrs={patt[0] : subgraph.get_node(sub_id[0])["text"]})
            
            
def connect_parent_and_children(ids, rule):
    for id in ids:
        parent_id = list(G.predecessors(id))[0]
        for child_id in list(G.successors(id)):
            rule.inject_add_edge(parent_id, child_id)
    return rule

In [None]:
id = G.get_node(17)["parent_id"]

In [None]:
list(G.successors(3))

----------

<b>Create patterns to select the "relevant" nodes:</b>

In [None]:
pattern1 = create_simple_pattern("import", "import_statement")
pattern2 = create_simple_pattern("function_def", "function_definition")
pattern3 = create_simple_pattern("function_call", "call")
pattern4 = create_simple_pattern("var_assignment", "assignment")
pattern5 = create_simple_pattern("first_node", "module")
pattern6 = create_simple_pattern("code_block", "block")
pattern7 = create_simple_pattern("if", "if_statement")
pattern8 = create_simple_pattern("for", "for_statement")
pattern9 = create_simple_pattern("expr_statement", "expression_statement")

<h3>Approach: Specify "relevant" nodes and delete the rest</h3>

<b>Get the IDs of all the "relevant" nodes:</b>

In [None]:
rule = Rule.from_transform(G)
all_ids = []

In [None]:
# pattern matching of import_statements
instances = G.find_matching(pattern1)
this_ids = get_ids("import", instances)
all_ids += this_ids

patterns = []

# pattern to get the library name
patt = create_simple_pattern("library", "dotted_name")
patterns.append(("library", patt))

# pattern to get the alias
patt = create_simple_pattern("alias", "identifier")
patterns.append(("alias", patt))

add_attrs_from_patterns(this_ids, patterns, True)

In [None]:
# pattern matching of function_definition
instances = G.find_matching(pattern2)
this_ids = get_ids("function_def", instances)
all_ids += this_ids

patterns = []

# pattern to get the function name
patt = create_simple_pattern("function", "identifier")
patterns.append(("function", patt))

# pattern to get the parameters
patt = create_simple_pattern("parameters", "parameters")
patterns.append(("parameters", patt))

add_attrs_from_patterns(this_ids, patterns, False)

In [None]:
# pattern matching of function call
instances = G.find_matching(pattern3)
this_ids = get_ids("function_call", instances)
all_ids += this_ids

patterns = []

# pattern to get the function arguments
patt = create_simple_pattern("arguments", "argument_list")
patterns.append(("arguments", patt))

add_attrs_from_patterns(this_ids, patterns, False)

In [None]:
instances = G.find_matching(pattern6)
this_ids = get_ids("code_block", instances)
#all_ids += this_ids

rule = connect_parent_and_children(this_ids, rule)

# for id in this_ids:
#     print(id)
#     rule.inject_remove_node(id)

In [None]:
instances = G.find_matching(pattern9)
this_ids = get_ids("expr_statement", instances)
#all_ids += this_ids

rule = connect_parent_and_children(this_ids, rule)

# for id in this_ids:
#     print(id)
#     rule.inject_remove_node(id)

In [None]:
instances = G.find_matching(pattern4)
this_ids = get_ids("var_assignment", instances)
all_ids += this_ids

instances = G.find_matching(pattern5)
this_ids += get_ids("first_node", instances)
all_ids += this_ids

instances = G.find_matching(pattern7)
this_ids += get_ids("if", instances)
all_ids += this_ids

instances = G.find_matching(pattern8)
this_ids += get_ids("for", instances)
all_ids += this_ids


<b>Delete all nodes that are not "relevant":</b>

In [None]:
rule = remove_everything_else(all_ids, rule, len(parsed_json["nodes"]))

In [None]:
plot_rule(rule)

In [None]:
rule.rhs.to_json()

--------

<h3>Approach: Specify "relevant" nodes and remove all their descendants</h3>

In [None]:
rule = Rule.from_transform(G)

instances1 = G.find_matching(pattern1)
rule = remove_descendants("import", instances1, rule)

instances2 = G.find_matching(pattern2)
rule = remove_descendants("function", instances2, rule)

instances3 = G.find_matching(pattern3)
rule = remove_descendants("expression", instances3, rule)

plot_rule(rule)

<b>Won't work since we won't be able to analyze the content of a function for example, since everything inside would be deleted. Nested functions wouldn't be possible.</b>

----------------------------------

<h2>Playground</h2>

In [None]:
# Create an empty graph object
graph = NXGraph()

# Add a list of nodes, optionally with attributes
graph.add_nodes_from(
    [
        'Alice',
        ('Bob', {'age': 15, 'gender': 'male'}),
        ('Jane', {'age': 40, 'gender': 'female'}),
        ('Eric', {'age': 55, 'gender': 'male'})
])

# Add a list of edges, optionally with attributes
graph.add_edges_from([
    ("Alice", "Bob"),
    ("Jane", "Bob", {"type": "parent", "since": 1993}),
    ("Eric", "Jane", {"type": "friend", "since": 1985}),
    ("Eric", "Alice", {"type": "parent", "since": 1992}),
])

graph_json = graph.to_json()

In [None]:
print("List of nodes: ")
for n, attrs in graph.nodes(data=True):
    print("\t", n, attrs)
print("List of edges: ")
for s, t, attrs in graph.edges(data=True):
    print("\t{}->{}".format(s, t), attrs)

In [None]:
rule = Rule.from_transform(graph)
plot_rule(rule)

In [None]:
pattern = NXGraph()
pattern.add_nodes_from(["x"])
pattern.add_node_attrs("x", {"age" : 15})

In [None]:
instances = graph.find_matching(pattern)
instances

In [None]:
print("List of nodes: ")
for n, attrs in graph.nodes(data=True):
    print("\t", n, attrs)
print("List of edges: ")
for s, t, attrs in graph.edges(data=True):
    print("\t{}->{}".format(s, t), attrs)