# Module ↔ Hierarchy Mapping: Static vs Dynamic Approaches

This notebook compares the current static approach with the proposed dynamic hook-based approach for creating module-to-hierarchy mappings.

In [2]:
import sys
sys.path.append('/mnt/d/BYOM/modelexport')

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from typing import Dict, List, Optional
import json
from pathlib import Path

# Load BERT-tiny for demonstration
model = AutoModel.from_pretrained("prajjwal1/bert-tiny")
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
inputs = tokenizer("Hello world", return_tensors="pt")

print("Setup complete!")

Setup complete!


## 🏗️ Current Approach: Static Analysis

Creates hierarchy tags from module paths without execution:

In [8]:
def static_hierarchy_mapping(model: nn.Module) -> Dict[str, str]:
    """Current approach: Static module path → hierarchy tag mapping"""
    
    mapping = {}
    
    # Static analysis - no forward pass needed
    for name, module in model.named_modules():
        if name:  # Skip root
            # Parse path and generate hierarchy tag
            hierarchy_tag = generate_hierarchy_tag(name, module)
            mapping[name] = hierarchy_tag
    
    return mapping

def generate_hierarchy_tag(module_path: str, module: nn.Module) -> str:
    """Simplified tag generation for demo"""
    
    # Simplified version of our algorithm
    segments = module_path.split('.')
    hierarchy_parts = ["BertModel"]
    
    i = 0
    while i < len(segments):
        segment = segments[i]
        
        if segment == "encoder":
            hierarchy_parts.append("BertEncoder")
        elif segment == "layer":
            # Skip ModuleList
            pass
        elif segment.isdigit():
            # Instance number
            hierarchy_parts.append(f"BertLayer.{segment}")
        elif segment == "attention":
            hierarchy_parts.append("BertAttention")
        elif segment == "embeddings":
            hierarchy_parts.append("BertEmbeddings")
        # ... etc
        
        i += 1
    
    return "/" + "/".join(hierarchy_parts)

# Demo static mapping
static_mapping = static_hierarchy_mapping(model)

print("📊 Static Mapping Examples:")
print("=" * 80)
examples = list(static_mapping.items())[:]
for path, tag in examples:
    print(f"{path:<40} → {tag}")
    
print(f"\nTotal mappings: {len(static_mapping)}")

📊 Static Mapping Examples:
embeddings                               → /BertModel/BertEmbeddings
embeddings.word_embeddings               → /BertModel/BertEmbeddings
embeddings.position_embeddings           → /BertModel/BertEmbeddings
embeddings.token_type_embeddings         → /BertModel/BertEmbeddings
embeddings.LayerNorm                     → /BertModel/BertEmbeddings
embeddings.dropout                       → /BertModel/BertEmbeddings
encoder                                  → /BertModel/BertEncoder
encoder.layer                            → /BertModel/BertEncoder
encoder.layer.0                          → /BertModel/BertEncoder/BertLayer.0
encoder.layer.0.attention                → /BertModel/BertEncoder/BertLayer.0/BertAttention
encoder.layer.0.attention.self           → /BertModel/BertEncoder/BertLayer.0/BertAttention
encoder.layer.0.attention.self.query     → /BertModel/BertEncoder/BertLayer.0/BertAttention
encoder.layer.0.attention.self.key       → /BertModel/BertEncoder/BertLay

## 🎣 Proposed Approach: Dynamic Hook-Based Mapping

Uses forward hooks to capture actual execution and map to ONNX operations:

