# HF Module Hierarchy Exploration

This notebook explores the sub-module hierarchy of any HuggingFace model and builds semantic hierarchy mappings using both static analysis and dynamic tracing approaches.

In [1]:
import torch
from transformers import AutoModel

# Configuration - set your model here
# Change this to test different models
MODEL_NAME = "prajjwal1/bert-tiny"
# MODEL_NAME = "google/vit-base-patch16-224" 
# MODEL_NAME = "facebook/sam-vit-base"

# Load the model
print(f"Loading {MODEL_NAME}...")
model = AutoModel.from_pretrained(MODEL_NAME)

print(f"Model type: {type(model)}")
print(f"Model class: {model.__class__.__name__}")

Loading prajjwal1/bert-tiny...
Model type: <class 'transformers.models.bert.modeling_bert.BertModel'>
Model class: BertModel


## Static HF Module Hierarchy Builder

Build a semantic hierarchy using HF class names and map to named_modules() paths using static analysis.

In [None]:
def is_hf_class(module):
    """Check if a module is a HuggingFace class"""
    module_path = module.__class__.__module__
    return module_path.startswith('transformers')

def build_hf_hierarchy_mapping(model):
    """Recursively build HF module hierarchy mapping"""
    hierarchy_mapping = {}
    
    def recursive_build(module, current_tag, module_name, parent_children_names=None):
        """Recursively build hierarchy for a module"""
        
        # If this is an HF class, update the tag
        if is_hf_class(module):
            class_name = module.__class__.__name__
            
            # Add index if this is a repeated class among siblings
            if parent_children_names and module_name:
                module_basename = module_name.split('.')[-1]
                # Count how many siblings have the same class name
                same_class_siblings = []
                for sibling_name in parent_children_names:
                    if sibling_name == module_basename:
                        same_class_siblings.append(sibling_name)
                
                # If there are multiple siblings with same class, add index
                if len(same_class_siblings) > 1 or module_basename.isdigit():
                    # Extract index from module name (e.g., "0" from "layer.0")
                    if module_basename.isdigit():
                        index = module_basename
                        current_tag = f"{current_tag}/{class_name}.{index}"
                    else:
                        current_tag = f"{current_tag}/{class_name}"
                else:
                    current_tag = f"{current_tag}/{class_name}"
            else:
                current_tag = f"{current_tag}/{class_name}"
        
        # Map this module to its hierarchy tag
        if module_name:  # Skip root module
            hierarchy_mapping[module_name] = current_tag
        
        # Get children names for indexing
        children_names = [name for name, _ in module.named_children()]
        
        # Recursively process children
        for child_name, child_module in module.named_children():
            child_full_name = f"{module_name}.{child_name}" if module_name else child_name
            recursive_build(child_module, current_tag, child_full_name, children_names)
    
    # Start with root model - use simple class name without duplication
    root_class = model.__class__.__name__
    initial_tag = f"/{root_class}" if is_hf_class(model) else ""
    
    # Skip the root model itself and start with its children to avoid duplication
    for child_name, child_module in model.named_children():
        recursive_build(child_module, initial_tag, child_name, [name for name, _ in model.named_children()])
    
    return hierarchy_mapping

In [None]:
# Build the static mapping
print("Building static HF hierarchy mapping...")
static_hierarchy_mapping = build_hf_hierarchy_mapping(model)

print(f"\nFound {len(static_hierarchy_mapping)} module mappings:")
print("=" * 70)

for module_name, hierarchy_tag in sorted(static_hierarchy_mapping.items()):
    print(f"{module_name:40} -> {hierarchy_tag}")

## Tracing-Based HF Module Hierarchy Builder

This approach uses forward hooks to trace actual execution flow and build more accurate hierarchy mappings.

