# Abstract Syntax Tree (AST) Preprocessing for Machine Learning Models
Converting source code into a format suitable for machine learning models requires several transformation steps. This document outlines the comprehensive preprocessing pipeline that transforms raw code into vectorized representations that machine learning models can process effectively.

## Key Components
1. AST Flattening
   1. The `flatten_ast` function captures both node types and structural information
   2. Tracks parent-child relationships via the path parameter
   3. Extracts values from nodes when available
2. Tokenization Strategy
   1. Creates three types of tokens:
      1. Node type tokens (`TYPE_X`)
      2. Structural relationship tokens (`PARENT_X_TO_Y`)
      3. Value tokens for identifiers and literals (`VAL_X` or `LIT_type`)
   2. This preserves both syntactic structure and semantic information
3. Vectorization Options
   1. Two complementary approaches:
      1. Sequence-based: Preserves order of AST nodes using vocabulary mapping
      2. Bag-of-nodes: Creates frequency-based vector representations, useful for classification tasks
4. Vocabulary Management:
   1. Creates a vocabulary with frequency thresholding
   2. Includes special tokens for padding and unknown tokens
   3. Enables consistent encoding across different code samples


In [50]:
import javalang
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from collections import defaultdict
import os
import pickle
import tensorflow as tf

In [51]:
def read_code_file(file_path):
    """Read code from a file."""
    try:
        with open(file_path, "r", encoding="utf-8") as file:
            return file.read()
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return None

In [52]:
def create_ast(code):
    """
    Creates an Abstract Syntax Tree (AST) from the given code.

    Args:
        code (str): The code to parse.

    Returns:
        javalang.tree.CompilationUnit: The AST of the code.
    """
    try:
        tree = javalang.parse.parse(code)
        return tree
    except javalang.parser.JavaSyntaxError as e:
        print(f"Syntax error in code: {e}")
        return None

In [53]:
def flatten_ast(node, token_types=None, path="", result=None):
    """
    Flattens an AST into a sequence of node types with their paths.

    Args:
        node: The current AST node
        token_types: Dictionary to track seen token types
        path: Current path in the AST
        result: List to collect flattened nodes

    Returns:
        List of tuples: (node_type, path, value)
    """
    if result is None:
        result = []

    if token_types is None:
        token_types = {}

    # Skip if node is None
    if node is None:
        return result

    # Process the current node
    node_type = node.__class__.__name__

    # Track token types
    if node_type not in token_types:
        token_types[node_type] = len(token_types)

    # Extract value if available (for identifiers, literals, etc.)
    value = None
    if hasattr(node, "name"):
        value = node.name
    elif hasattr(node, "value"):
        value = node.value

    # Add the current node to the result
    result.append((node_type, path, value))

    # Recursively process children
    if hasattr(node, "children"):
        for i, child in enumerate(node.children):
            child_path = f"{path}/{node_type}_{i}"
            if isinstance(child, list):
                for j, item in enumerate(child):
                    if hasattr(item, "__class__"):
                        flatten_ast(item, token_types, f"{child_path}_{j}", result)
            elif hasattr(child, "__class__"):
                flatten_ast(child, token_types, child_path, result)

    return result

In [54]:
def tokenize_ast(flattened_ast):
    """
    Convert a flattened AST to a sequence of tokens.

    Args:
        flattened_ast: List of (node_type, path, value) tuples

    Returns:
        List of string tokens
    """
    tokens = []

    for node_type, path, value in flattened_ast:
        # Add node type as token
        tokens.append(f"TYPE_{node_type}")

        # Add simplified path to capture structural information
        path_components = path.split("/")
        if len(path_components) > 1:
            parent = path_components[-2].split("_")[0]
            tokens.append(f"PARENT_{parent}_TO_{node_type}")

        # Add value if present (for identifiers, literals, etc.)
        if value is not None:
            # Handle different types of values
            if isinstance(value, str):
                # For identifiers, method names, etc.
                tokens.append(f"VAL_{value}")
            elif isinstance(value, (int, float, bool)):
                # For numeric literals
                tokens.append(f"LIT_{type(value).__name__}")

    return tokens

In [55]:
def create_vocabulary(all_tokens, min_freq=2):
    """
    Create a vocabulary from all tokens.

    Args:
        all_tokens: List of token lists
        min_freq: Minimum frequency for a token to be included

    Returns:
        Dictionary mapping tokens to indices
    """
    # Count token frequencies
    token_counts = defaultdict(int)
    for tokens in all_tokens:
        for token in tokens:
            token_counts[token] += 1

    # Create vocabulary with tokens that meet minimum frequency
    vocabulary = {"<PAD>": 0, "<UNK>": 1}
    for token, count in token_counts.items():
        if count >= min_freq:
            vocabulary[token] = len(vocabulary)

    return vocabulary

In [56]:
def vectorize_tokens(tokens, vocabulary, max_length=None):
    """
    Convert token list to vector using vocabulary.

    Args:
        tokens: List of tokens
        vocabulary: Token to index mapping
        max_length: Maximum length of vector (pad/truncate)

    Returns:
        Numpy array of token indices
    """
    if max_length is None:
        max_length = len(tokens)

    # Convert tokens to indices
    vector = []
    for i, token in enumerate(tokens[:max_length]):
        if token in vocabulary:
            vector.append(vocabulary[token])
        else:
            vector.append(vocabulary["<UNK>"])

    # Pad if necessary
    if len(vector) < max_length:
        vector.extend([vocabulary["<PAD>"]] * (max_length - len(vector)))

    return np.array(vector)

