# Understanding OnnxConfig: From HuggingFace Config to ONNX Export

This notebook explains:
1. What is OnnxConfig and why it's needed
2. The structure and specification of OnnxConfig
3. How to generate OnnxConfig from a HuggingFace model config
4. Practical examples of creating custom OnnxConfigs

## 1. What is OnnxConfig?

**OnnxConfig** is a configuration class that tells the ONNX exporter:
- What inputs the model expects (names, shapes, types)
- What outputs the model produces
- How to generate dummy inputs for tracing
- Which axes should be dynamic (variable batch size, sequence length)
- ONNX opset version to use

Think of it as a "recipe" for converting a PyTorch model to ONNX.

## 2. OnnxConfig Specification

### Class Hierarchy
```
OnnxConfig (base class)
├── OnnxConfigWithPast (for decoder models with KV cache)
├── OnnxSeq2SeqConfig (for encoder-decoder models)
└── Model-specific configs (BertOnnxConfig, GPT2OnnxConfig, etc.)
```

### Key Components

1. **inputs** property: Defines input names and dynamic axes
2. **outputs** property: Defines output names and dynamic axes
3. **generate_dummy_inputs()**: Creates example inputs for tracing
4. **DEFAULT_ONNX_OPSET**: ONNX operator set version (default: 11)
5. **ATOL_FOR_VALIDATION**: Tolerance for validation

In [None]:
# Let's look at the OnnxConfig interface
from typing import Dict, Any, Tuple
import torch

class OnnxConfigInterface:
    """Simplified interface showing what OnnxConfig provides"""
    
    def __init__(self, config, task="default"):
        self.config = config  # HuggingFace PretrainedConfig
        self.task = task      # Task like "text-classification"
    
    @property
    def inputs(self) -> Dict[str, Dict[int, str]]:
        """Define input names and their dynamic axes
        
        Returns:
            Dict mapping input names to axis definitions
            Example: {
                "input_ids": {0: "batch_size", 1: "sequence_length"},
                "attention_mask": {0: "batch_size", 1: "sequence_length"}
            }
        """
        pass
    
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        """Define output names and their dynamic axes"""
        pass
    
    def generate_dummy_inputs(self, 
                            batch_size: int = 1,
                            seq_length: int = 128,
                            **kwargs) -> Dict[str, torch.Tensor]:
        """Generate dummy inputs for model tracing
        
        Returns:
            Dict of input tensors
            Example: {
                "input_ids": torch.tensor([[101, 2023, ...]])
                "attention_mask": torch.tensor([[1, 1, ...]])
            }
        """
        pass

## 3. How to Generate OnnxConfig from HF Config

The process involves:
1. Load the HuggingFace model config
2. Detect the model type and task
3. Determine input/output specifications
4. Create appropriate OnnxConfig

Let's implement this step by step:

In [None]:
# Step 1: Load a HuggingFace config
from transformers import AutoConfig

# Example: Load BERT config
config = AutoConfig.from_pretrained("bert-base-uncased")

print(f"Model type: {config.model_type}")
print(f"Hidden size: {config.hidden_size}")
print(f"Num layers: {config.num_hidden_layers}")
print(f"Architectures: {config.architectures}")

In [None]:
# Step 2: Detect task from architecture
def detect_task_from_config(config):
    """Detect the task from model config"""
    
    # Check architectures field
    if hasattr(config, 'architectures') and config.architectures:
        arch = config.architectures[0]
        
        # Map architecture patterns to tasks
        if "ForSequenceClassification" in arch:
            return "text-classification"
        elif "ForTokenClassification" in arch:
            return "token-classification"
        elif "ForQuestionAnswering" in arch:
            return "question-answering"
        elif "ForCausalLM" in arch or "LMHead" in arch:
            return "text-generation"
        elif "ForMaskedLM" in arch:
            return "fill-mask"
        elif "ForImageClassification" in arch:
            return "image-classification"
    
    # Default to feature extraction
    return "feature-extraction"

task = detect_task_from_config(config)
print(f"Detected task: {task}")

