# Advanced Model Export Features Demo

This notebook demonstrates the new advanced export features:
1. **JIT Graph Dumping** - Extract module hierarchy before ONNX conversion
2. **FX Graph Export** - Alternative representation (dynamo=False)
3. **Comprehensive Analysis** - Compare different export methods

## Key Research Findings
- ✅ **Context preservation**: Module hierarchy IS available in TorchScript graphs
- ✅ **Interception point**: We can capture scope info before ONNX conversion
- ✅ **Multiple alternatives**: FX graphs provide different perspectives

In [None]:
import subprocess
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os

# Setup
plt.style.use('default')
sns.set_palette("husl")
os.makedirs('../temp/notebook_exports', exist_ok=True)

print("📚 Notebook setup complete!")

## 1. Full Feature Export

Let's run the complete export with all new features enabled:

In [None]:
# Run the full export with all features
cmd = [
    "uv", "run", "modelexport", "--verbose", "export",
    "prajjwal1/bert-tiny", 
    "../temp/notebook_exports/bert_full_demo.onnx",
    "--config", "../export_config_bertmodel.json",
    "--jit-graph",
    "--fx-graph", "both",
    "--strategy", "htp"
]

print("🚀 Running full export with all advanced features...")
print(f"Command: {' '.join(cmd)}")
print("\n" + "="*60)

result = subprocess.run(cmd, capture_output=True, text=True, cwd="..")

print("STDOUT:")
print(result.stdout)

if result.stderr:
    print("\nSTDERR:")
    print(result.stderr)

print(f"\n✅ Export completed with return code: {result.returncode}")

## 2. JIT Graph Analysis

Let's analyze the captured TorchScript graph information:

In [None]:
# Load JIT graph information
jit_info_path = "../temp/notebook_exports/bert_full_demo_jit_debug/jit_graph_info.json"

if Path(jit_info_path).exists():
    with open(jit_info_path, 'r') as f:
        jit_data = json.load(f)
    
    print("🔍 JIT GRAPH ANALYSIS RESULTS")
    print("="*50)
    
    # Basic statistics
    graph_info = jit_data['graph_info']
    scope_stats = jit_data['scope_statistics']
    
    print(f"📊 Graph Statistics:")
    print(f"   Total nodes: {graph_info['total_nodes']}")
    print(f"   Graph string length: {graph_info['graph_string_length']:,} chars")
    print(f"   Extraction success: {scope_stats['extraction_success']}")
    print(f"   Total scopes found: {scope_stats['total_scopes']}")
    
    # Scope hierarchy analysis
    hierarchy = jit_data['unified_scope_hierarchy']
    print(f"\n🌳 Scope Hierarchy:")
    print(f"   Unique scopes: {hierarchy['total_unique_scopes']}")
    
    # Show module-specific scopes (not execution paths)
    module_scopes = [s for s in hierarchy['scope_list'] if '__module' in s]
    print(f"\n🎯 Module Scopes Found ({len(module_scopes)}):")
    for i, scope in enumerate(module_scopes, 1):
        print(f"   {i:2d}. {scope}")
    
    # Coverage analysis
    if 'coverage_analysis' in scope_stats:
        coverage = scope_stats['coverage_analysis']
        print(f"\n📈 Coverage Analysis:")
        print(f"   BERT modules: {'✅' if coverage['has_bert_modules'] else '❌'}")
        print(f"   Attention modules: {'✅' if coverage['has_attention_modules'] else '❌'}")
        print(f"   Layer modules: {'✅' if coverage['has_layer_modules'] else '❌'}")
        print(f"   Unique module types: {coverage['unique_module_types']}")
    
    # Depth analysis visualization
    if 'scope_analysis' in hierarchy:
        depth_dist = hierarchy['scope_analysis']['depth_distribution']
        
        plt.figure(figsize=(10, 6))
        
        # Convert string keys to integers for proper sorting
        depths = [int(k) for k in depth_dist.keys()]
        counts = [depth_dist[str(d)] for d in depths]
        
        plt.subplot(1, 2, 1)
        plt.bar(depths, counts, alpha=0.7, color='skyblue')
        plt.xlabel('Scope Depth')
        plt.ylabel('Count')
        plt.title('Scope Depth Distribution')
        plt.grid(axis='y', alpha=0.3)
        
        # Module types
        module_types = hierarchy['scope_analysis']['module_types']
        if module_types:
            plt.subplot(1, 2, 2)
            plt.pie([1] * len(module_types), labels=module_types[:8], autopct='', startangle=90)
            plt.title(f'Module Types Found ({len(module_types)} total)')
        
        plt.tight_layout()
        plt.show()
        
else:
    print("❌ JIT graph information not found. Export may have failed.")

