# VampNet TorchScript Export

This notebook explores using TorchScript as an intermediate format for ONNX export.
TorchScript can preserve more of the model's behavior and supports type annotations.

In [None]:
import torch
import torch.nn as nn
from typing import Tuple, Optional, List
import numpy as np
import vampnet
import onnx
import onnxruntime as ort
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('./')))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Load Pretrained Models

In [None]:
# Load VampNet models
interface = vampnet.interface.Interface(
    codec_ckpt="../models/vampnet/codec.pth",
    coarse_ckpt="../models/vampnet/coarse.pth",
)

coarse_model = interface.coarse
print(f"Coarse model type: {type(coarse_model)}")
print(f"Has _orig_mod: {hasattr(coarse_model, '_orig_mod')}")

# Get the actual model
if hasattr(coarse_model, '_orig_mod'):
    actual_model = coarse_model._orig_mod
else:
    actual_model = coarse_model
    
print(f"Actual model type: {type(actual_model)}")

## 2. Understanding VampNet Model Structure

In [None]:
# Explore model structure
print("Model attributes:")
for attr in dir(actual_model):
    if not attr.startswith('_') and not callable(getattr(actual_model, attr, None)):
        print(f"  {attr}: {getattr(actual_model, attr, 'N/A')}")

print("\nModel methods:")
for attr in dir(actual_model):
    if not attr.startswith('_') and callable(getattr(actual_model, attr, None)):
        print(f"  {attr}")

## 3. Create TorchScript-Compatible Wrapper