In [2]:
class TracingHierarchyBuilder:
    """Tracing-based HF hierarchy builder using forward hooks."""
    
    def __init__(self):
        self.tag_stack = []
        self.execution_trace = []
        self.operation_context = {}
        self.hooks = []
        
    def is_hf_class(self, module):
        """Check if a module is a HuggingFace class"""
        module_path = module.__class__.__module__
        return module_path.startswith('transformers')
    
    def should_create_hierarchy_level(self, module):
        """Determine if module should create a new hierarchy level"""
        if self.is_hf_class(module):
            return True
        # Include some important torch.nn modules
        important_torch_nn = ['LayerNorm', 'Embedding']
        return module.__class__.__name__ in important_torch_nn
    
    def extract_module_info(self, module_name: str, module):
        """Extract module information for hierarchy building"""
        name_parts = module_name.split(".")
        
        # Check if this is an indexed module (e.g., layer.0)
        is_indexed_module = False
        module_index = None
        
        if len(name_parts) >= 2:
            last_part = name_parts[-1]
            second_last_part = name_parts[-2]
            
            if (last_part.isdigit() and 
                second_last_part in ['layer', 'layers', 'block', 'blocks', 'h']):
                is_indexed_module = True
                module_index = last_part
        
        return {
            'class_name': module.__class__.__name__,
            'module_index': module_index,
            'full_name': module_name,
            'is_indexed': is_indexed_module,
            'name_parts': name_parts,
        }
    
    def create_pre_hook(self, module_info):
        """Create pre-forward hook to push tag onto stack"""
        def pre_hook(module, inputs):
            # Get parent context from stack
            parent_tag = self.tag_stack[-1] if self.tag_stack else ""
            
            # Build current class name with index if needed
            if module_info['is_indexed']:
                current_class_name = f"{module_info['class_name']}.{module_info['module_index']}"
            else:
                current_class_name = module_info['class_name']
            
            # Build hierarchical tag
            hierarchical_tag = f"{parent_tag}/{current_class_name}"
            self.tag_stack.append(hierarchical_tag)
            
            # Record execution trace
            trace_entry = {
                'module_name': module_info['full_name'],
                'tag': hierarchical_tag,
                'action': 'enter',
                'stack_depth': len(self.tag_stack),
                'execution_order': len(self.execution_trace)
            }
            self.execution_trace.append(trace_entry)
            
            # Record in operation context
            self.operation_context[module_info['full_name']] = {
                "tag": hierarchical_tag,
                "module_class": module_info['class_name'],
                "creates_hierarchy": True,
                "stack_depth": len(self.tag_stack),
                "execution_order": len(self.execution_trace) - 1,
                "module_info": module_info
            }
            
        return pre_hook
    
    def create_post_hook(self, module_info):
        """Create post-forward hook to pop tag from stack"""
        def post_hook(module, inputs, outputs):
            # Record exit
            trace_entry = {
                'module_name': module_info['full_name'],
                'tag': self.tag_stack[-1] if self.tag_stack else "",
                'action': 'exit',
                'stack_depth': len(self.tag_stack),
                'execution_order': len(self.execution_trace)
            }
            self.execution_trace.append(trace_entry)
            
            # Pop the tag when module execution completes
            if self.tag_stack:
                self.tag_stack.pop()
                
        return post_hook
    
    def register_hooks(self, model):
        """Register forward hooks for tracing"""
        # Initialize stack with root module tag
        root_tag = f"/{model.__class__.__name__}"
        self.tag_stack = [root_tag]
        
        # Register hooks on all modules
        for name, module in model.named_modules():
            if name:  # Skip root module
                module_info = self.extract_module_info(name, module)
                
                # Only hook modules that should create hierarchy levels
                if self.should_create_hierarchy_level(module):
                    # Register pre-hook
                    pre_hook = module.register_forward_pre_hook(
                        self.create_pre_hook(module_info)
                    )
                    self.hooks.append(pre_hook)
                    
                    # Register post-hook
                    post_hook = module.register_forward_hook(
                        self.create_post_hook(module_info)
                    )
                    self.hooks.append(post_hook)
    
    def remove_hooks(self):
        """Remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
    
    def trace_model_execution(self, model, example_inputs):
        """Trace model execution to build hierarchy mapping"""
        self.register_hooks(model)
        
        try:
            # Run model forward pass to trigger hooks
            model.eval()
            with torch.no_grad():
                _ = model(*example_inputs)
        finally:
            self.remove_hooks()
    
    def get_hierarchy_mapping(self):
        """Get the traced hierarchy mapping"""
        hierarchy_mapping = {}
        
        for module_name, context in self.operation_context.items():
            hierarchy_mapping[module_name] = context['tag']
        
        return hierarchy_mapping
    
    def get_execution_summary(self):
        """Get summary of execution trace"""
        return {
            'total_modules_traced': len(self.operation_context),
            'execution_steps': len(self.execution_trace),
            'max_stack_depth': max([t['stack_depth'] for t in self.execution_trace] + [0]),
            'hierarchy_mapping': self.get_hierarchy_mapping()
        }

### Test Tracing-Based Hierarchy Mapping

Now let's test the tracing approach and compare it with our static approach.

In [3]:
# Prepare example inputs for tracing
from modelexport.core.model_input_generator import generate_dummy_inputs_from_model_path

# Simple usage - Optimum handles defaults based on model type
inputs = generate_dummy_inputs_from_model_path(MODEL_NAME)

print(f"Example inputs prepared:")
print(f"Generated inputs: {list(inputs.keys())}")
example_inputs = list(inputs.values())

Example inputs prepared:
Generated inputs: ['input_ids', 'attention_mask', 'token_type_ids']


In [4]:
# Create tracing hierarchy builder
print("\nCreating tracing hierarchy builder...")
tracer = TracingHierarchyBuilder()

# Trace model execution
print("\nTracing model execution...")
tracer.trace_model_execution(model, example_inputs)

# Get results
traced_mapping = tracer.get_hierarchy_mapping()
execution_summary = tracer.get_execution_summary()

print(f"\nTracing completed!")
print(f"Execution summary: {execution_summary}")


Creating tracing hierarchy builder...

Tracing model execution...

Tracing completed!
Execution summary: {'total_modules_traced': 25, 'execution_steps': 50, 'max_stack_depth': 6, 'hierarchy_mapping': {'embeddings': '/BertModel/BertEmbeddings', 'embeddings.word_embeddings': '/BertModel/BertEmbeddings/Embedding', 'embeddings.token_type_embeddings': '/BertModel/BertEmbeddings/Embedding', 'embeddings.position_embeddings': '/BertModel/BertEmbeddings/Embedding', 'embeddings.LayerNorm': '/BertModel/BertEmbeddings/LayerNorm', 'encoder': '/BertModel/BertEncoder', 'encoder.layer.0': '/BertModel/BertEncoder/BertLayer.0', 'encoder.layer.0.attention': '/BertModel/BertEncoder/BertLayer.0/BertAttention', 'encoder.layer.0.attention.self': '/BertModel/BertEncoder/BertLayer.0/BertAttention/BertSdpaSelfAttention', 'encoder.layer.0.attention.output': '/BertModel/BertEncoder/BertLayer.0/BertAttention/BertSelfOutput', 'encoder.layer.0.attention.output.LayerNorm': '/BertModel/BertEncoder/BertLayer.0/BertAtt

In [5]:
# Display traced hierarchy mapping
print("Traced Hierarchy Mapping:")
print("=" * 70)

for module_name, hierarchy_tag in sorted(traced_mapping.items()):
    print(f"{module_name:40} -> {hierarchy_tag}")

print(f"\nFound {len(traced_mapping)} modules with traced hierarchy tags")

Traced Hierarchy Mapping:
embeddings                               -> /BertModel/BertEmbeddings
embeddings.LayerNorm                     -> /BertModel/BertEmbeddings/LayerNorm
embeddings.position_embeddings           -> /BertModel/BertEmbeddings/Embedding
embeddings.token_type_embeddings         -> /BertModel/BertEmbeddings/Embedding
embeddings.word_embeddings               -> /BertModel/BertEmbeddings/Embedding
encoder                                  -> /BertModel/BertEncoder
encoder.layer.0                          -> /BertModel/BertEncoder/BertLayer.0
encoder.layer.0.attention                -> /BertModel/BertEncoder/BertLayer.0/BertAttention
encoder.layer.0.attention.output         -> /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSelfOutput
encoder.layer.0.attention.output.LayerNorm -> /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSelfOutput/LayerNorm
encoder.layer.0.attention.self           -> /BertModel/BertEncoder/BertLayer.0/BertAttention/BertSdpaSelfAttention
en

### Compare Static vs Tracing Approaches

Let's compare the static and tracing-based hierarchy mappings to see the differences.

In [None]:
def compare_hierarchy_mappings(static_mapping, traced_mapping):
    """Compare static and traced hierarchy mappings"""
    
    # Find common modules
    static_modules = set(static_mapping.keys())
    traced_modules = set(traced_mapping.keys())
    common_modules = static_modules & traced_modules
    
    # Find differences
    only_static = static_modules - traced_modules
    only_traced = traced_modules - static_modules
    
    print("Comparison Summary:")
    print("=" * 50)
    print(f"Static mapping modules: {len(static_modules)}")
    print(f"Traced mapping modules: {len(traced_modules)}")
    print(f"Common modules: {len(common_modules)}")
    print(f"Only in static: {len(only_static)}")
    print(f"Only in traced: {len(only_traced)}")
    
    # Check for tag differences in common modules
    tag_differences = []
    tag_matches = 0
    
    for module in common_modules:
        static_tag = static_mapping[module]
        traced_tag = traced_mapping[module]
        
        if static_tag == traced_tag:
            tag_matches += 1
        else:
            tag_differences.append({
                'module': module,
                'static': static_tag,
                'traced': traced_tag
            })
    
    print(f"\nTag Comparison:")
    print(f"Matching tags: {tag_matches}/{len(common_modules)}")
    print(f"Different tags: {len(tag_differences)}")
    
    # Show some differences
    if tag_differences:
        print(f"\nFirst few tag differences:")
        for diff in tag_differences[:5]:
            print(f"  {diff['module']:30}")
            print(f"    Static:  {diff['static']}")
            print(f"    Traced:  {diff['traced']}")
        
        if len(tag_differences) > 5:
            print(f"    ... and {len(tag_differences) - 5} more differences")
    
    # Show modules only in traced (key insight)
    if only_traced:
        print(f"\nModules only in traced mapping (execution-only):")
        for module in sorted(only_traced)[:5]:
            print(f"  {module:30} -> {traced_mapping[module]}")
        if len(only_traced) > 5:
            print(f"    ... and {len(only_traced) - 5} more")
    
    return {
        'static_count': len(static_modules),
        'traced_count': len(traced_modules),
        'common_count': len(common_modules),
        'tag_matches': tag_matches,
        'tag_differences': len(tag_differences),
        'differences': tag_differences
    }

# Compare the mappings
comparison = compare_hierarchy_mappings(static_hierarchy_mapping, traced_mapping)

### Execution Trace Analysis

Let's analyze the execution trace to understand the call flow.

In [None]:
# Analyze execution trace
print("Execution Trace Analysis:")
print("=" * 50)

# Show first few trace entries
print("First 10 execution steps:")
for i, trace in enumerate(tracer.execution_trace[:10]):
    indent = "  " * (trace['stack_depth'] - 1)
    action_symbol = "→" if trace['action'] == 'enter' else "←"
    print(f"{i:2d}: {indent}{action_symbol} {trace['module_name']}")

if len(tracer.execution_trace) > 10:
    print(f"... and {len(tracer.execution_trace) - 10} more steps")

# Show execution order of modules
print(f"\nModule execution order:")
execution_order = {}
for module_name, context in tracer.operation_context.items():
    execution_order[context['execution_order']] = (module_name, context['tag'])

for order in sorted(execution_order.keys()):
    module_name, tag = execution_order[order]
    print(f"{order:2d}: {module_name:30} -> {tag}")

print(f"\nMax stack depth reached: {execution_summary['max_stack_depth']}")

### Key Insights

**Static vs Tracing Comparison:**

1. **Static Approach**: Fast, comprehensive, covers all modules in model structure
2. **Tracing Approach**: Accurate, execution-based, includes important torch.nn modules like LayerNorm
3. **Key Difference**: Tracing captures modules that actually execute and includes torch.nn modules that contribute to hierarchy

**When to Use:**
- **Static**: When you need fast analysis of model structure
- **Tracing**: When you need accurate operation-to-module mapping for ONNX export
- **Both**: For validation and comprehensive coverage

**Tracing Advantages:**
- Captures actual execution flow
- Includes important torch.nn modules (LayerNorm, Embedding)
- Provides execution order and timing
- Better for dynamic module behavior

**Test with different models by changing MODEL_NAME above!**