## 3. FX Graph Analysis

Let's examine the FX graph export results:

In [None]:
# Load FX graph information
fx_info_path = "../temp/notebook_exports/bert_full_demo_fx_graph.json"

if Path(fx_info_path).exists():
    with open(fx_info_path, 'r') as f:
        fx_data = json.load(f)
    
    print("🔧 FX GRAPH ANALYSIS RESULTS")
    print("="*50)
    
    # Check which methods succeeded
    methods_tried = []
    successful_methods = []
    
    if 'symbolic_trace_results' in fx_data:
        methods_tried.append('symbolic_trace')
        if fx_data['symbolic_trace_results'].get('success'):
            successful_methods.append('symbolic_trace')
    
    if 'torch_export_results' in fx_data:
        methods_tried.append('torch_export')
        if fx_data['torch_export_results'].get('success'):
            successful_methods.append('torch_export')
    
    # If single method was used
    if 'method' in fx_data:
        methods_tried = [fx_data['method']]
        if fx_data.get('success'):
            successful_methods = [fx_data['method']]
    
    print(f"📊 FX Export Summary:")
    print(f"   Methods tried: {', '.join(methods_tried)}")
    print(f"   Successful methods: {', '.join(successful_methods) if successful_methods else 'None'}")
    print(f"   Overall success: {'✅' if fx_data.get('success') or successful_methods else '❌'}")
    
    # Analyze successful exports
    for method in successful_methods:
        if method in fx_data:
            method_data = fx_data[method]
        elif f'{method}_results' in fx_data:
            method_data = fx_data[f'{method}_results']
        else:
            method_data = fx_data
        
        print(f"\n🎯 {method.upper()} Results:")
        print(f"   Success: {method_data.get('success', False)}")
        print(f"   Execution test: {method_data.get('execution_test', 'unknown')}")
        
        if 'fx_graph_analysis' in method_data:
            analysis = method_data['fx_graph_analysis']
            print(f"   Total nodes: {analysis.get('total_nodes', 'unknown')}")
            print(f"   Node types: {len(analysis.get('node_types', {}))}")
            print(f"   Nodes with metadata: {analysis.get('nodes_with_meta', 0)}")
            
            # Visualize node types if available
            if 'node_types' in analysis and analysis['node_types']:
                node_types = analysis['node_types']
                
                plt.figure(figsize=(12, 5))
                
                plt.subplot(1, 2, 1)
                types = list(node_types.keys())
                counts = list(node_types.values())
                
                plt.bar(types, counts, alpha=0.7, color='lightcoral')
                plt.xlabel('Node Operation Type')
                plt.ylabel('Count')
                plt.title(f'{method}: Node Type Distribution')
                plt.xticks(rotation=45)
                plt.grid(axis='y', alpha=0.3)
                
                # Sample nodes info
                if 'sample_nodes' in analysis:
                    sample_nodes = analysis['sample_nodes']
                    
                    plt.subplot(1, 2, 2)
                    meta_counts = [1 if node.get('has_meta') else 0 for node in sample_nodes]
                    
                    plt.pie([sum(meta_counts), len(meta_counts) - sum(meta_counts)], 
                           labels=['With Metadata', 'Without Metadata'], 
                           autopct='%1.1f%%', startangle=90)
                    plt.title('Sample Nodes: Metadata Coverage')
                
                plt.tight_layout()
                plt.show()
        
        # Show code preview if available
        if 'graph_code_preview' in method_data:
            print(f"\n📄 {method.upper()} Code Preview:")
            print("```python")
            print(method_data['graph_code_preview'])
            print("```")
    
    # Show errors for failed methods
    failed_methods = set(methods_tried) - set(successful_methods)
    for method in failed_methods:
        if method in fx_data:
            method_data = fx_data[method]
        elif f'{method}_results' in fx_data:
            method_data = fx_data[f'{method}_results']
        else:
            method_data = fx_data
        
        error = method_data.get('error', 'Unknown error')
        print(f"\n❌ {method.upper()} Failed:")
        print(f"   Error: {error}")
        
else:
    print("❌ FX graph information not found. Export may have failed.")

## 4. ONNX Export Analysis

Let's analyze the standard ONNX export results and compare with our new methods:

In [None]:
# Load ONNX hierarchy information
onnx_hierarchy_path = "../temp/notebook_exports/bert_full_demo_hierarchy.json"

