# TorchScript Intermediate State Exploration

This notebook explores what information is available in TorchScript graph before ONNX conversion,
specifically investigating the scope information that contains module hierarchy.

## Key Research Questions:
1. What scope information is available at the TorchScript level?
2. How can we access `node.scopeName()` properly?
3. Where exactly is context lost during ONNX conversion?

In [1]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
import json
from typing import Dict, List, Any
import os

# Create temp directory for outputs
os.makedirs("../temp", exist_ok=True)

print("✅ Imports successful")

✅ Imports successful


## 1. Load Model and Prepare Inputs

In [2]:
# Load a small model for exploration
model_name = "prajjwal1/bert-tiny"
print(f"Loading model: {model_name}")

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Prepare inputs
text = "Hello world"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# Set model to eval mode
model.eval()

print(f"Model type: {type(model).__name__}")
print(f"Input shapes: {[(k, v.shape) for k, v in inputs.items()]}")
print("✅ Model loaded successfully")

Loading model: prajjwal1/bert-tiny


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Model type: BertModel
Input shapes: [('input_ids', torch.Size([1, 4])), ('token_type_ids', torch.Size([1, 4])), ('attention_mask', torch.Size([1, 4]))]
✅ Model loaded successfully


## 2. Attempt Different Tracing Approaches

We'll try multiple methods to trace the model and see which one preserves scope information.

In [3]:
# Try multiple tracing approaches
traced_model = None
tracing_method = None

print("Trying different tracing approaches...\n")

# Approach 1: Direct tracing with strict=False (to handle dict output)
try:
    print("1. Trying direct tracing with strict=False...")
    with torch.no_grad():
        traced_model = torch.jit.trace(
            model, 
            inputs, 
            strict=False,
            check_trace=False
        )
    print("  ✅ Direct tracing successful")
    tracing_method = "direct_trace"
except Exception as e:
    print(f"  ❌ Direct tracing failed: {e}")
    
    # Approach 2: Try with positional args
    try:
        print("\n2. Trying positional args tracing...")
        with torch.no_grad():
            traced_model = torch.jit.trace(
                model, 
                (inputs['input_ids'], inputs['attention_mask']),
                strict=False,
                check_trace=False
            )
        print("  ✅ Positional args tracing successful")
        tracing_method = "positional_trace"
    except Exception as e:
        print(f"  ❌ Positional args tracing failed: {e}")
        
        # Approach 3: Wrapper approach
        print("\n3. Using wrapper as fallback...")
        class ModelWrapper(nn.Module):
            def __init__(self, model):
                super().__init__()
                self.model = model
            
            def forward(self, input_ids, attention_mask):
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                return outputs.last_hidden_state
        
        wrapped_model = ModelWrapper(model)
        with torch.no_grad():
            traced_model = torch.jit.trace(wrapped_model, (inputs['input_ids'], inputs['attention_mask']))
        print("  ✅ Wrapper tracing successful")
        tracing_method = "wrapper_trace"

print(f"\n🎯 Using tracing method: {tracing_method}")
print(f"Traced model type: {type(traced_model)}")
print(f"Has graph: {hasattr(traced_model, 'graph')}")

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Trying different tracing approaches...

1. Trying direct tracing with strict=False...
  ❌ Direct tracing failed: Type 'Tuple[str, str, str]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

2. Trying positional args tracing...
  ✅ Positional args tracing successful

🎯 Using tracing method: positional_trace
Traced model type: <class 'torch.jit._trace.TopLevelTracedModule'>
Has graph: True


## 3. Access TorchScript Graph and Explore Node Information

Now let's examine the graph structure and try to extract scope information.

In [4]:
# Access the TorchScript graph
print("Accessing TorchScript graph...\n")
graph = traced_model.graph
nodes = list(graph.nodes())

print(f"Total nodes in graph: {len(nodes)}")
print(f"Graph type: {type(graph)}")

# Show graph string representation (this might contain scope info)
print("\n=== Graph String Representation (first 2000 chars) ===")
graph_str = str(graph)
print(graph_str[:2000])
if len(graph_str) > 2000:
    print(f"... (truncated, total length: {len(graph_str)} chars)")

Accessing TorchScript graph...

Total nodes in graph: 74
Graph type: <class 'torch.Graph'>

