# Abstract Syntax Tree (AST) Preprocessing for Machine Learning Models
Converting source code into a format suitable for machine learning models requires several transformation steps. This document outlines the comprehensive preprocessing pipeline that transforms raw code into vectorized representations that machine learning models can process effectively.

## Key Components
1. AST Flattening
   1. The `flatten_ast` function captures both node types and structural information
   2. Tracks parent-child relationships via the path parameter
   3. Extracts values from nodes when available
2. Tokenization Strategy
   1. Creates three types of tokens:
      1. Node type tokens (`TYPE_X`)
      2. Structural relationship tokens (`PARENT_X_TO_Y`)
      3. Value tokens for identifiers and literals (`VAL_X` or `LIT_type`)
   2. This preserves both syntactic structure and semantic information
3. Vectorization Options
   1. Two complementary approaches:
      1. Sequence-based: Preserves order of AST nodes using vocabulary mapping
      2. Bag-of-nodes: Creates frequency-based vector representations, useful for classification tasks
4. Vocabulary Management:
   1. Creates a vocabulary with frequency thresholding
   2. Includes special tokens for padding and unknown tokens
   3. Enables consistent encoding across different code samples


In [47]:
import javalang
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from collections import defaultdict
import os
import pickle
import tensorflow as tf
import xml.etree.ElementTree as ET

In [48]:
def read_code_file(file_path):
    """Read code from a file."""
    try:
        with open(file_path, "r", encoding="utf-8") as file:
            return file.read()
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return None

In [49]:
def create_ast(code):
    """
    Creates an Abstract Syntax Tree (AST) from the given code.

    Args:
        code (str): The code to parse.

    Returns:
        javalang.tree.CompilationUnit: The AST of the code.
    """
    try:
        tree = javalang.parse.parse(code)
        return tree
    except javalang.parser.JavaSyntaxError as e:
        print(f"Syntax error in code: {e}")
        return None

In [50]:
import xml.etree.ElementTree as ET


def ast_to_xml(node, parent_elem=None):
    """
    Converts an AST node to XML format.

    Args:
        node: The current AST node
        parent_elem: The parent XML element

    Returns:
        xml.etree.ElementTree.Element: Root element of the XML tree
    """
    if node is None:
        return None

    # Create root element if this is the first call
    if parent_elem is None:
        root = ET.Element("ast")
    else:
        root = parent_elem

    # Create element for current node
    node_type = node.__class__.__name__
    elem = ET.SubElement(root, node_type)

    # Add attributes if they exist
    if hasattr(node, "name"):
        elem.set("name", str(node.name))
    if hasattr(node, "value"):
        elem.set("value", str(node.value))

    # Process children
    if hasattr(node, "children"):
        for child in node.children:
            if isinstance(child, list):
                list_elem = ET.SubElement(elem, "list")
                for item in child:
                    if hasattr(item, "__class__"):
                        ast_to_xml(item, list_elem)
            elif hasattr(child, "__class__"):
                ast_to_xml(child, elem)

    return root

In [51]:
# Modify the process_dataset function to use XML conversion
def process_dataset(dataset_path):
    """
    Process all files in the dataset and convert ASTs to XML.

    Args:
        dataset_path: Path to the dataset directory

    Returns:
        List of XML strings representing the ASTs
    """
    all_xml_asts = []

    for file in os.listdir(dataset_path):
        file_path = os.path.join(dataset_path, file)
        code = read_code_file(file_path)

        if code:
            tree = create_ast(code)
            if tree:
                xml_tree = ast_to_xml(tree)
                xml_string = ET.tostring(xml_tree, encoding="unicode", method="xml")
                all_xml_asts.append(xml_string)
            else:
                print(f"Failed to create AST for {file}.")
        else:
            print(f"Failed to read code from {file}.")

    return all_xml_asts

In [52]:
# Modified save function to save XML data
def save_processed_data(data, output_file):
    """Save processed XML data to disk."""
    with open(output_file, "w", encoding="utf-8") as f:
        for xml_ast in data:
            f.write(xml_ast + "\n")

## Data Processing

In [54]:
dataset_path = "../../datasets/conplag_preprocessed"
processed_data = process_dataset(dataset_path)
save_processed_data(processed_data, "ast_xml_data.txt")