### Summary

This notebook demonstrated:

1. **Static Hierarchy Building**: Fast analysis of model structure using class names
2. **Tracing-Based Hierarchy Building**: Accurate execution-based hierarchy capture
3. **Production TracingHierarchyBuilder**: The actual implementation used in HTP exporter
4. **Pretty Display**: Beautiful visualization using Rich library with proper styling

The key insight is that the production TracingHierarchyBuilder provides the most accurate hierarchy information by tracing actual model execution, and the Rich library helps create beautiful, readable output similar to the HTP exporter's display.

### Node Count Analysis

Let's analyze how many ONNX nodes would be tagged with each hierarchy tag, similar to the HTP exporter output.

### Pretty Display with Rich Library

Let's display the hierarchy data in a beautiful format using Rich library, similar to how the HTP exporter displays it.

## Using the Actual TracingHierarchyBuilder from modelexport.core

Now let's use the actual implementation from our codebase and display the results in a pretty format. This section demonstrates how to use the production TracingHierarchyBuilder that is used in the HTP exporter.

### Using the Actual TracingHierarchyBuilder from modelexport.core

Now let's use the actual implementation from our codebase and display the results in a pretty format.

In [1]:
from transformers import AutoModel, AutoTokenizer
import torch

# Configuration - set your model here
MODEL_NAME = "prajjwal1/bert-tiny"  # Change this to test different models

# Load the model
print(f"Loading {MODEL_NAME}...")
model = AutoModel.from_pretrained(MODEL_NAME)

print(f"Model type: {type(model)}")
print(f"Model class: {model.__class__.__name__}")

Loading prajjwal1/bert-tiny...
Model type: <class 'transformers.models.bert.modeling_bert.BertModel'>
Model class: BertModel


Model type: <class 'transformers.models.bert.modeling_bert.BertModel'>
Model class: BertModel


# HuggingFace Module Hierarchy Exploration Notebook

This notebook explores different approaches to building HuggingFace module hierarchies and demonstrates the TracingHierarchyBuilder used in the HTP exporter.

In [2]:
def is_hf_class(module):
    """Check if a module is a HuggingFace class"""
    module_path = module.__class__.__module__
    return module_path.startswith('transformers')

def build_hf_hierarchy_mapping(model):
    """Recursively build HF module hierarchy mapping"""
    hierarchy_mapping = {}
    
    def recursive_build(module, current_tag, module_name, parent_children_names=None):
        """Recursively build hierarchy for a module"""
        
        # If this is an HF class, update the tag
        if is_hf_class(module):
            class_name = module.__class__.__name__
            
            # Add index if this is a repeated class among siblings
            if parent_children_names and module_name:
                module_basename = module_name.split('.')[-1]
                # Count how many siblings have the same class name
                same_class_siblings = []
                for sibling_name in parent_children_names:
                    if sibling_name == module_basename:
                        same_class_siblings.append(sibling_name)
                
                # If there are multiple siblings with same class, add index
                if len(same_class_siblings) > 1 or module_basename.isdigit():
                    # Extract index from module name (e.g., "0" from "layer.0")
                    if module_basename.isdigit():
                        index = module_basename
                        current_tag = f"{current_tag}/{class_name}.{index}"
                    else:
                        current_tag = f"{current_tag}/{class_name}"
                else:
                    current_tag = f"{current_tag}/{class_name}"
            else:
                current_tag = f"{current_tag}/{class_name}"
        
        # Map this module to its hierarchy tag
        if module_name:  # Skip root module
            hierarchy_mapping[module_name] = current_tag
        
        # Get children names for indexing
        children_names = [name for name, _ in module.named_children()]
        
        # Recursively process children
        for child_name, child_module in module.named_children():
            child_full_name = f"{module_name}.{child_name}" if module_name else child_name
            recursive_build(child_module, current_tag, child_full_name, children_names)
    
    # Start with root model - use simple class name without duplication
    root_class = model.__class__.__name__
    initial_tag = f"/{root_class}" if is_hf_class(model) else ""
    
    # Skip the root model itself and start with its children to avoid duplication
    for child_name, child_module in model.named_children():
        recursive_build(child_module, initial_tag, child_name, [name for name, _ in model.named_children()])
    
    return hierarchy_mapping