if Path(onnx_hierarchy_path).exists():
    with open(onnx_hierarchy_path, 'r') as f:
        onnx_data = json.load(f)
    
    print("🎯 ONNX EXPORT ANALYSIS")
    print("="*50)
    
    # Basic export info
    summary = onnx_data['summary']
    exporter_info = onnx_data['exporter']
    
    print(f"📊 Export Summary:")
    print(f"   Strategy: {exporter_info['strategy']}")
    print(f"   Total operations: {summary['total_operations']}")
    print(f"   Tagged operations: {summary['tagged_operations']}")
    print(f"   Coverage: {summary['tagged_operations']/summary['total_operations']*100:.1f}%")
    print(f"   Unique tags: {summary['unique_tags']}")
    print(f"   Operation trace length: {summary['operation_trace_length']}")
    
    # Tag distribution analysis
    tag_stats = onnx_data['tag_statistics']
    
    print(f"\n🏷️ Tag Distribution:")
    sorted_tags = sorted(tag_stats.items(), key=lambda x: x[1], reverse=True)
    
    for tag, count in sorted_tags[:10]:  # Top 10
        print(f"   {tag}: {count} ops")
    
    if len(sorted_tags) > 10:
        print(f"   ... and {len(sorted_tags) - 10} more tags")
    
    # Visualization
    plt.figure(figsize=(15, 10))
    
    # Tag distribution pie chart
    plt.subplot(2, 2, 1)
    top_tags = dict(sorted_tags[:8])
    other_count = sum(count for _, count in sorted_tags[8:])
    if other_count > 0:
        top_tags['Others'] = other_count
    
    plt.pie(top_tags.values(), labels=[t.split('/')[-1] for t in top_tags.keys()], 
           autopct='%1.1f%%', startangle=90)
    plt.title('Operation Distribution by Module')
    
    # Coverage comparison
    plt.subplot(2, 2, 2)
    coverage_data = {
        'Tagged': summary['tagged_operations'],
        'Untagged': summary['total_operations'] - summary['tagged_operations']
    }
    plt.bar(coverage_data.keys(), coverage_data.values(), 
           color=['lightgreen', 'lightcoral'], alpha=0.7)
    plt.title('Operation Coverage')
    plt.ylabel('Number of Operations')
    
    # Tag hierarchy depth
    plt.subplot(2, 2, 3)
    tag_depths = [tag.count('/') for tag in tag_stats.keys()]
    plt.hist(tag_depths, bins=range(max(tag_depths)+2), alpha=0.7, color='orange')
    plt.xlabel('Hierarchy Depth')
    plt.ylabel('Number of Tags')
    plt.title('Tag Hierarchy Depth Distribution')
    
    # Module type analysis
    plt.subplot(2, 2, 4)
    module_types = {}
    for tag in tag_stats.keys():
        if '/' in tag:
            module_type = tag.split('/')[-1]
            module_types[module_type] = module_types.get(module_type, 0) + tag_stats[tag]
    
    if module_types:
        sorted_modules = sorted(module_types.items(), key=lambda x: x[1], reverse=True)[:8]
        modules, counts = zip(*sorted_modules)
        
        plt.bar(modules, counts, alpha=0.7, color='lightblue')
        plt.xlabel('Module Type')
        plt.ylabel('Operations')
        plt.title('Operations by Module Type')
        plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()
    
else:
    print("❌ ONNX hierarchy information not found. Export may have failed.")

## 5. Comparative Analysis

Let's compare the information captured by different methods:

In [None]:
print("🔬 COMPARATIVE ANALYSIS")
print("="*60)

# Create comparison dataframe
comparison_data = []

# ONNX Export Analysis
if 'onnx_data' in locals():
    comparison_data.append({
        'Method': 'ONNX Export',
        'Success': '✅',
        'Node Count': onnx_data['summary']['total_operations'],
        'Tagged Count': onnx_data['summary']['tagged_operations'],
        'Unique Tags': onnx_data['summary']['unique_tags'],
        'Coverage %': f"{onnx_data['summary']['tagged_operations']/onnx_data['summary']['total_operations']*100:.1f}%",
        'Key Feature': 'Hierarchy preservation in production format',
        'Context Preserved': 'Partial (via tagging)'
    })

# JIT Graph Analysis
if 'jit_data' in locals():
    comparison_data.append({
        'Method': 'JIT Graph',
        'Success': '✅' if jit_data['scope_statistics']['extraction_success'] else '❌',
        'Node Count': jit_data['graph_info']['total_nodes'],
        'Tagged Count': '-',
        'Unique Tags': jit_data['scope_statistics']['total_scopes'],
        'Coverage %': 'N/A',
        'Key Feature': 'Pre-ONNX module hierarchy capture',
        'Context Preserved': 'Full (before conversion)'
    })

