In [40]:
from tree_sitter import Language, Parser
import tree_sitter_java as ts_java
import torch
from torch_geometric.data import Data

<h3>Get AST</h3>

In [41]:
java_LANGUAGE = Language(ts_java.language())

In [42]:
original_file = './fire14-source-code-training-dataset/java/000.java'
with open(original_file, 'r', encoding='utf-8') as file:
    code = file.read()

In [43]:
parser = Parser(java_LANGUAGE)
code = 'System.out.println("Hello World");'
tree = parser.parse(bytes(code, 'utf8'))
root_node = tree.root_node
# print(root_node)

<h3>Travers AST to get nodes and edges</h3>

In [44]:
nodes = []
edges = []

def traverse(node, parent_idx=None):
    idx = len(nodes)
    nodes.append(node.type)
    
    if parent_idx is not None:
        edges.append((parent_idx, idx))
    
    for child in node.children:
        traverse(child, idx)

traverse(root_node)

print("Nodes:")
for i, node in enumerate(nodes):
    print(f"{i}: {node}")

print("\nEdges:")
for parent_idx, child_idx in edges:
    print(f"{parent_idx} -> {child_idx}")


Nodes:
0: program
1: expression_statement
2: method_invocation
3: field_access
4: identifier
5: .
6: identifier
7: .
8: identifier
9: argument_list
10: (
11: string_literal
12: "
13: string_fragment
14: "
15: )
16: ;

Edges:
0 -> 1
1 -> 2
2 -> 3
3 -> 4
3 -> 5
3 -> 6
2 -> 7
2 -> 8
2 -> 9
9 -> 10
9 -> 11
11 -> 12
11 -> 13
11 -> 14
9 -> 15
1 -> 16


In [45]:
node_types = list(set(nodes))
node_type_to_idx = {typ: i for i, typ in enumerate(node_types)}
node_features = [node_type_to_idx[typ] for typ in nodes]

print("\nNode Features:")
for i, feature in enumerate(node_features):
    print(f"Node {i}: {feature}")


Node Features:
Node 0: 7
Node 1: 3
Node 2: 6
Node 3: 2
Node 4: 9
Node 5: 12
Node 6: 9
Node 7: 12
Node 8: 9
Node 9: 1
Node 10: 8
Node 11: 11
Node 12: 4
Node 13: 10
Node 14: 4
Node 15: 5
Node 16: 0


In [47]:
num_node_types = len(node_types)
x = torch.eye(num_node_types)[torch.tensor(node_features)]
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
data = Data(x=x, edge_index=edge_index)

print("\nGraph Data:")
print(data)
print("Number of nodes:", data.num_nodes)
print("Number of edges:", data.num_edges)
print("Node features shape:", data.x.shape)
print("Edge index shape:", data.edge_index.shape)


Graph Data:
Data(x=[17, 13], edge_index=[2, 16])
Number of nodes: 17
Number of edges: 16
Node features shape: torch.Size([17, 13])
Edge index shape: torch.Size([2, 16])
