In [1]:
import torch
import transformers
import tree_sitter_java as tsjava
from tree_sitter import Language, Parser
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
JAVA_LANGUAGE  = Language(tsjava.language())

In [3]:
parser = Parser(JAVA_LANGUAGE)

# Read and parse the Java file
with open('ant-ivy/src/java/org/apache/ivy/Ivy.java', 'r') as f:
    java_code = f.read()

# Parse the code to get the syntax tree
tree = parser.parse(bytes(java_code, "utf8"))
root_node = tree.root_node

# Traverse and print the parsed tree structure
# def print_tree(node, indent=0):
#     print('  ' * indent + node.type, node.start_point, node.end_point)
#     for child in node.children:
#         print_tree(child, indent + 1)

# print_tree(root_node)

In [8]:
#helper function for parsing each level
def extract_code_fragments(node, code, level: str):
    fragments = []
    if level == "class" and node.type == "class_declaration":
        fragments.append(code[node.start_byte:node.end_byte].decode('utf8'))
    elif level == "method" and node.type == "method_declaration":
        fragments.append(code[node.start_byte:node.end_byte].decode('utf8'))
    elif level == "token" and len(node.children) == 0 and (node.type != 'block_comment' and node.type != 'line_comment'):
        fragments.append(code[node.start_byte:node.end_byte].decode('utf8'))
    
    for child in node.children:
        fragments.extend(extract_code_fragments(child, code, level))
    
    return fragments

#parse one java file and 
def parse_java_file(file_path, level):
    # Read Java file
    with open(file_path, 'rb') as f:
        java_code = f.read()

    # Parse the file
    JAVA_LANGUAGE  = Language(tsjava.language())
    parser = Parser(JAVA_LANGUAGE)
    tree = parser.parse(java_code)
    root_node = tree.root_node

    # Extract code fragments based on level
    fragments = extract_code_fragments(root_node, java_code, level)
    return fragments

#
def extract_from_project(directory, levels=[ "class", "method", "token"]):
    project_fragments = {level: [] for level in levels}
    
    # Traverse directory to find all .java files
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".java"):
                file_path = os.path.join(root, file)
                print(f"Parsing {file_path}")
                
                for level in levels:
                    fragments = parse_java_file(file_path, level)
                    project_fragments[level].extend(fragments)
                    
    return project_fragments

In [None]:
# print(extract_code_fragments(root_node, bytes(java_code, "utf8"), level = "token"))


In [10]:
levels = ["package", "class", "method", "token"]
# 
src_fragments = {level: [] for level in levels}

for level in levels:
        fragments = extract_code_fragments(root_node, bytes(java_code, "utf8"), level)
        src_fragments[level].extend(fragments)

In [None]:
# Output each level of granularity
for level, fragments in src_fragments.items():
    print(f"\n--- {level.capitalize()} Level ---")
    for fragment in fragments:
        print(fragment)