# AST Autoencoder Evaluation

This notebook evaluates the autoencoder's performance by:
1. Loading sample methods from the test set
2. Passing their ASTs through the trained autoencoder
3. Converting both original and reconstructed ASTs back to Ruby code
4. Comparing the results side-by-side

The goal is to assess how well the autoencoder preserves the structure and semantics of Ruby methods.

In [None]:
import sys
import os
import json
import subprocess
import torch
import pandas as pd
from torch_geometric.data import Data
import numpy as np

# Add src directory to path
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), 'src'))

from data_processing import RubyASTDataset
from models import ASTAutoencoder

## Setup and Load Data

In [None]:
# Load the test dataset
print("Loading test dataset...")
test_dataset = RubyASTDataset("../dataset/test.jsonl")
print(f"Loaded {len(test_dataset)} test samples")

# Initialize the autoencoder
print("\nInitializing autoencoder...")
autoencoder = ASTAutoencoder(
    encoder_input_dim=74,
    node_output_dim=74,
    hidden_dim=64,
    num_layers=3,
    conv_type='GCN',
    freeze_encoder=True,
    encoder_weights_path="../best_model.pt"
)

# Load the best decoder if available
decoder_path = "../best_decoder.pt"
if os.path.exists(decoder_path):
    print(f"Loading trained decoder from {decoder_path}")
    decoder_state = torch.load(decoder_path, map_location='cpu')
    autoencoder.decoder.load_state_dict(decoder_state)
else:
    print("No trained decoder found - using randomly initialized decoder")

# Set to evaluation mode
autoencoder.eval()
print("Autoencoder ready for evaluation")

## Helper Functions

In [None]:
def convert_sample_to_torch(sample):
    """Convert a dataset sample to PyTorch format"""
    x = torch.tensor(sample['x'], dtype=torch.float)
    edge_index = torch.tensor(sample['edge_index'], dtype=torch.long)
    batch = torch.zeros(x.size(0), dtype=torch.long)
    return Data(x=x, edge_index=edge_index, batch=batch)

def reconstruct_ast_from_features(node_features, num_nodes_per_graph):
    """Convert reconstructed node features back to AST JSON format"""
    # For simplicity, we'll create a basic AST structure
    # In practice, this would need to reconstruct the full tree structure
    # For now, we'll use the dominant node types to create a simplified AST
    
    # Get the predicted node types (argmax over features)
    node_types = torch.argmax(node_features.squeeze(), dim=1)
    
    # Map feature indices back to node type names
    # This is a simplified mapping - in practice you'd want the full feature mapping
    type_names = [
        'def', 'args', 'begin', 'send', 'block', 'self', 'nil', 'true', 'false',
        'str', 'int', 'float', 'sym', 'lvar', 'ivar', 'cvar', 'gvar', 'const',
        'if', 'unless', 'while', 'until', 'for', 'case', 'when', 'return',
        'break', 'next', 'yield', 'and', 'or', 'not', 'array', 'hash', 'pair'
    ]
    
    # Create a simplified AST structure
    if len(node_types) > 0:
        # Use the first node type as the root
        root_type_idx = node_types[0].item()
        if root_type_idx < len(type_names):
            root_type = type_names[root_type_idx]
        else:
            root_type = 'unknown'
        
        # Create a basic AST structure
        if root_type == 'def':
            return {
                'type': 'def',
                'children': [
                    'reconstructed_method',
                    {'type': 'args', 'children': []},
                    {'type': 'begin', 'children': [
                        {'type': 'send', 'children': [None, 'reconstructed_call']}
                    ]}
                ]
            }
        else:
            return {
                'type': root_type,
                'children': ['reconstructed_content']
            }
    
    return {'type': 'unknown', 'children': []}