In [None]:
# Step 3: Determine input specifications based on model type
def get_input_specs_for_model_type(model_type: str, config):
    """Get input specifications for a model type"""
    
    # Common text model inputs
    text_inputs = {
        "bert": ["input_ids", "attention_mask", "token_type_ids"],
        "roberta": ["input_ids", "attention_mask"],
        "gpt2": ["input_ids", "attention_mask"],
        "t5": ["input_ids", "attention_mask", "decoder_input_ids"],
    }
    
    # Vision model inputs
    vision_inputs = {
        "vit": ["pixel_values"],
        "resnet": ["pixel_values"],
    }
    
    # Get inputs for model type
    if model_type in text_inputs:
        return text_inputs[model_type]
    elif model_type in vision_inputs:
        return vision_inputs[model_type]
    else:
        # Default text inputs
        return ["input_ids", "attention_mask"]

input_names = get_input_specs_for_model_type(config.model_type, config)
print(f"Input names: {input_names}")

## 4. Creating a Universal OnnxConfig

Now let's create a universal OnnxConfig that can work with any model:

In [None]:
import torch
from typing import Dict, Any, List

class UniversalOnnxConfig:
    """Universal OnnxConfig that works with any HF model"""
    
    DEFAULT_ONNX_OPSET = 14  # Modern ONNX opset
    
    def __init__(self, config, task="default"):
        self.config = config
        self.task = task if task != "default" else self._detect_task()
        self.model_type = config.model_type
        
    def _detect_task(self) -> str:
        """Auto-detect task from config"""
        if hasattr(self.config, 'architectures') and self.config.architectures:
            arch = self.config.architectures[0]
            if "ForSequenceClassification" in arch:
                return "text-classification"
            elif "ForCausalLM" in arch:
                return "text-generation"
        return "feature-extraction"
    
    @property
    def inputs(self) -> Dict[str, Dict[int, str]]:
        """Define inputs with dynamic axes"""
        
        # Base dynamic axes for text
        dynamic_axes = {0: "batch_size", 1: "sequence_length"}
        
        # Determine input names based on model type
        if self.model_type == "bert":
            return {
                "input_ids": dynamic_axes,
                "attention_mask": dynamic_axes,
                "token_type_ids": dynamic_axes,
            }
        elif self.model_type in ["gpt2", "llama", "mistral"]:
            return {
                "input_ids": dynamic_axes,
                "attention_mask": dynamic_axes,
            }
        elif self.model_type in ["vit", "resnet"]:
            # Vision models
            return {
                "pixel_values": {0: "batch_size"},
            }
        else:
            # Default
            return {
                "input_ids": dynamic_axes,
                "attention_mask": dynamic_axes,
            }
    
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        """Define outputs with dynamic axes"""
        
        if self.task == "text-classification":
            return {"logits": {0: "batch_size"}}
        elif self.task == "text-generation":
            return {
                "logits": {0: "batch_size", 1: "sequence_length"},
            }
        elif self.task == "question-answering":
            return {
                "start_logits": {0: "batch_size"},
                "end_logits": {0: "batch_size"},
            }
        else:
            # Feature extraction
            return {
                "last_hidden_state": {0: "batch_size", 1: "sequence_length"},
            }
    
    def generate_dummy_inputs(self, 
                            batch_size: int = 1,
                            seq_length: int = 128,
                            **kwargs) -> Dict[str, torch.Tensor]:
        """Generate dummy inputs for tracing"""
        
        dummy_inputs = {}
        
        if self.model_type in ["bert", "roberta", "gpt2", "llama"]:
            # Text models
            dummy_inputs["input_ids"] = torch.randint(
                0, self.config.vocab_size, (batch_size, seq_length)
            )
            dummy_inputs["attention_mask"] = torch.ones(
                (batch_size, seq_length), dtype=torch.long
            )
            
            if self.model_type == "bert":
                dummy_inputs["token_type_ids"] = torch.zeros(
                    (batch_size, seq_length), dtype=torch.long
                )
                
        elif self.model_type in ["vit", "resnet"]:
            # Vision models
            image_size = getattr(self.config, "image_size", 224)
            num_channels = getattr(self.config, "num_channels", 3)
            
            dummy_inputs["pixel_values"] = torch.randn(
                (batch_size, num_channels, image_size, image_size)
            )
        
        return dummy_inputs
    
    def get_input_names(self) -> List[str]:
        """Get list of input names"""
        return list(self.inputs.keys())
    
    def get_output_names(self) -> List[str]:
        """Get list of output names"""
        return list(self.outputs.keys())
    
    def get_dynamic_axes(self) -> Dict[str, Dict[int, str]]:
        """Get combined dynamic axes for inputs and outputs"""
        return {**self.inputs, **self.outputs}