# FX Graph Analysis
if 'fx_data' in locals():
    fx_success = fx_data.get('success', False) or len(successful_methods) > 0 if 'successful_methods' in locals() else False
    fx_node_count = 'Unknown'
    fx_feature = 'Graph representation analysis'
    
    # Try to get node count from successful method
    if 'successful_methods' in locals() and successful_methods:
        method = successful_methods[0]
        if f'{method}_results' in fx_data and 'fx_graph_analysis' in fx_data[f'{method}_results']:
            fx_node_count = fx_data[f'{method}_results']['fx_graph_analysis'].get('total_nodes', 'Unknown')
            fx_feature = f'{method.title()} graph representation'
    
    comparison_data.append({
        'Method': 'FX Graph',
        'Success': '✅' if fx_success else '❌',
        'Node Count': fx_node_count,
        'Tagged Count': '-',
        'Unique Tags': '-',
        'Coverage %': 'N/A',
        'Key Feature': fx_feature,
        'Context Preserved': 'Alternative (metadata-based)'
    })

# Display comparison table
if comparison_data:
    df = pd.DataFrame(comparison_data)
    
    print("📊 Method Comparison:")
    print(df.to_string(index=False))
    
    # Success rate visualization
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    success_counts = df['Success'].value_counts()
    plt.pie(success_counts.values, labels=success_counts.index, autopct='%1.1f%%', startangle=90)
    plt.title('Export Method Success Rate')
    
    # Node count comparison (where available)
    plt.subplot(1, 2, 2)
    node_data = []
    labels = []
    
    for _, row in df.iterrows():
        if row['Node Count'] != '-' and str(row['Node Count']).isdigit():
            node_data.append(int(row['Node Count']))
            labels.append(row['Method'])
    
    if node_data:
        plt.bar(labels, node_data, alpha=0.7, color=['skyblue', 'lightgreen', 'orange'])
        plt.ylabel('Node Count')
        plt.title('Graph Size Comparison')
        plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()

print("\n🎯 KEY INSIGHTS:")
print("1. ✅ JIT graphs preserve FULL module context before ONNX conversion")
print("2. ✅ ONNX export achieves high coverage with hierarchy tagging")
print("3. ⚠️  FX graphs may face compatibility issues with complex models")
print("4. 🚀 Combined approach provides comprehensive model analysis")

print("\n💡 RECOMMENDED WORKFLOW:")
print("• Use --jit-graph for debugging and hierarchy analysis")
print("• Use ONNX export for production deployment")
print("• Use --fx-graph for alternative analysis when supported")

## 6. File Structure Overview

Let's examine what files were generated by our export:

In [None]:
import os
from pathlib import Path

export_dir = Path("../temp/notebook_exports")

print("📁 GENERATED FILES OVERVIEW")
print("="*50)

if export_dir.exists():
    all_files = []
    
    # Walk through all files
    for root, dirs, files in os.walk(export_dir):
        for file in files:
            file_path = Path(root) / file
            relative_path = file_path.relative_to(export_dir)
            file_size = file_path.stat().st_size
            
            all_files.append({
                'File': str(relative_path),
                'Size': f"{file_size:,} bytes",
                'Type': file_path.suffix or 'directory'
            })
    
    # Sort by file type and name
    all_files.sort(key=lambda x: (x['Type'], x['File']))
    
    # Display file structure
    df_files = pd.DataFrame(all_files)
    print(df_files.to_string(index=False))
    
    # File type distribution
    plt.figure(figsize=(10, 6))
    
    plt.subplot(1, 2, 1)
    type_counts = df_files['Type'].value_counts()
    plt.pie(type_counts.values, labels=type_counts.index, autopct='%1.1f%%', startangle=90)
    plt.title('Generated File Types')
    
    # File sizes
    plt.subplot(1, 2, 2)
    sizes = [int(size.split()[0].replace(',', '')) for size in df_files['Size']]
    file_names = [f.split('/')[-1][:15] + '...' if len(f.split('/')[-1]) > 15 else f.split('/')[-1] for f in df_files['File']]
    
    plt.bar(range(len(sizes)), sizes, alpha=0.7)
    plt.xlabel('Files')
    plt.ylabel('Size (bytes)')
    plt.title('File Sizes')
    plt.xticks(range(len(file_names)), file_names, rotation=45, ha='right')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n📊 Summary: {len(all_files)} files generated")
    total_size = sum(int(size.split()[0].replace(',', '')) for size in df_files['Size'])
    print(f"📦 Total size: {total_size:,} bytes ({total_size/1024:.1f} KB)")
    
else:
    print("❌ Export directory not found")

print("\n🎉 NOTEBOOK ANALYSIS COMPLETE!")
print("\nNext steps:")
print("• Explore individual JSON files for detailed analysis")
print("• Use the generated .py files to understand graph structure")
print("• Compare ONNX model with original PyTorch model performance")