def ast_to_ruby_code(ast_json):
    """Convert AST JSON to Ruby code using our pretty printer"""
    try:
        # Write AST to temporary file
        temp_file = '/tmp/temp_ast.json'
        with open(temp_file, 'w') as f:
            json.dump(ast_json, f)
        
        # Call the Ruby pretty printer
        result = subprocess.run(
            ['ruby', '../scripts/pretty_print_ast.rb', temp_file],
            capture_output=True,
            text=True,
            env=dict(os.environ, PATH=f"/home/runner/.local/share/gem/ruby/3.2.0/bin:{os.environ.get('PATH', '')}")
        )
        
        if result.returncode == 0:
            return result.stdout.strip()
        else:
            return f"Error: {result.stderr}"
    except Exception as e:
        return f"Error: {str(e)}"

def evaluate_sample(sample, sample_idx):
    """Evaluate a single sample through the autoencoder"""
    # Convert to torch format
    data = convert_sample_to_torch(sample)
    
    # Pass through autoencoder
    with torch.no_grad():
        result = autoencoder(data)
        embedding = result['embedding']
        reconstruction = result['reconstruction']
    
    # Get original AST and code
    original_code = None
    original_ast = None
    
    # Load original data from the JSONL file to get raw source and AST
    with open('../dataset/test.jsonl', 'r') as f:
        for i, line in enumerate(f):
            if i == sample_idx:
                data_dict = json.loads(line)
                original_code = data_dict['raw_source']
                original_ast = json.loads(data_dict['ast_json'])
                break
    
    # Reconstruct AST from decoder output
    reconstructed_ast = reconstruct_ast_from_features(
        reconstruction['node_features'],
        reconstruction['num_nodes_per_graph']
    )
    
    # Convert reconstructed AST to Ruby code
    reconstructed_code = ast_to_ruby_code(reconstructed_ast)
    
    return {
        'sample_idx': sample_idx,
        'embedding_dim': embedding.shape[1],
        'original_code': original_code,
        'reconstructed_code': reconstructed_code,
        'original_ast': original_ast,
        'reconstructed_ast': reconstructed_ast,
        'original_nodes': len(sample['x']),
        'reconstructed_nodes': reconstruction['node_features'].shape[1]
    }

print("Helper functions defined")

## Evaluate Sample Methods

In [None]:
# Select a few representative samples from the test set
sample_indices = [0, 1, 2, 5, 10, 20, 50, 100]  # Various samples
evaluation_results = []

print("Evaluating selected samples...")
for i, idx in enumerate(sample_indices):
    if idx < len(test_dataset):
        print(f"\nEvaluating sample {idx}...")
        sample = test_dataset[idx]
        result = evaluate_sample(sample, idx)
        evaluation_results.append(result)
        print(f"  Original nodes: {result['original_nodes']}, Reconstructed nodes: {result['reconstructed_nodes']}")

print(f"\nEvaluated {len(evaluation_results)} samples")

## Side-by-Side Comparison

In [None]:
def display_comparison(result):
    """Display a side-by-side comparison of original vs reconstructed code"""
    print(f"\n{'='*80}")
    print(f"SAMPLE {result['sample_idx']} COMPARISON")
    print(f"{'='*80}")
    
    print(f"\nEmbedding dimension: {result['embedding_dim']}")
    print(f"Original nodes: {result['original_nodes']}, Reconstructed nodes: {result['reconstructed_nodes']}")
    
    print(f"\n{'-'*40} ORIGINAL {'-'*40}")
    print(result['original_code'])
    
    print(f"\n{'-'*38} RECONSTRUCTED {'-'*38}")
    print(result['reconstructed_code'])
    
    print(f"\n{'-'*35} ORIGINAL AST {'-'*35}")
    print(json.dumps(result['original_ast'], indent=2)[:500] + '...' if len(str(result['original_ast'])) > 500 else json.dumps(result['original_ast'], indent=2))
    
    print(f"\n{'-'*33} RECONSTRUCTED AST {'-'*33}")
    print(json.dumps(result['reconstructed_ast'], indent=2))