## 5. Using the Universal OnnxConfig

Let's see how to use this universal config to export a model:

In [None]:
# Example 1: BERT model
from transformers import AutoConfig, AutoModel

# Load config and model
config = AutoConfig.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
model.eval()

# Create universal OnnxConfig
onnx_config = UniversalOnnxConfig(config)

print("Task:", onnx_config.task)
print("Input names:", onnx_config.get_input_names())
print("Output names:", onnx_config.get_output_names())
print("\nInputs with dynamic axes:")
for name, axes in onnx_config.inputs.items():
    print(f"  {name}: {axes}")

# Generate dummy inputs
dummy_inputs = onnx_config.generate_dummy_inputs()
print("\nDummy input shapes:")
for name, tensor in dummy_inputs.items():
    print(f"  {name}: {tensor.shape}")

In [None]:
# Example 2: Export to ONNX using the config
import torch
import tempfile
import os

def export_with_universal_config(model, config, output_path="model.onnx"):
    """Export a model to ONNX using UniversalOnnxConfig"""
    
    # Create OnnxConfig
    onnx_config = UniversalOnnxConfig(config)
    
    # Generate dummy inputs
    dummy_inputs = onnx_config.generate_dummy_inputs()
    
    # Export to ONNX
    torch.onnx.export(
        model,
        tuple(dummy_inputs.values()),  # Convert dict to tuple
        output_path,
        input_names=onnx_config.get_input_names(),
        output_names=onnx_config.get_output_names(),
        dynamic_axes=onnx_config.get_dynamic_axes(),
        opset_version=onnx_config.DEFAULT_ONNX_OPSET,
        do_constant_folding=True,
    )
    
    print(f"✅ Model exported to {output_path}")
    print(f"   File size: {os.path.getsize(output_path) / 1024 / 1024:.2f} MB")
    
    return onnx_config

# Export the model
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp:
    export_with_universal_config(model, config, tmp.name)
    
    # Verify the export
    import onnx
    onnx_model = onnx.load(tmp.name)
    onnx.checker.check_model(onnx_model)
    print("✅ ONNX model validation passed!")

## 6. Comparison with Optimum's Approach

HuggingFace Optimum uses model-specific OnnxConfig classes. Let's compare:

In [None]:
# Optimum's approach (if available)
try:
    from optimum.exporters.onnx import BertOnnxConfig
    
    # Model-specific config
    optimum_config = BertOnnxConfig(config)
    
    print("Optimum BertOnnxConfig:")
    print("  Inputs:", list(optimum_config.inputs.keys()))
    print("  Outputs:", list(optimum_config.outputs.keys()))
    
except ImportError:
    print("Optimum not installed. To see the comparison, install with:")
    print("pip install optimum[exporters]")

print("\nOur UniversalOnnxConfig:")
universal_config = UniversalOnnxConfig(config)
print("  Inputs:", universal_config.get_input_names())
print("  Outputs:", universal_config.get_output_names())

print("\nAdvantage of Universal approach:")
print("✅ Works with ANY model without specific implementation")
print("✅ Automatically detects task and configuration")
print("✅ No need to maintain model-specific classes")

## 7. Advanced: Handling Special Cases

Some models need special handling. Here's how to extend the universal config:

In [None]:
class AdvancedUniversalOnnxConfig(UniversalOnnxConfig):
    """Extended version with special case handling"""
    
    def __init__(self, config, task="default"):
        super().__init__(config, task)
        self.use_past = self._should_use_past()
        
    def _should_use_past(self) -> bool:
        """Check if model uses past key values (for generation)"""
        return (
            self.task == "text-generation" and 
            self.model_type in ["gpt2", "llama", "mistral", "falcon"]
        )
    
    @property
    def inputs(self) -> Dict[str, Dict[int, str]]:
        """Extended inputs with past key values support"""
        base_inputs = super().inputs
        
        # Add past key values for generation models
        if self.use_past:
            num_layers = getattr(self.config, "n_layer", 
                               getattr(self.config, "num_hidden_layers", 12))
            
            for i in range(num_layers):
                base_inputs[f"past_key_values.{i}.key"] = {
                    0: "batch_size", 2: "past_sequence_length"
                }
                base_inputs[f"past_key_values.{i}.value"] = {
                    0: "batch_size", 2: "past_sequence_length"
                }
        
        return base_inputs
    
    def generate_dummy_inputs(self, **kwargs) -> Dict[str, torch.Tensor]:
        """Extended dummy inputs with special handling"""
        dummy_inputs = super().generate_dummy_inputs(**kwargs)
        
        # Handle encoder-decoder models (T5, BART)
        if self.model_type in ["t5", "bart", "mbart"]:
            # Add decoder inputs
            batch_size = dummy_inputs["input_ids"].shape[0]
            seq_length = dummy_inputs["input_ids"].shape[1]
            
            dummy_inputs["decoder_input_ids"] = torch.randint(
                0, self.config.vocab_size, (batch_size, seq_length)
            )
            
        # Handle multimodal models (CLIP)
        elif self.model_type == "clip":
            # Add both text and image inputs
            batch_size = kwargs.get("batch_size", 1)
            seq_length = kwargs.get("seq_length", 77)  # CLIP default
            
            dummy_inputs["input_ids"] = torch.randint(
                0, self.config.vocab_size, (batch_size, seq_length)
            )
            dummy_inputs["attention_mask"] = torch.ones(
                (batch_size, seq_length), dtype=torch.long
            )
            dummy_inputs["pixel_values"] = torch.randn(
                (batch_size, 3, 224, 224)
            )
            
        return dummy_inputs

# Test with different model types
test_models = [
    "gpt2",           # Decoder-only
    "t5-small",       # Encoder-decoder
    "openai/clip-vit-base-patch32",  # Multimodal
]

for model_name in test_models:
    try:
        config = AutoConfig.from_pretrained(model_name)
        adv_config = AdvancedUniversalOnnxConfig(config)
        
        print(f"\n{model_name}:")
        print(f"  Model type: {adv_config.model_type}")
        print(f"  Task: {adv_config.task}")
        print(f"  Input names: {adv_config.get_input_names()[:5]}...")  # First 5
        print(f"  Uses past KV: {adv_config.use_past}")
    except Exception as e:
        print(f"\n{model_name}: Error - {e}")

## Summary

### Key Takeaways:

1. **OnnxConfig** is essential for ONNX export - it defines:
   - Input/output specifications
   - Dynamic axes for variable dimensions
   - Dummy input generation for tracing

2. **From HF Config to OnnxConfig**:
   - Detect model type from `config.model_type`
   - Detect task from `config.architectures`
   - Map model type to input/output specs
   - Generate appropriate dummy inputs

3. **Universal Approach Benefits**:
   - Works with any model without specific implementation
   - Automatically detects configuration
   - Reduces maintenance burden
   - Easy to extend for special cases

4. **Optimum's Approach**:
   - Uses model-specific OnnxConfig classes
   - More precise but requires implementation for each model
   - Limited to ~100 supported architectures

### Next Steps:

To make this production-ready:
1. Add more model type mappings
2. Handle more special cases (vision-language, speech, etc.)
3. Add validation and error handling
4. Integrate with existing export pipelines
5. Test with diverse model architectures