=== Graph String Representation (first 2000 chars) ===
graph(%self.1 : __torch__.transformers.models.bert.modeling_bert.___torch_mangle_78.BertModel,
      %input_ids : Long(1, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %attention_mask.1 : Long(1, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %pooler : __torch__.transformers.models.bert.modeling_bert.___torch_mangle_77.BertPooler = prim::GetAttr[name="pooler"](%self.1)
  %encoder : __torch__.transformers.models.bert.modeling_bert.___torch_mangle_74.BertEncoder = prim::GetAttr[name="encoder"](%self.1)
  %embeddings : __torch__.transformers.models.bert.modeling_bert.___torch_mangle_36.BertEmbeddings = prim::GetAttr[name="embeddings"](%self.1)
  %embeddings.3 : __torch__.transformers.models.bert.modeling_bert.___torch_mangle_36.BertEmbeddings = prim::GetAttr[name="embeddings"](%self.1)
  %token_type_ids : Tensor = prim::GetAttr[na

## 4. Deep Node Analysis

Let's examine individual nodes and try different methods to extract scope information.

In [5]:
# Explore node information in detail
print("Exploring node information...\n")

node_info = []
scope_hierarchy = set()

for i, node in enumerate(nodes[:10]):  # First 10 nodes for detailed analysis
    try:
        print(f"=== Node {i} ===")
        
        # Basic node information
        info = {
            "index": i,
            "kind": str(node.kind()),
            "has_scope": hasattr(node, 'scopeName'),
        }
        
        # Try different ways to get scope information
        scope_methods = {
            "scopeName()": lambda n: str(n.scopeName()) if hasattr(n, 'scopeName') else "No method",
            "scope()": lambda n: str(n.scope()) if hasattr(n, 'scope') else "No method", 
            "sourceRange()": lambda n: str(n.sourceRange()) if hasattr(n, 'sourceRange') else "No method",
            "debugName()": lambda n: str(n.debugName()) if hasattr(n, 'debugName') else "No method",
        }
        
        for method_name, method_func in scope_methods.items():
            try:
                result = method_func(node)
                info[method_name] = result
                print(f"  {method_name}: {result}")
            except Exception as e:
                info[method_name] = f"Error: {e}"
                print(f"  {method_name}: Error - {e}")
        
        # Try to get all available methods/attributes
        available_methods = [attr for attr in dir(node) if not attr.startswith('_')]
        info["available_methods"] = available_methods[:10]  # First 10 to avoid clutter
        
        print(f"  Available methods (first 10): {info['available_methods']}")
        
        node_info.append(info)
        print()
        
    except Exception as e:
        print(f"Error processing node {i}: {e}")
        node_info.append({"index": i, "error": str(e)})

print(f"✅ Analyzed {len(node_info)} nodes")

Exploring node information...

=== Node 0 ===
  scopeName(): 
  scope(): No method
  sourceRange(): 
  debugName(): No method
  Available methods (first 10): ['addBlock', 'addInput', 'addOutput', 'attributeNames', 'blocks', 'c', 'c_', 'cconv', 'copyAttributes', 'copyMetadata']

=== Node 1 ===
  scopeName(): 
  scope(): No method
  sourceRange(): 
  debugName(): No method
  Available methods (first 10): ['addBlock', 'addInput', 'addOutput', 'attributeNames', 'blocks', 'c', 'c_', 'cconv', 'copyAttributes', 'copyMetadata']

=== Node 2 ===
  scopeName(): 
  scope(): No method
  sourceRange(): 
  debugName(): No method
  Available methods (first 10): ['addBlock', 'addInput', 'addOutput', 'attributeNames', 'blocks', 'c', 'c_', 'cconv', 'copyAttributes', 'copyMetadata']

=== Node 3 ===
  scopeName(): 
  scope(): No method
  sourceRange(): 
  debugName(): No method
  Available methods (first 10): ['addBlock', 'addInput', 'addOutput', 'attributeNames', 'blocks', 'c', 'c_', 'cconv', 'copyAttribu

## 5. Try to Access the Full Graph IR

The error message we saw earlier contained rich scope information. Let's try to access that.

In [6]:
# Try to access the full IR graph
print("Trying to access detailed graph representation...\n")

# Method 1: Get the inlined graph (this might have scope info)
try:
    if hasattr(traced_model, 'inlined_graph'):
        inlined_graph = traced_model.inlined_graph
        print(f"Found inlined_graph: {type(inlined_graph)}")
        
        inlined_nodes = list(inlined_graph.nodes())
        print(f"Inlined graph has {len(inlined_nodes)} nodes")
        
        # Check first few nodes for scope info
        for i, node in enumerate(inlined_nodes[:5]):
            scope_name = str(node.scopeName()) if hasattr(node, 'scopeName') else "No scope"
            print(f"  Inlined Node {i}: {node.kind()} - Scope: {scope_name}")
    else:
        print("No inlined_graph attribute found")
except Exception as e:
    print(f"Error accessing inlined_graph: {e}")

# Method 2: Try to trigger an ONNX export to see the detailed IR in the error
print("\n--- Attempting ONNX export to reveal detailed IR ---")
try:
    onnx_file = "../temp/debug_onnx_export.onnx"
    
    if tracing_method == "positional_trace":
        torch.onnx.export(
            traced_model,
            (inputs['input_ids'], inputs['attention_mask']),
            onnx_file,
            input_names=['input_ids', 'attention_mask'],
            output_names=['output'],
            do_constant_folding=False,  # Keep constants separate
            verbose=True
        )
        print("✅ ONNX export successful (unexpectedly!)")
    else:
        print("Skipping ONNX export for wrapper method (would fail)")
        
except Exception as e:
    error_str = str(e)
    print(f"ONNX export failed as expected. Analyzing error for scope info...")
    
    # Look for scope information in the error message
    if "scope:" in error_str:
        print("\n🎯 FOUND SCOPE INFORMATION IN ERROR!")
        
        # Extract lines containing scope information
        scope_lines = [line.strip() for line in error_str.split('\n') if 'scope:' in line]
        
        print(f"Found {len(scope_lines)} lines with scope information:")
        for i, line in enumerate(scope_lines[:10]):  # Show first 10
            print(f"  {i+1}. {line}")
        
        if len(scope_lines) > 10:
            print(f"  ... and {len(scope_lines) - 10} more")
            
        # Extract unique scopes
        unique_scopes = set()
        for line in scope_lines:
            if 'scope:' in line:
                scope_part = line.split('scope:')[1].split('#')[0].strip()
                unique_scopes.add(scope_part)
        
        print(f"\n🎯 Unique scopes found: {len(unique_scopes)}")
        for scope in sorted(list(unique_scopes))[:15]:  # Show first 15
            print(f"  {scope}")
    else:
        print("No scope information found in error message.")

Trying to access detailed graph representation...

Found inlined_graph: <class 'torch.Graph'>
Inlined graph has 276 nodes
  Inlined Node 0: prim::GetAttr - Scope: 
  Inlined Node 1: prim::GetAttr - Scope: 
  Inlined Node 2: prim::GetAttr - Scope: 
  Inlined Node 3: prim::GetAttr - Scope: 
  Inlined Node 4: prim::GetAttr - Scope: 

--- Attempting ONNX export to reveal detailed IR ---
ONNX export failed as expected. Analyzing error for scope info...
No scope information found in error message.
Torch IR graph at exception: graph(%input_ids : Long(1, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %attention_mask.1 : Long(1, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %embeddings.token_type_ids : Long(1, 512, strides=[512, 1], requires_grad=0, device=cpu),
      %embeddings.position_ids : Long(1, 512, strides=[512, 1], requires_grad=0, device=cpu),
      %embeddings.word_embeddings.weight : Float(30522, 128, strides=[128, 1], requires_grad=0, device=cpu),
      %embedding



: Float(128, strides=[1], requires_grad=0, device=cpu),
      %encoder.layer.1.output.dense.weight : Float(128, 512, strides=[512, 1], requires_grad=0, device=cpu),
      %encoder.layer.1.output.LayerNorm.bias : Float(128, strides=[1], requires_grad=0, device=cpu),
      %encoder.layer.1.output.LayerNorm.weight : Float(128, strides=[1], requires_grad=0, device=cpu),
      %pooler.dense.bias : Float(128, strides=[1], requires_grad=0, device=cpu),
      %pooler.dense.weight : Float(128, 128, strides=[128, 1], requires_grad=0, device=cpu)):
  %44 : str = prim::Constant[value="pooler_output"](), scope: transformers.models.bert.modeling_bert.BertModel:: # /mnt/d/BYOM/modelexport/.venv/lib/python3.12/site-packages/torch/jit/_trace.py:1279:0
  %45 : str = prim::Constant[value="last_hidden_state"](), scope: transformers.models.bert.modeling_bert.BertModel:: # /mnt/d/BYOM/modelexport/.venv/lib/python3.12/site-packages/torch/jit/_trace.py:1279:0
  %46 : str = prim::Constant[value="none"](), scop

## 6. Alternative Approach: torch.fx Tracing

Let's try torch.fx which might preserve more context.

In [8]:
print("Trying torch.fx symbolic tracing...\n")

try:
    import torch.fx as fx
    
    # Try to symbolically trace the model
    print("1. Attempting symbolic tracing...")
    
    # This might fail for transformers models, but let's try
    try:
        fx_traced = fx.symbolic_trace(model)
        print("  ✅ Symbolic tracing successful!")
        
        # Explore FX graph
        fx_graph = fx_traced.graph
        fx_nodes = list(fx_graph.nodes)
        
        print(f"  FX graph has {len(fx_nodes)} nodes")
        
        # Show first few nodes
        for i, node in enumerate(fx_nodes[:10]):
            print(f"    Node {i}: {node.op} - {node.target} - {node.name}")
            if hasattr(node, 'meta') and node.meta:
                print(f"      Meta: {node.meta}")
                
        # Try torch.export (PyTorch 2.0+)
        print("\n2. Attempting torch.export...")
        try:
            exported_program = torch.export.export(model, (inputs['input_ids'], inputs['attention_mask']))
            print("  ✅ torch.export successful!")
            
            export_graph = exported_program.module.graph
            export_nodes = list(export_graph.nodes)
            
            print(f"  Export graph has {len(export_nodes)} nodes")
            
            # Show first few nodes with metadata
            for i, node in enumerate(export_nodes[:5]):
                print(f"    Export Node {i}: {node.op} - {node.target}")
                if hasattr(node, 'meta') and node.meta:
                    print(f"      Meta keys: {list(node.meta.keys())}")
                    if 'stack_trace' in node.meta:
                        print(f"      Stack trace available: {len(node.meta['stack_trace'])} frames")
                        
        except Exception as e:
            print(f"  ❌ torch.export failed: {e}")
                    
    except Exception as e:
        print(f"  ❌ Symbolic tracing failed: {e}")
        
except ImportError:
    print("torch.fx not available in this PyTorch version")
except Exception as e:
    print(f"Error with torch.fx: {e}")

Trying torch.fx symbolic tracing...

1. Attempting symbolic tracing...
  ❌ Symbolic tracing failed: You cannot specify both input_ids and inputs_embeds at the same time


## 7. Save Results and Summary

In [None]:
# Save detailed results
results = {
    "model_name": model_name,
    "tracing_method": tracing_method,
    "total_nodes": len(nodes),
    "graph_string_length": len(str(graph)),
    "node_analysis": node_info,
    "exploration_summary": {
        "torch_jit_scope_available": any(info.get("scopeName()", "") != "" for info in node_info),
        "methods_attempted": list(scope_methods.keys()) if 'scope_methods' in locals() else [],
        "graph_string_contains_scope": "scope:" in str(graph),
    }
}

# Save to JSON
output_file = "../temp/torchscript_exploration_notebook.json"
with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"✅ Results saved to: {output_file}")

# Summary
print("\n" + "="*60)
print("EXPLORATION SUMMARY")
print("="*60)
print(f"Model: {model_name}")
print(f"Tracing method: {tracing_method}")
print(f"Total TorchScript nodes: {len(nodes)}")
print(f"Graph string length: {len(str(graph))} chars")
print(f"Contains 'scope:' in graph string: {'scope:' in str(graph)}")

scope_available = any(info.get("scopeName()", "") not in ["", "No method"] for info in node_info)
print(f"Scope information accessible via scopeName(): {scope_available}")

print("\n🎯 KEY FINDINGS:")
if "scope:" in str(graph):
    print("✅ Scope information IS present in the graph string representation")
    print("✅ We can extract module hierarchy from TorchScript before ONNX export")
else:
    print("❌ No scope information found in graph representation")
    
if scope_available:
    print("✅ Individual node scope access working")
else:
    print("❌ Individual node scopeName() not working - need alternative approach")

print("\n💡 NEXT STEPS:")
print("1. Parse graph string representation to extract scope information")
print("2. Implement JIT graph interception before ONNX conversion")
print("3. Create hierarchy mapping from TorchScript to ONNX nodes")