In [7]:
class DynamicHierarchyMapper:
    """Hook-based dynamic hierarchy mapping inspired by HTP"""
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.hierarchy_stack = ["BertModel"]  # Start with root
        self.module_to_hierarchy = {}  # module_id → hierarchy_tag
        self.execution_trace = []  # Track execution order
        self.hooks = []
        
        # Create module name mapping
        self.module_names = {}
        for name, module in model.named_modules():
            self.module_names[id(module)] = name
        
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward hooks on all modules"""
        
        for name, module in self.model.named_modules():
            hook = module.register_forward_hook(self._create_hook(name, module))
            self.hooks.append(hook)
    
    def _create_hook(self, module_name: str, module: nn.Module):
        """Create a forward hook for a specific module"""
        
        def forward_hook(module, input, output):
            # Determine hierarchy tag for this module
            hierarchy_tag = self._get_current_hierarchy_tag(module_name, module)
            
            # Record mapping
            module_id = id(module)
            self.module_to_hierarchy[module_id] = hierarchy_tag
            
            # Record execution trace
            self.execution_trace.append({
                'module_name': module_name,
                'module_class': module.__class__.__name__,
                'hierarchy_tag': hierarchy_tag,
                'execution_order': len(self.execution_trace)
            })
        
        return forward_hook
    
    def _get_current_hierarchy_tag(self, module_name: str, module: nn.Module) -> str:
        """Generate hierarchy tag based on current execution context"""
        
        if not module_name:  # Root module
            return "/BertModel"
        
        # Use same logic as static approach but could be enhanced
        # with execution context information
        return generate_hierarchy_tag(module_name, module)
    
    def trace_forward_pass(self, *args, **kwargs):
        """Run forward pass and capture hierarchy mappings"""
        
        # Clear previous traces
        self.execution_trace.clear()
        self.module_to_hierarchy.clear()
        
        # Run forward pass (this triggers all hooks)
        with torch.no_grad():
            output = self.model(*args, **kwargs)
        
        return output
    
    def get_execution_mapping(self) -> Dict[str, str]:
        """Get module execution order → hierarchy mapping"""
        
        mapping = {}
        for trace in self.execution_trace:
            mapping[trace['module_name']] = trace['hierarchy_tag']
        
        return mapping
    
    def cleanup(self):
        """Remove all hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

# Demo dynamic mapping
dynamic_mapper = DynamicHierarchyMapper(model)

print("🎣 Dynamic Mapping with Forward Hooks:")
print("=" * 80)

# Trace forward pass
output = dynamic_mapper.trace_forward_pass(inputs['input_ids'], inputs['attention_mask'])
dynamic_mapping = dynamic_mapper.get_execution_mapping()

print(f"Executed {len(dynamic_mapper.execution_trace)} modules")
print("\nFirst 10 executed modules:")
for i, trace in enumerate(dynamic_mapper.execution_trace[:]):
    print(f"{i+1:2d}. {trace['module_name']:<35} → {trace['hierarchy_tag']}")

dynamic_mapper.cleanup()
print(f"\nTotal dynamic mappings: {len(dynamic_mapping)}")

🎣 Dynamic Mapping with Forward Hooks:
Executed 45 modules

First 10 executed modules:
 1. embeddings.word_embeddings          → /BertModel/BertEmbeddings
 2. embeddings.token_type_embeddings    → /BertModel/BertEmbeddings
 3. embeddings.position_embeddings      → /BertModel/BertEmbeddings
 4. embeddings.LayerNorm                → /BertModel/BertEmbeddings
 5. embeddings.dropout                  → /BertModel/BertEmbeddings
 6. embeddings                          → /BertModel/BertEmbeddings
 7. encoder.layer.0.attention.self.query → /BertModel/BertEncoder/BertLayer.0/BertAttention
 8. encoder.layer.0.attention.self.key  → /BertModel/BertEncoder/BertLayer.0/BertAttention
 9. encoder.layer.0.attention.self.value → /BertModel/BertEncoder/BertLayer.0/BertAttention
10. encoder.layer.0.attention.self      → /BertModel/BertEncoder/BertLayer.0/BertAttention
11. encoder.layer.0.attention.output.dense → /BertModel/BertEncoder/BertLayer.0/BertAttention
12. encoder.layer.0.attention.output.dropout →

## 🤔 Comparison: Static vs Dynamic

Let's compare the two approaches:

In [5]:
print("🔍 STATIC vs DYNAMIC COMPARISON")
print("=" * 80)

# Compare coverage
static_modules = set(static_mapping.keys())
dynamic_modules = set(dynamic_mapping.keys())

only_static = static_modules - dynamic_modules
only_dynamic = dynamic_modules - static_modules
both = static_modules & dynamic_modules

print(f"📊 Coverage Comparison:")
print(f"  Static only:  {len(only_static):3d} modules")
print(f"  Dynamic only: {len(only_dynamic):3d} modules")
print(f"  Both:         {len(both):3d} modules")
print(f"  Total static: {len(static_modules):3d} modules")
print(f"  Total dynamic:{len(dynamic_modules):3d} modules")