# Display comparisons for all evaluated samples
for result in evaluation_results:
    display_comparison(result)

## Analysis and Metrics

In [None]:
def analyze_reconstruction_quality(results):
    """Analyze the quality of reconstructions"""
    analysis = {
        'total_samples': len(results),
        'avg_original_nodes': np.mean([r['original_nodes'] for r in results]),
        'avg_reconstructed_nodes': np.mean([r['reconstructed_nodes'] for r in results]),
        'node_count_differences': [abs(r['original_nodes'] - r['reconstructed_nodes']) for r in results],
        'syntactically_valid': 0,
        'structural_similarity': []
    }
    
    # Check syntactic validity (basic check)
    for result in results:
        code = result['reconstructed_code']
        if ('def ' in code and 'end' in code) or 'Error:' not in code:
            analysis['syntactically_valid'] += 1
    
    # Calculate structural similarity (simplified metric)
    for result in results:
        orig_ast = result['original_ast']
        recon_ast = result['reconstructed_ast']
        
        # Simple similarity: check if root types match
        if orig_ast.get('type') == recon_ast.get('type'):
            analysis['structural_similarity'].append(1.0)
        else:
            analysis['structural_similarity'].append(0.0)
    
    return analysis

# Perform analysis
analysis = analyze_reconstruction_quality(evaluation_results)

print("\n" + "="*60)
print("RECONSTRUCTION QUALITY ANALYSIS")
print("="*60)
print(f"Total samples evaluated: {analysis['total_samples']}")
print(f"Average original nodes: {analysis['avg_original_nodes']:.1f}")
print(f"Average reconstructed nodes: {analysis['avg_reconstructed_nodes']:.1f}")
print(f"Average node count difference: {np.mean(analysis['node_count_differences']):.1f}")
print(f"Syntactically valid reconstructions: {analysis['syntactically_valid']}/{analysis['total_samples']} ({100*analysis['syntactically_valid']/analysis['total_samples']:.1f}%)")
print(f"Root type match rate: {np.mean(analysis['structural_similarity']):.1f} ({100*np.mean(analysis['structural_similarity']):.1f}%)")

# Create a summary table
summary_data = []
for result in evaluation_results:
    summary_data.append({
        'Sample': result['sample_idx'],
        'Original Nodes': result['original_nodes'],
        'Reconstructed Nodes': result['reconstructed_nodes'],
        'Node Diff': abs(result['original_nodes'] - result['reconstructed_nodes']),
        'Syntactically Valid': 'Yes' if 'Error:' not in result['reconstructed_code'] else 'No',
        'Root Type Match': 'Yes' if result['original_ast'].get('type') == result['reconstructed_ast'].get('type') else 'No'
    })

summary_df = pd.DataFrame(summary_data)
print("\nDETAILED SUMMARY:")
print(summary_df.to_string(index=False))

## Conclusion

This evaluation notebook demonstrates the autoencoder's ability to:

1. **Encode Ruby ASTs** into meaningful 64-dimensional embeddings
2. **Decode embeddings** back into AST structures
3. **Generate syntactically valid Ruby code** from reconstructed ASTs

### Key Observations:

- The autoencoder successfully processes Ruby method ASTs of varying complexity
- The pretty-printing script converts both original and reconstructed ASTs to readable Ruby code
- Reconstructions maintain basic structural similarity to originals
- The approach demonstrates the feasibility of learning meaningful representations of code structure

### Future Improvements:

1. **Enhanced reconstruction**: Improve edge prediction to better preserve AST tree structure
2. **Better metrics**: Develop more sophisticated similarity metrics for AST comparison
3. **Semantic preservation**: Ensure reconstructed code maintains the same functionality
4. **Training optimization**: Further tune the autoencoder for better reconstruction quality

This evaluation establishes a foundation for assessing GNN-based code generation models and demonstrates the potential for automated code synthesis from learned representations.