In [3]:
# Build the static mapping
print("Building static HF hierarchy mapping...")
static_hierarchy_mapping = build_hf_hierarchy_mapping(model)

print(f"\nFound {len(static_hierarchy_mapping)} module mappings:")
print("=" * 70)

for module_name, hierarchy_tag in sorted(static_hierarchy_mapping.items()):
    print(f"{module_name:40} -> {hierarchy_tag}")

Building static HF hierarchy mapping...

Found 47 module mappings:
embeddings                               -> /BertModel/BertEmbeddings
embeddings.LayerNorm                     -> /BertModel/BertEmbeddings
embeddings.dropout                       -> /BertModel/BertEmbeddings
embeddings.position_embeddings           -> /BertModel/BertEmbeddings
embeddings.token_type_embeddings         -> /BertModel/BertEmbeddings
embeddings.word_embeddings               -> /BertModel/BertEmbeddings
encoder                                  -> /BertModel/BertEncoder
encoder.layer                            -> /BertModel/BertEncoder
encoder.layer.0                          -> /BertModel/BertEncoder/BertLayer.0
encoder.layer.0.attention                -> /BertModel/BertEncoder/BertLayer.0/BertAttention
encoder.layer.0.attention.output         -> /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSelfOutput
encoder.layer.0.attention.output.LayerNorm -> /BertModel/BertEncoder/BertLayer.0/BertAttention/Bert

## Tracing-Based HF Module Hierarchy Builder

This approach uses forward hooks to trace actual execution flow and build more accurate hierarchy mappings.