print(f"\n🔍 Modules only found in static analysis:")
for module in sorted(list(only_static)[:10]):
    print(f"  - {module}")
if len(only_static) > 10:
    print(f"  ... and {len(only_static) - 10} more")

print(f"\n🎯 Modules only found during execution:")
for module in sorted(list(only_dynamic)[:10]):
    print(f"  - {module}")
if len(only_dynamic) > 10:
    print(f"  ... and {len(only_dynamic) - 10} more")

🔍 STATIC vs DYNAMIC COMPARISON
📊 Coverage Comparison:
  Static only:    3 modules
  Dynamic only:   1 modules
  Both:          44 modules
  Total static:  47 modules
  Total dynamic: 45 modules

🔍 Modules only found in static analysis:
  - encoder.layer
  - encoder.layer.0.attention.self.dropout
  - encoder.layer.1.attention.self.dropout

🎯 Modules only found during execution:
  - 


## 💡 Advantages of Each Approach

### 🏗️ Static Analysis (Current)

**Pros:**
- ✅ **Fast**: No forward pass execution needed
- ✅ **Complete coverage**: Finds ALL modules in model
- ✅ **Deterministic**: Always same result
- ✅ **Simple**: No hooks or execution tracing
- ✅ **Works for unused modules**: Even finds modules that never execute

**Cons:**
- ❌ **No operation mapping**: Doesn't map to actual ONNX operations
- ❌ **No execution context**: Misses dynamic behavior
- ❌ **No execution order**: Can't track execution flow

### 🎣 Dynamic Hook-Based (Proposed)

**Pros:**
- ✅ **Real execution**: Captures actual forward pass behavior
- ✅ **Operation mapping**: Can map to ONNX operations during export
- ✅ **Execution order**: Tracks the order modules are called
- ✅ **Dynamic models**: Handles conditional execution
- ✅ **Precise**: Only includes actually executed modules

**Cons:**
- ❌ **Slower**: Requires forward pass execution
- ❌ **Hook overhead**: Performance impact during export
- ❌ **Incomplete**: Might miss unused modules
- ❌ **Complex**: More implementation complexity

## 🎯 Proposed Hybrid Solution

Based on this analysis, I think we should use **both approaches**:

### Phase 1: Static Hierarchy Analysis
```python
# Get complete model structure
static_hierarchy = analyze_model_hierarchy(model)
```

### Phase 2: Dynamic Operation Mapping
```python
# During ONNX export, use hooks to map operations
with DynamicHierarchyMapper(model, static_hierarchy) as mapper:
    onnx_model = torch.onnx.export(model, inputs, ...)
    operation_tags = mapper.get_operation_hierarchy_map()
```

### Benefits:
- ✅ **Complete coverage** from static analysis
- ✅ **Precise operation tagging** from dynamic hooks
- ✅ **Best of both worlds**

In [6]:
print("💡 RECOMMENDATION")
print("=" * 80)
print()
print("Use HYBRID approach:")
print()
print("1. 🏗️  Static Analysis Phase:")
print("   - Build complete module hierarchy mapping")
print("   - Generate hierarchy tags for all modules")
print("   - Handle torch.nn filtering")
print()
print("2. 🎣  Dynamic Hook Phase (during ONNX export):")
print("   - Use forward hooks to capture execution")
print("   - Map ONNX operations to executing modules")
print("   - Apply hierarchy tags from static analysis")
print()
print("This gives us:")
print("✅ Complete model understanding (static)")
print("✅ Precise operation mapping (dynamic)")
print("✅ Both structure and execution info")

💡 RECOMMENDATION

Use HYBRID approach:

1. 🏗️  Static Analysis Phase:
   - Build complete module hierarchy mapping
   - Generate hierarchy tags for all modules
   - Handle torch.nn filtering

2. 🎣  Dynamic Hook Phase (during ONNX export):
   - Use forward hooks to capture execution
   - Map ONNX operations to executing modules
   - Apply hierarchy tags from static analysis

This gives us:
✅ Complete model understanding (static)
✅ Precise operation mapping (dynamic)
✅ Both structure and execution info