In [None]:
class VampNetTorchScriptWrapper(nn.Module):
    """
    TorchScript-compatible wrapper with type annotations.
    """
    
    def __init__(self, vampnet_model):
        super().__init__();
        self.model = vampnet_model
        self.n_codebooks: int = 4
        self.vocab_size: int = 1024
        self.mask_token: int = 1024
        
    def forward(self, codes: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with generation.
        
        Args:
            codes: Input codes [batch, n_codebooks, seq_len]
            mask: Binary mask [batch, n_codebooks, seq_len]
            
        Returns:
            Generated codes [batch, n_codebooks, seq_len]
        """
        # Get logits from model
        logits = self.model(codes)
        
        # Apply temperature (fixed at 1.0)
        temperature: float = 1.0
        logits = logits / temperature
        
        # Get predictions
        predictions = torch.argmax(logits, dim=-1)
        
        # Apply mask
        output = torch.where(mask.bool(), predictions, codes)
        
        return output


# Create wrapper
wrapper = VampNetTorchScriptWrapper(actual_model)
wrapper.eval()

## 4. Test the Wrapper

In [None]:
# Create test inputs
test_codes = torch.randint(0, 1024, (1, 4, 100), device=device)
test_mask = torch.randint(0, 2, (1, 4, 100), device=device)

# Add some mask tokens
test_codes[test_mask.bool()] = 1024

print(f"Test codes shape: {test_codes.shape}")
print(f"Number of masked positions: {test_mask.sum().item()}")

# Test forward pass
try:
    with torch.no_grad():
        output = wrapper(test_codes, test_mask)
    print(f"✅ Forward pass successful! Output shape: {output.shape}")
    print(f"Output differs at masked positions: {(output[test_mask.bool()] != test_codes[test_mask.bool()]).any().item()}")
except Exception as e:
    print(f"❌ Forward pass failed: {e}")

## 5. Attempt TorchScript Tracing

In [None]:
# Try to trace the model
print("Attempting to trace the wrapper...")

try:
    traced_model = torch.jit.trace(wrapper, (test_codes, test_mask))
    print("✅ Tracing successful!")
    
    # Test traced model
    with torch.no_grad():
        traced_output = traced_model(test_codes, test_mask)
    
    # Compare outputs
    if torch.allclose(output, traced_output):
        print("✅ Traced model produces identical output!")
    else:
        max_diff = torch.max(torch.abs(output - traced_output)).item()
        print(f"⚠️ Traced model has differences. Max diff: {max_diff}")
        
    # Save traced model
    traced_model.save("traced_vampnet.pt")
    print("Saved traced model to traced_vampnet.pt")
    
except Exception as e:
    print(f"❌ Tracing failed: {e}")
    print("\nThis is expected if the model has dynamic control flow.")

## 6. Try TorchScript Scripting

In [None]:
# Try scripting instead
print("Attempting to script the wrapper...")

try:
    scripted_model = torch.jit.script(wrapper)
    print("✅ Scripting successful!")
    
    # Test scripted model
    with torch.no_grad():
        scripted_output = scripted_model(test_codes, test_mask)
    
    print(f"Scripted output shape: {scripted_output.shape}")
    
    # Save scripted model
    scripted_model.save("scripted_vampnet.pt")
    print("Saved scripted model to scripted_vampnet.pt")
    
except Exception as e:
    print(f"❌ Scripting failed: {e}")
    print("\nThe model likely contains operations not supported by TorchScript.")

## 7. Simplified Approach: Extract Key Operations

In [None]:
class SimplifiedVampNet(nn.Module):
    """
    Simplified model that extracts key components from VampNet.
    """
    
    def __init__(self, vampnet_model):
        super().__init__()
        
        # Try to extract key components
        self.n_codebooks = 4
        self.vocab_size = 1024
        
        # Extract embeddings if available
        if hasattr(vampnet_model, 'embedding'):
            self.embedding = vampnet_model.embedding
            print("✅ Extracted embedding layer")
        else:
            print("❌ No embedding layer found")
            
        # Extract transformer if available
        if hasattr(vampnet_model, 'transformer'):
            self.transformer = vampnet_model.transformer
            print("✅ Extracted transformer")
        elif hasattr(vampnet_model, 'net'):
            self.transformer = vampnet_model.net
            print("✅ Extracted net as transformer")
        else:
            print("❌ No transformer found")
            
        # Extract output projection if available
        if hasattr(vampnet_model, 'classifier'):
            self.classifier = vampnet_model.classifier
            print("✅ Extracted classifier")
        else:
            print("❌ No classifier found")
    
    def forward(self, codes: torch.Tensor) -> torch.Tensor:
        # This is a placeholder - would need actual implementation
        return torch.randn(codes.shape[0], self.n_codebooks, codes.shape[2], self.vocab_size)


# Try the simplified approach
simplified = SimplifiedVampNet(actual_model)

# List all modules in the actual model
print("\nActual model modules:")
for name, module in actual_model.named_modules():
    if name:  # Skip the root module
        print(f"  {name}: {type(module).__name__}")

## 8. Alternative: Export Individual Components

In [None]:
# Instead of exporting the whole model, we can export individual operations
# and reconstruct the model in ONNX

class ComponentWrapper(nn.Module):
    """Wrapper for individual model components."""
    
    def __init__(self, component):
        super().__init__()
        self.component = component
        
    def forward(self, x):
        return self.component(x)


# Try to export the embedding layer
if hasattr(actual_model, 'embedding'):
    print("Attempting to export embedding layer...")
    embedding_wrapper = ComponentWrapper(actual_model.embedding)
    
    try:
        # Create example input
        example_tokens = torch.randint(0, 1024, (1, 4, 100))
        
        # Trace
        traced_embedding = torch.jit.trace(embedding_wrapper, example_tokens)
        
        # Export to ONNX
        torch.onnx.export(
            traced_embedding,
            example_tokens,
            "vampnet_embedding.onnx",
            input_names=['tokens'],
            output_names=['embeddings'],
            dynamic_axes={'tokens': {0: 'batch', 2: 'sequence'}},
            opset_version=14
        )
        
        print("✅ Successfully exported embedding layer!")
        
    except Exception as e:
        print(f"❌ Failed to export embedding: {e}")

## 9. Custom TorchScript Operations

In [None]:
# We can also define custom TorchScript operations
@torch.jit.script
def generate_tokens(logits: torch.Tensor, 
                   mask: torch.Tensor,
                   codes: torch.Tensor,
                   temperature: float = 1.0) -> torch.Tensor:
    """
    TorchScript function for token generation.
    
    Args:
        logits: Model output logits [batch, n_codebooks, seq_len, vocab_size]
        mask: Binary mask [batch, n_codebooks, seq_len]
        codes: Original codes [batch, n_codebooks, seq_len]
        temperature: Temperature for scaling
        
    Returns:
        Generated codes [batch, n_codebooks, seq_len]
    """
    # Scale by temperature
    scaled_logits = logits / temperature
    
    # Get predictions
    predictions = torch.argmax(scaled_logits, dim=-1)
    
    # Apply mask
    output = torch.where(mask.bool(), predictions, codes)
    
    return output


# Test the scripted function
test_logits = torch.randn(1, 4, 100, 1024)
result = generate_tokens(test_logits, test_mask, test_codes)
print(f"Scripted function output shape: {result.shape}")

# This function can be used as part of a larger TorchScript model

## 10. Hybrid Approach: TorchScript + ONNX

In [None]:
class HybridVampNet(nn.Module):
    """
    Hybrid model that uses TorchScript for complex operations
    and can be exported to ONNX.
    """
    
    def __init__(self):
        super().__init__()
        self.n_codebooks = 4
        self.vocab_size = 1024
        self.d_model = 512
        
        # Use standard PyTorch layers that ONNX supports well
        self.embedding = nn.Embedding(1025, self.d_model)  # +1 for mask token
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=self.d_model,
                nhead=8,
                dim_feedforward=2048,
                batch_first=True
            ),
            num_layers=6
        )
        self.output_proj = nn.Linear(self.d_model, self.vocab_size)
        
    def forward(self, codes: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        batch_size, n_codebooks, seq_len = codes.shape
        
        # Flatten codebooks
        flat_codes = codes.view(batch_size, -1)
        
        # Embed
        embeddings = self.embedding(flat_codes)
        
        # Transform
        transformed = self.transformer(embeddings)
        
        # Project
        logits = self.output_proj(transformed)
        
        # Reshape
        logits = logits.view(batch_size, n_codebooks, seq_len, -1)
        
        # Use our scripted generation function
        output = generate_tokens(logits, mask, codes, 1.0)
        
        return output


# Create and test hybrid model
hybrid = HybridVampNet()
hybrid.eval()

# Test
with torch.no_grad():
    hybrid_output = hybrid(test_codes, test_mask)
print(f"Hybrid model output shape: {hybrid_output.shape}")

# Try to export to ONNX
try:
    torch.onnx.export(
        hybrid,
        (test_codes, test_mask),
        "hybrid_vampnet.onnx",
        input_names=['codes', 'mask'],
        output_names=['generated_codes'],
        dynamic_axes={
            'codes': {0: 'batch', 2: 'sequence'},
            'mask': {0: 'batch', 2: 'sequence'},
            'generated_codes': {0: 'batch', 2: 'sequence'}
        },
        opset_version=14
    )
    print("✅ Successfully exported hybrid model to ONNX!")
    
    # Verify
    onnx_model = onnx.load("hybrid_vampnet.onnx")
    onnx.checker.check_model(onnx_model)
    print("✅ ONNX model verification passed!")
    
except Exception as e:
    print(f"❌ Failed to export hybrid model: {e}")

## Summary

TorchScript can help with ONNX export in several ways:

1. **Type Annotations**: Help TorchScript understand the model better
2. **@torch.jit.script**: Can script individual functions for use in models
3. **@torch.jit.trace**: Can trace models with fixed control flow
4. **Hybrid Approach**: Combine TorchScript operations with ONNX-friendly layers

However, the pretrained VampNet model is still too complex for direct export due to:
- Custom layers and operations
- Dynamic control flow
- Complex attention mechanisms

The best approaches are:
1. Create a new model architecture that's ONNX-friendly and transfer weights
2. Use a hybrid pipeline (ONNX codec + PyTorch transformer)
3. Export individual components and reconstruct in the target framework