In [57]:
def create_embeddings(vocabulary, embedding_dim=100):
    """Create initial random embeddings for tokens."""
    np.random.seed(42)  # For reproducibility
    vocab_size = len(vocabulary)
    embeddings = np.random.normal(0, 1, (vocab_size, embedding_dim))
    return embeddings

In [None]:
def process_dataset(dataset_path, embedding_dim=100):
    """
    Process all files in the dataset.

    Args:
        dataset_path: Path to the dataset directory
        embedding_dim: Dimension for token embeddings

    Returns:
        Dictionary with processed data
    """
    token_types = {}
    all_flattened_asts = []
    all_tokenized_asts = []
    file_paths = []

    for file in os.listdir(dataset_path):
        file_path = os.path.join(dataset_path, file)
        code = read_code_file(file_path)

        if code:
            tree = create_ast(code)
            if tree:
                flattened = flatten_ast(tree, token_types)
                all_flattened_asts.append(flattened)

                # Tokenize AST
                tokens = tokenize_ast(flattened)
                all_tokenized_asts.append(tokens)

                file_paths.append(file_path)
            else:
                print(f"Failed to create AST for {file}.")
        else:
            print(f"Failed to read code from {file}.")

    # Create vocabulary from all tokens
    vocabulary = create_vocabulary(all_tokenized_asts)

    # Create embeddings for vocabulary
    embeddings = create_embeddings(vocabulary, embedding_dim)

    # Get max sequence length for padding
    max_length = max(len(tokens) for tokens in all_tokenized_asts)

    # Vectorize all token sequences
    vectorized_sequences = [
        vectorize_tokens(tokens, vocabulary, max_length)
        for tokens in all_tokenized_asts
    ]

    return {
        "token_types": token_types,
        "vocabulary": vocabulary,
        "embeddings": embeddings,
        "flattened_asts": all_flattened_asts,
        "tokenized_asts": all_tokenized_asts,
        "sequence_vectors": vectorized_sequences,
        "file_paths": file_paths,
    }

In [59]:
def prepare_for_deep_learning(processed_data, batch_size=32):
    """
    Prepare data for deep learning models.

    Args:
        processed_data: Dictionary with processed data
        batch_size: Batch size for training

    Returns:
        Dataset ready for deep learning training
    """
    # Convert to appropriate format
    sequence_vectors = np.array(processed_data["sequence_vectors"])

    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices(sequence_vectors)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset, processed_data["embeddings"]

In [60]:
def save_processed_data(data, output_file):
    """Save processed data to disk."""
    with open(output_file, "wb") as f:
        pickle.dump(data, f)

## Data Processing

In [61]:
dataset_path = "../../datasets/conplag_preprocessed"
processed_data = process_dataset(dataset_path)

print(f"Processed {len(processed_data['file_paths'])} files")
print(f"Found {len(processed_data['token_types'])} unique token types")
print(f"Vocabulary size: {len(processed_data['vocabulary'])}")

# Save the processed data
save_processed_data(processed_data, "ast_processed_data.pkl")

# Example: accessing the first tokenized AST
if processed_data["tokenized_asts"]:
    print("\nSample of first tokenized AST:")
    print(processed_data["tokenized_asts"][0][:20])  # First 20 tokens

30306
Processed 975 files
Found 57 unique token types
Vocabulary size: 2666

Sample of first tokenized AST:
['TYPE_CompilationUnit', 'TYPE_Import', 'PARENT__TO_Import', 'TYPE_str', 'PARENT_CompilationUnit_TO_str', 'TYPE_bool', 'PARENT_CompilationUnit_TO_bool', 'TYPE_bool', 'PARENT_CompilationUnit_TO_bool', 'TYPE_Import', 'PARENT__TO_Import', 'TYPE_str', 'PARENT_CompilationUnit_TO_str', 'TYPE_bool', 'PARENT_CompilationUnit_TO_bool', 'TYPE_bool', 'PARENT_CompilationUnit_TO_bool', 'TYPE_Import', 'PARENT__TO_Import', 'TYPE_str']


In [62]:
# For deep learning workflows
dataset, embeddings = prepare_for_deep_learning(processed_data)

print(f"Dataset prepared with {len(processed_data['file_paths'])} samples")
print(f"Embedding matrix shape: {embeddings.shape}")
print(
    f"Embedding example: {embeddings[0][:20]}"
)  # First 5 values of the first embedding

Dataset prepared with 975 samples
Embedding matrix shape: (2666, 100)
Embedding example: [ 0.49671415 -0.1382643   0.64768854  1.52302986 -0.23415337 -0.23413696
  1.57921282  0.76743473 -0.46947439  0.54256004 -0.46341769 -0.46572975
  0.24196227 -1.91328024 -1.72491783 -0.56228753 -1.01283112  0.31424733
 -0.90802408 -1.4123037 ]