In [4]:
class TracingHierarchyBuilder:
    """Tracing-based HF hierarchy builder using forward hooks."""
    
    def __init__(self):
        self.tag_stack = []
        self.execution_trace = []
        self.operation_context = {}
        self.hooks = []
        
    def is_hf_class(self, module):
        """Check if a module is a HuggingFace class"""
        module_path = module.__class__.__module__
        return module_path.startswith('transformers')
    
    def should_create_hierarchy_level(self, module):
        """Determine if module should create a new hierarchy level"""
        if self.is_hf_class(module):
            return True
        # Include some important torch.nn modules
        important_torch_nn = ['LayerNorm', 'Embedding']
        return module.__class__.__name__ in important_torch_nn
    
    def extract_module_info(self, module_name: str, module):
        """Extract module information for hierarchy building"""
        name_parts = module_name.split(".")
        
        # Check if this is an indexed module (e.g., layer.0)
        is_indexed_module = False
        module_index = None
        
        if len(name_parts) >= 2:
            last_part = name_parts[-1]
            second_last_part = name_parts[-2]
            
            if (last_part.isdigit() and 
                second_last_part in ['layer', 'layers', 'block', 'blocks', 'h']):
                is_indexed_module = True
                module_index = last_part
        
        return {
            'class_name': module.__class__.__name__,
            'module_index': module_index,
            'full_name': module_name,
            'is_indexed': is_indexed_module,
            'name_parts': name_parts,
        }
    
    def create_pre_hook(self, module_info):
        """Create pre-forward hook to push tag onto stack"""
        def pre_hook(module, inputs):
            # Get parent context from stack
            parent_tag = self.tag_stack[-1] if self.tag_stack else ""
            
            # Build current class name with index if needed
            if module_info['is_indexed']:
                current_class_name = f"{module_info['class_name']}.{module_info['module_index']}"
            else:
                current_class_name = module_info['class_name']
            
            # Build hierarchical tag
            hierarchical_tag = f"{parent_tag}/{current_class_name}"
            self.tag_stack.append(hierarchical_tag)
            
            # Record execution trace
            trace_entry = {
                'module_name': module_info['full_name'],
                'tag': hierarchical_tag,
                'action': 'enter',
                'stack_depth': len(self.tag_stack),
                'execution_order': len(self.execution_trace)
            }
            self.execution_trace.append(trace_entry)
            
            # Record in operation context
            self.operation_context[module_info['full_name']] = {
                "tag": hierarchical_tag,
                "module_class": module_info['class_name'],
                "creates_hierarchy": True,
                "stack_depth": len(self.tag_stack),
                "execution_order": len(self.execution_trace) - 1,
                "module_info": module_info
            }
            
        return pre_hook
    
    def create_post_hook(self, module_info):
        """Create post-forward hook to pop tag from stack"""
        def post_hook(module, inputs, outputs):
            # Record exit
            trace_entry = {
                'module_name': module_info['full_name'],
                'tag': self.tag_stack[-1] if self.tag_stack else "",
                'action': 'exit',
                'stack_depth': len(self.tag_stack),
                'execution_order': len(self.execution_trace)
            }
            self.execution_trace.append(trace_entry)
            
            # Pop the tag when module execution completes
            if self.tag_stack:
                self.tag_stack.pop()
                
        return post_hook
    
    def register_hooks(self, model):
        """Register forward hooks for tracing"""
        # Initialize stack with root module tag
        root_tag = f"/{model.__class__.__name__}"
        self.tag_stack = [root_tag]
        
        # Register hooks on all modules
        for name, module in model.named_modules():
            if name:  # Skip root module
                module_info = self.extract_module_info(name, module)
                
                # Only hook modules that should create hierarchy levels
                if self.should_create_hierarchy_level(module):
                    # Register pre-hook
                    pre_hook = module.register_forward_pre_hook(
                        self.create_pre_hook(module_info)
                    )
                    self.hooks.append(pre_hook)
                    
                    # Register post-hook
                    post_hook = module.register_forward_hook(
                        self.create_post_hook(module_info)
                    )
                    self.hooks.append(post_hook)
    
    def remove_hooks(self):
        """Remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
    
    def trace_model_execution(self, model, example_inputs):
        """Trace model execution to build hierarchy mapping"""
        self.register_hooks(model)
        
        try:
            # Run model forward pass to trigger hooks
            model.eval()
            with torch.no_grad():
                _ = model(*example_inputs)
        finally:
            self.remove_hooks()
    
    def get_hierarchy_mapping(self):
        """Get the traced hierarchy mapping"""
        hierarchy_mapping = {}
        
        for module_name, context in self.operation_context.items():
            hierarchy_mapping[module_name] = context['tag']
        
        return hierarchy_mapping
    
    def get_execution_summary(self):
        """Get summary of execution trace"""
        return {
            'total_modules_traced': len(self.operation_context),
            'execution_steps': len(self.execution_trace),
            'max_stack_depth': max([t['stack_depth'] for t in self.execution_trace] + [0]),
            'hierarchy_mapping': self.get_hierarchy_mapping()
        }

### Test Tracing-Based Hierarchy Mapping

Now let's test the tracing approach and compare it with our static approach.

In [5]:
# Prepare example inputs for tracing
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text = "Hello world, this is a test sentence."
inputs = tokenizer(text, return_tensors="pt", max_length=64, padding="max_length", truncation=True)
example_inputs = (inputs["input_ids"], inputs["attention_mask"])

print(f"Example inputs prepared:")
print(f"Input IDs shape: {example_inputs[0].shape}")
print(f"Attention mask shape: {example_inputs[1].shape}")

# Create tracing hierarchy builder
print("\nCreating tracing hierarchy builder...")
tracer = TracingHierarchyBuilder()

# Trace model execution
print("\nTracing model execution...")
tracer.trace_model_execution(model, example_inputs)

# Get results
traced_mapping = tracer.get_hierarchy_mapping()
execution_summary = tracer.get_execution_summary()

print(f"\nTracing completed!")
print(f"Execution summary: {execution_summary}")

Example inputs prepared:
Input IDs shape: torch.Size([1, 64])
Attention mask shape: torch.Size([1, 64])

Creating tracing hierarchy builder...

Tracing model execution...

Tracing completed!
Execution summary: {'total_modules_traced': 25, 'execution_steps': 50, 'max_stack_depth': 6, 'hierarchy_mapping': {'embeddings': '/BertModel/BertEmbeddings', 'embeddings.word_embeddings': '/BertModel/BertEmbeddings/Embedding', 'embeddings.token_type_embeddings': '/BertModel/BertEmbeddings/Embedding', 'embeddings.position_embeddings': '/BertModel/BertEmbeddings/Embedding', 'embeddings.LayerNorm': '/BertModel/BertEmbeddings/LayerNorm', 'encoder': '/BertModel/BertEncoder', 'encoder.layer.0': '/BertModel/BertEncoder/BertLayer.0', 'encoder.layer.0.attention': '/BertModel/BertEncoder/BertLayer.0/BertAttention', 'encoder.layer.0.attention.self': '/BertModel/BertEncoder/BertLayer.0/BertAttention/BertSdpaSelfAttention', 'encoder.layer.0.attention.output': '/BertModel/BertEncoder/BertLayer.0/BertAttention/Be

In [6]:
# Display traced hierarchy mapping
print("Traced Hierarchy Mapping:")
print("=" * 70)

for module_name, hierarchy_tag in sorted(traced_mapping.items()):
    print(f"{module_name:40} -> {hierarchy_tag}")

print(f"\nFound {len(traced_mapping)} modules with traced hierarchy tags")

Traced Hierarchy Mapping:
embeddings                               -> /BertModel/BertEmbeddings
embeddings.LayerNorm                     -> /BertModel/BertEmbeddings/LayerNorm
embeddings.position_embeddings           -> /BertModel/BertEmbeddings/Embedding
embeddings.token_type_embeddings         -> /BertModel/BertEmbeddings/Embedding
embeddings.word_embeddings               -> /BertModel/BertEmbeddings/Embedding
encoder                                  -> /BertModel/BertEncoder
encoder.layer.0                          -> /BertModel/BertEncoder/BertLayer.0
encoder.layer.0.attention                -> /BertModel/BertEncoder/BertLayer.0/BertAttention
encoder.layer.0.attention.output         -> /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSelfOutput
encoder.layer.0.attention.output.LayerNorm -> /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSelfOutput/LayerNorm
encoder.layer.0.attention.self           -> /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSdpaSelfAttention
en

### Compare Static vs Tracing Approaches

Let's compare the static and tracing-based hierarchy mappings to see the differences.

In [7]:
def compare_hierarchy_mappings(static_mapping, traced_mapping):
    """Compare static and traced hierarchy mappings"""
    
    # Find common modules
    static_modules = set(static_mapping.keys())
    traced_modules = set(traced_mapping.keys())
    common_modules = static_modules & traced_modules
    
    # Find differences
    only_static = static_modules - traced_modules
    only_traced = traced_modules - static_modules
    
    print("Comparison Summary:")
    print("=" * 50)
    print(f"Static mapping modules: {len(static_modules)}")
    print(f"Traced mapping modules: {len(traced_modules)}")
    print(f"Common modules: {len(common_modules)}")
    print(f"Only in static: {len(only_static)}")
    print(f"Only in traced: {len(only_traced)}")
    
    # Check for tag differences in common modules
    tag_differences = []
    tag_matches = 0
    
    for module in common_modules:
        static_tag = static_mapping[module]
        traced_tag = traced_mapping[module]
        
        if static_tag == traced_tag:
            tag_matches += 1
        else:
            tag_differences.append({
                'module': module,
                'static': static_tag,
                'traced': traced_tag
            })
    
    print(f"\nTag Comparison:")
    print(f"Matching tags: {tag_matches}/{len(common_modules)}")
    print(f"Different tags: {len(tag_differences)}")
    
    # Show some differences
    if tag_differences:
        print(f"\nFirst few tag differences:")
        for diff in tag_differences[:5]:
            print(f"  {diff['module']:30}")
            print(f"    Static:  {diff['static']}")
            print(f"    Traced:  {diff['traced']}")
        
        if len(tag_differences) > 5:
            print(f"    ... and {len(tag_differences) - 5} more differences")
    
    # Show modules only in traced (key insight)
    if only_traced:
        print(f"\nModules only in traced mapping (execution-only):")
        for module in sorted(list(only_traced))[:5]:
            print(f"  {module:30} -> {traced_mapping[module]}")
        if len(only_traced) > 5:
            print(f"    ... and {len(only_traced) - 5} more")
    
    return {
        'static_count': len(static_modules),
        'traced_count': len(traced_modules),
        'common_count': len(common_modules),
        'tag_matches': tag_matches,
        'tag_differences': len(tag_differences),
        'differences': tag_differences
    }

# Compare the mappings
comparison = compare_hierarchy_mappings(static_hierarchy_mapping, traced_mapping)

Comparison Summary:
Static mapping modules: 47
Traced mapping modules: 25
Common modules: 25
Only in static: 22
Only in traced: 0

Tag Comparison:
Matching tags: 17/25
Different tags: 8

First few tag differences:
  embeddings.token_type_embeddings
    Static:  /BertModel/BertEmbeddings
    Traced:  /BertModel/BertEmbeddings/Embedding
  embeddings.position_embeddings
    Static:  /BertModel/BertEmbeddings
    Traced:  /BertModel/BertEmbeddings/Embedding
  encoder.layer.1.output.LayerNorm
    Static:  /BertModel/BertEncoder/BertLayer.1/BertOutput
    Traced:  /BertModel/BertEncoder/BertLayer.1/BertOutput/LayerNorm
  encoder.layer.0.output.LayerNorm
    Static:  /BertModel/BertEncoder/BertLayer.0/BertOutput
    Traced:  /BertModel/BertEncoder/BertLayer.0/BertOutput/LayerNorm
  encoder.layer.0.attention.output.LayerNorm
    Static:  /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSelfOutput
    Traced:  /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSelfOutput/LayerNorm
    ...

### Execution Trace Analysis

Let's analyze the execution trace to understand the call flow.

In [8]:
# Analyze execution trace
print("Execution Trace Analysis:")
print("=" * 50)

# Show first few trace entries
print("First 10 execution steps:")
for i, trace in enumerate(tracer.execution_trace[:10]):
    indent = "  " * (trace['stack_depth'] - 1)
    action_symbol = "→" if trace['action'] == 'enter' else "←"
    print(f"{i:2d}: {indent}{action_symbol} {trace['module_name']}")

if len(tracer.execution_trace) > 10:
    print(f"... and {len(tracer.execution_trace) - 10} more steps")

# Show execution order of modules
print(f"\nModule execution order:")
execution_order = {}
for module_name, context in tracer.operation_context.items():
    execution_order[context['execution_order']] = (module_name, context['tag'])

for order in sorted(execution_order.keys()):
    module_name, tag = execution_order[order]
    print(f"{order:2d}: {module_name:30} -> {tag}")

print(f"\nMax stack depth reached: {execution_summary['max_stack_depth']}")

Execution Trace Analysis:
First 10 execution steps:
 0:   → embeddings
 1:     → embeddings.word_embeddings
 2:     ← embeddings.word_embeddings
 3:     → embeddings.token_type_embeddings
 4:     ← embeddings.token_type_embeddings
 5:     → embeddings.position_embeddings
 6:     ← embeddings.position_embeddings
 7:     → embeddings.LayerNorm
 8:     ← embeddings.LayerNorm
 9:   ← embeddings
... and 40 more steps

Module execution order:
 0: embeddings                     -> /BertModel/BertEmbeddings
 1: embeddings.word_embeddings     -> /BertModel/BertEmbeddings/Embedding
 3: embeddings.token_type_embeddings -> /BertModel/BertEmbeddings/Embedding
 5: embeddings.position_embeddings -> /BertModel/BertEmbeddings/Embedding
 7: embeddings.LayerNorm           -> /BertModel/BertEmbeddings/LayerNorm
10: encoder                        -> /BertModel/BertEncoder
11: encoder.layer.0                -> /BertModel/BertEncoder/BertLayer.0
12: encoder.layer.0.attention      -> /BertModel/BertEncoder/Be

### Key Insights

**Static vs Tracing Comparison:**

1. **Static Approach**: Fast, comprehensive, covers all modules in model structure
2. **Tracing Approach**: Accurate, execution-based, includes important torch.nn modules like LayerNorm
3. **Key Difference**: Tracing captures modules that actually execute and includes torch.nn modules that contribute to hierarchy

**When to Use:**
- **Static**: When you need fast analysis of model structure
- **Tracing**: When you need accurate operation-to-module mapping for ONNX export
- **Both**: For validation and comprehensive coverage

**Tracing Advantages:**
- Captures actual execution flow
- Includes important torch.nn modules (LayerNorm, Embedding)
- Provides execution order and timing
- Better for dynamic module behavior

**Test with different models by changing MODEL_NAME above!**

In [None]:
# Import the actual TracingHierarchyBuilder from modelexport.core
import sys
sys.path.append('../..')  # Add project root to path

from modelexport.core.tracing_hierarchy_builder import TracingHierarchyBuilder
from modelexport.core.model_input_generator import generate_dummy_inputs
from rich.console import Console
from rich.tree import Tree
from rich.text import Text
from rich.table import Table
from rich.panel import Panel
import json

console = Console()

# Create inputs using the model_input_generator
print("Generating inputs using model_input_generator...")
dummy_inputs = generate_dummy_inputs(MODEL_NAME, exporter="onnx")
print(f"Generated inputs: {list(dummy_inputs.keys())}")

# Create the actual TracingHierarchyBuilder
print("\nCreating TracingHierarchyBuilder from modelexport.core...")
core_tracer = TracingHierarchyBuilder()

# Trace the model
print("Tracing model execution...")
core_tracer.trace_model_execution(model, dummy_inputs)

# Get the execution summary
core_summary = core_tracer.get_execution_summary()

print(f"\n✅ Tracing completed!")
print(f"Total modules traced: {core_summary['total_modules_traced']}")
print(f"Total modules: {core_summary['total_modules']}")
print(f"Execution steps: {core_summary['execution_steps']}")

# Get the hierarchy data
module_hierarchy = core_summary.get('module_hierarchy', {})

# Function to create styled text (main: detail)
def create_styled_text(main_text, detail_text, main_style="bold", detail_style="dim"):
    """Create styled text with main text and detail text in gray."""
    styled_text = Text()
    styled_text.append(main_text, style=main_style)
    styled_text.append(": ", style="white")
    styled_text.append(detail_text, style=detail_style)
    return styled_text

# Create a summary panel
stats_table = Table(show_header=False, box=None)
stats_table.add_column(style="bold cyan")
stats_table.add_column(style="bold yellow")

stats_table.add_row("Total Modules Traced", str(core_summary['total_modules_traced']))
stats_table.add_row("Total Modules", str(core_summary['total_modules']))
stats_table.add_row("Execution Steps", str(core_summary['execution_steps']))
stats_table.add_row("Hierarchy Entries", str(len(core_summary.get('hierarchy_mapping', {}))))

console.print(Panel(stats_table, title="[bold blue]Execution Summary[/bold blue]", border_style="blue"))

# Display hierarchy as a tree
print("\n🌳 Module Hierarchy Tree:")
print("=" * 80)

# Get root info
root_info = module_hierarchy.get("", {})
root_class = root_info.get('class_name', model.__class__.__name__)

# Create the tree
tree = Tree(f"[bold bright_magenta]{root_class}[/bold bright_magenta]")

def build_tree(tree_node, parent_path, hierarchy_data, processed=None):
    """Build tree showing the module hierarchy with proper parent-child relationships."""
    if processed is None:
        processed = set()
    
    # Find immediate children
    immediate_children = []
    
    for path, info in hierarchy_data.items():
        if path in processed or not path:  # Skip if already processed or is root
            continue
            
        if parent_path == "":
            # Root level - find paths with no dots
            if "." not in path:
                immediate_children.append((path, info))
        else:
            # Check if this path is a child of parent_path
            if path.startswith(parent_path + "."):
                suffix = path[len(parent_path + "."):]
                
                # Case 1: Direct child (no dots in suffix)
                if "." not in suffix:
                    immediate_children.append((path, info))
                # Case 2: Numbered pattern (e.g., layer.0)
                else:
                    parts = suffix.split(".")
                    # Match only exact patterns like "layer.0"
                    if len(parts) == 2 and parts[1].isdigit():
                        immediate_children.append((path, info))
    
    # Sort children for consistent display
    immediate_children.sort(key=lambda x: x[0])
    
    # Add children to tree
    for child_path, child_info in immediate_children:
        processed.add(child_path)
        
        class_name = child_info.get('class_name', 'Unknown')
        
        # Create styled text
        styled_text = create_styled_text(class_name, child_path, "bold bright_green", "bright_cyan")
        
        child_node = tree_node.add(styled_text)
        
        # Recursively add children
        build_tree(child_node, child_path, hierarchy_data, processed)

# Build the tree
build_tree(tree, "", module_hierarchy)

# Display the tree
console.print(tree)

## Testing with ResNet-50

Now let's test with ResNet-50 to see the issue with ResNetConvLayer showing 0 nodes:

In [None]:
# Test ResNet with the fixed HTP exporter
import sys
sys.path.append('../..')  # Add project root to path

from transformers import AutoModel
from modelexport.strategies.htp import HTPExporter
from modelexport.core.model_input_generator import generate_dummy_inputs
from rich.tree import Tree
from rich.console import Console
from rich.style import Style
from rich.text import Text

# Initialize
console = Console()
model_name = "microsoft/resnet-50"
model = AutoModel.from_pretrained(model_name)

# Export with hierarchy preservation AND include torch.nn children
console.print(f"\n🔧 Exporting {model_name} with hierarchy preservation (include torch.nn)...", style="bold cyan")
exporter = HTPExporter(verbose=False, include_torch_nn_children=True)
export_result = exporter.export(
    model=model,
    output_path="resnet_fixed.onnx",
    model_name_or_path=model_name
)

# Display summary
console.print("\n📊 Export Summary", style="bold green")
console.print(f"Total modules in hierarchy: {len(exporter._hierarchy_data)}")
console.print(f"Total ONNX operations: {len(exporter._tagged_nodes)}")

# Build hierarchy tree focusing on ResNetConvLayer
console.print("\n🌳 Module Hierarchy (ResNetConvLayer focus):", style="bold blue")
tree = Tree("ResNetModel", style="bold magenta")

def add_resnet_conv_layers(node, path="", hierarchy_data=None, tagged_nodes=None, depth=0, max_depth=4):
    """Add ResNetConvLayer nodes to tree with counts."""
    if depth > max_depth:
        return
    
    # Find immediate children
    immediate_children = []
    for child_path, child_info in hierarchy_data.items():
        if child_path.startswith(path + ".") and child_path != path:
            parts = child_path[len(path)+1:].split('.')
            if len(parts) == 1 or (len(parts) == 2 and parts[1].isdigit()):
                immediate_children.append((child_path, child_info))
    
    # Filter to show ResNetConvLayer and its parents
    for child_path, child_info in sorted(immediate_children):
        class_name = child_info.get('class_name', 'Unknown')
        traced_tag = child_info.get('traced_tag', '')
        
        # Count nodes with correct logic
        node_count = 0
        for tag in tagged_nodes.values():
            if tag == traced_tag or tag.startswith(traced_tag + "/"):
                node_count += 1
        
        # Show ResNetConvLayer and important parents
        if class_name in ['ResNetConvLayer', 'ResNetBottleNeckLayer', 'ResNetStage', 'ResNetEncoder']:
            # Style based on node count
            if node_count == 0:
                style = "red"
                issue = " ❌"
            elif node_count < 5:
                style = "yellow" 
                issue = " ✅"
            else:
                style = "green"
                issue = " ✅"
            
            # Create text with styled count
            text = Text(f"{class_name}: {child_path.split('.')[-1]}")
            text.append(f" ({node_count} nodes){issue}", style=style)
            
            child_node = node.add(text)
            
            # Recurse for important nodes
            if class_name != 'ResNetConvLayer':  # Don't recurse into ResNetConvLayer
                add_resnet_conv_layers(child_node, child_path, hierarchy_data, tagged_nodes, depth + 1, max_depth)

# Show only ResNetEncoder branch
for path, info in exporter._hierarchy_data.items():
    if path == 'encoder':
        node = tree.add("ResNetEncoder")
        add_resnet_conv_layers(node, path, exporter._hierarchy_data, exporter._tagged_nodes)
        break

console.print(tree)

# Show the fix explanation
console.print("\n📝 Fix Explanation:", style="bold yellow")
console.print("The issue was that ONNX node names had a double 'layer' pattern:")
console.print("  ❌ ONNX: /encoder/stages.0/layers.0/layer/layer.0/convolution/Conv")
console.print("  ✅ Hierarchy: encoder.stages.0.layers.0.layer.0.convolution")
console.print("\nThe fix in _extract_scope_from_node() now handles this pattern:")
console.print("  - Detects 'layer/layer.N' patterns")
console.print("  - Converts to 'layer.N' to match hierarchy data")
console.print("  - Operations are now correctly tagged to ResNetConvLayer modules")

# Verify some specific ResNetConvLayer nodes
console.print("\n🔍 Sample ResNetConvLayer Verification:", style="bold cyan")
sample_paths = [
    "encoder.stages.0.layers.0.layer.0",
    "encoder.stages.0.layers.0.layer.1",
    "encoder.stages.0.layers.0.layer.2"
]

for path in sample_paths:
    if path in exporter._hierarchy_data:
        info = exporter._hierarchy_data[path]
        tag = info.get('traced_tag')
        node_count = sum(1 for t in exporter._tagged_nodes.values() if t == tag or t.startswith(tag + "/"))
        console.print(f"\n{path}:")
        console.print(f"  Nodes: {node_count} {'✅' if node_count > 0 else '❌'}")

## ResNet Issue Fix Demonstration

This section demonstrates the fix for the ResNet issue where `ResNetConvLayer` modules were showing 0 nodes due to ONNX node naming mismatches.

In [None]:
# Test with ResNet-50
print("Testing with ResNet-50...")
print("=" * 80)

# Load ResNet-50
resnet_model = AutoModel.from_pretrained("microsoft/resnet-50")
print(f"Model class: {resnet_model.__class__.__name__}")

# Generate inputs
resnet_inputs = generate_dummy_inputs("microsoft/resnet-50", exporter="onnx")
print(f"Generated inputs: {list(resnet_inputs.keys())}")

# Create tracer
resnet_tracer = TracingHierarchyBuilder()

# Trace the model
print("\nTracing ResNet-50 execution...")
resnet_tracer.trace_model_execution(resnet_model, resnet_inputs)

# Get summary
resnet_summary = resnet_tracer.get_execution_summary()
resnet_hierarchy = resnet_summary['module_hierarchy']

print(f"\nTotal modules traced: {len(resnet_hierarchy)}")

# Look for ResNetConvLayer modules
print("\nLooking for ResNetConvLayer modules:")
print("-" * 80)

conv_layers = []
for path, info in resnet_hierarchy.items():
    if info.get('class_name') == 'ResNetConvLayer':
        conv_layers.append((path, info))

print(f"Found {len(conv_layers)} ResNetConvLayer modules")

# Show first few
for path, info in conv_layers[:3]:
    print(f"\nPath: {path}")
    print(f"  Class: {info.get('class_name')}")
    print(f"  Tag: {info.get('traced_tag')}")
    
    # Look for child modules of this ResNetConvLayer
    print(f"  Looking for children of {path}:")
    child_count = 0
    for child_path, child_info in resnet_hierarchy.items():
        if child_path.startswith(path + ".") and child_path != path:
            print(f"    - {child_path}: {child_info.get('class_name')} (tag: {child_info.get('traced_tag')})")
            child_count += 1
    
    if child_count == 0:
        print(f"    ❌ No children found in hierarchy! This explains the 0 nodes issue.")
        
        # Let's check the actual module structure
        print(f"\n  Checking actual module structure of {path}:")
        module = resnet_model
        for part in path.split('.'):
            if part:
                module = getattr(module, part)
        
        print(f"  Module has the following torch.nn children:")
        for name, child in module.named_children():
            print(f"    - {name}: {type(child).__name__}")

print("\n🔴 Issue: ResNetConvLayer has torch.nn children (Conv2d, BatchNorm2d, ReLU) that are not captured")
print("    in the hierarchy because the current TracingHierarchyBuilder includes ALL modules.")
print("    However, when filtered by should_include_in_hierarchy for MUST-002 compliance,")
print("    these torch.nn modules would be excluded, causing ResNetConvLayer to show 0 nodes.")