# VampNet ONNX vs Original Model Comparison Demo

This notebook demonstrates the current state of VampNet ONNX export and compares outputs between:
1. Original PyTorch VampNet (both coarse and C2F models)
2. ONNX exported model (currently only coarse model)

## Current Limitations
- Only the coarse transformer (4 codebooks) has been exported to ONNX
- The coarse-to-fine (C2F) model (10 codebooks) has NOT been exported
- ONNX output will be lower quality due to missing fine codebooks

In [None]:
# Setup and imports
import torch
import numpy as np
import onnxruntime as ort
from pathlib import Path
import IPython.display as ipd
import matplotlib.pyplot as plt
from scipy.io import wavfile
import json

# VampNet imports
from vampnet import mask as pmask
from vampnet.interface import Interface

# ONNX pipeline imports
import sys
sys.path.append('..')
from vampnet_onnx.pipeline import VampNetONNXPipeline
from vampnet_onnx.audio_processor import AudioProcessor

print("Libraries imported successfully")

## 1. Load Original VampNet Models

In [None]:
# Load original VampNet interface
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Initialize VampNet interface - this loads both coarse and C2F models
try:
    interface = Interface(device=device)
    print("\nOriginal VampNet models loaded:")
    print(f"- Coarse model: {interface.coarse is not None}")
    print(f"- C2F model: {interface.c2f is not None}")
    print(f"- Codec: {interface.codec is not None}")
    print(f"\nModel details:")
    print(f"- Coarse model layers: {len(interface.coarse.net.layers) if hasattr(interface.coarse.net, 'layers') else 'N/A'}")
    print(f"- C2F model layers: {len(interface.c2f.net.layers) if hasattr(interface.c2f.net, 'layers') else 'N/A'}")
except Exception as e:
    print(f"Error loading VampNet models: {e}")
    print("Make sure you have the model checkpoints in the expected location")

## 2. Check Available ONNX Models

In [None]:
# Check what ONNX models are available
onnx_dirs = [
    "../onnx_models",
    "../onnx_models_optimized",
    "../onnx_models_quantized",
    "../onnx_models_test"
]

print("Available ONNX models:")
for dir_path in onnx_dirs:
    if Path(dir_path).exists():
        print(f"\n{dir_path}:")
        for model_file in Path(dir_path).glob("*.onnx"):
            print(f"  - {model_file.name}")
            
# Check for C2F model specifically
c2f_models = []
for dir_path in onnx_dirs:
    if Path(dir_path).exists():
        c2f_models.extend(list(Path(dir_path).glob("*c2f*.onnx")))
        c2f_models.extend(list(Path(dir_path).glob("*fine*.onnx")))
        
print(f"\nC2F/Fine models found: {len(c2f_models)}")
if c2f_models:
    for model in c2f_models:
        print(f"  - {model}")

## 3. Load ONNX Models

In [None]:
# Load ONNX models
codec_path = "../onnx_models/codec_encoder.onnx"
decoder_path = "../onnx_models/codec_decoder.onnx"
transformer_path = "../onnx_models/vampnet_transformer.onnx"

# Check if all required models exist
models_exist = {
    "Encoder": Path(codec_path).exists(),
    "Decoder": Path(decoder_path).exists(),
    "Transformer (Coarse)": Path(transformer_path).exists()
}

print("ONNX model status:")
for model_name, exists in models_exist.items():
    print(f"  {model_name}: {'✓ Found' if exists else '✗ Missing'}")

if all(models_exist.values()):
    # Initialize ONNX sessions
    encoder_session = ort.InferenceSession(codec_path)
    decoder_session = ort.InferenceSession(decoder_path)
    transformer_session = ort.InferenceSession(transformer_path)
    
    # Check transformer details
    print("\nTransformer ONNX model info:")
    for input in transformer_session.get_inputs():
        print(f"  Input: {input.name}, shape: {input.shape}, dtype: {input.type}")
    for output in transformer_session.get_outputs():
        print(f"  Output: {output.name}, shape: {output.shape}, dtype: {output.type}")
else:
    print("\nMissing required ONNX models. Please run export scripts first.")

## 4. Load Test Audio

In [None]:
# Load or generate test audio
test_audio_path = "../assets/test_audio.wav"  # You can change this to your audio file

if Path(test_audio_path).exists():
    # Load existing audio
    sr, audio_data = wavfile.read(test_audio_path)
    audio_data = audio_data.astype(np.float32) / 32768.0  # Normalize to [-1, 1]
    if len(audio_data.shape) > 1:
        audio_data = audio_data.mean(axis=1)  # Convert to mono
    print(f"Loaded audio: {len(audio_data)/sr:.2f} seconds at {sr} Hz")
else:
    # Generate test audio (sine wave)
    print("Test audio not found. Generating sine wave...")
    sr = 44100
    duration = 3.0
    t = np.linspace(0, duration, int(sr * duration))
    frequency = 440  # A4 note
    audio_data = 0.5 * np.sin(2 * np.pi * frequency * t).astype(np.float32)
    
# Display audio
print(f"Audio shape: {audio_data.shape}")
print(f"Audio range: [{audio_data.min():.3f}, {audio_data.max():.3f}]")
ipd.display(ipd.Audio(audio_data, rate=sr))

## 5. Process with Original VampNet (Full Pipeline)

In [None]:
if 'interface' in globals():
    # Encode audio with original VampNet
    print("Encoding with original VampNet...")
    
    # Convert to torch tensor
    audio_tensor = torch.from_numpy(audio_data).unsqueeze(0).unsqueeze(0).to(device)
    
    # Encode to get latent codes
    with torch.no_grad():
        z = interface.encode(audio_tensor)
        print(f"Encoded shape: {z.shape}")
        print(f"Codebooks: {z.shape[1]} (first 4 are coarse, remaining 10 are fine)")
        
        # Decode back to audio
        reconstructed_audio = interface.decode(z)
        reconstructed_audio = reconstructed_audio.squeeze().cpu().numpy()
    
    print("\nOriginal VampNet reconstruction:")
    ipd.display(ipd.Audio(reconstructed_audio, rate=sr))
    
    # Save the codes for comparison
    original_codes = z.cpu().numpy()
    
    # Show code statistics
    print("\nCode statistics:")
    print(f"Coarse codes (0-3): unique values = {np.unique(original_codes[0, :4]).shape[0]}")
    print(f"Fine codes (4-13): unique values = {np.unique(original_codes[0, 4:]).shape[0]}")
else:
    print("Original VampNet not loaded")

## 6. Process with ONNX Pipeline (Coarse Only)

In [None]:
if all(models_exist.values()):
    print("Processing with ONNX pipeline...")
    
    # Initialize ONNX pipeline
    pipeline = VampNetONNXPipeline(
        encoder_path=codec_path,
        decoder_path=decoder_path,
        transformer_path=transformer_path
    )
    
    # Process audio
    try:
        # Encode audio
        codes = pipeline.encode_audio(audio_data, sample_rate=sr)
        print(f"ONNX encoded shape: {codes.shape}")
        
        # Note: ONNX pipeline currently only uses coarse codes (first 4)
        print("\n⚠️ WARNING: ONNX pipeline only processes coarse codes (4 codebooks)")
        print("Fine codes (10 codebooks) are padded with zeros for decoding")
        
        # Decode back
        onnx_reconstructed = pipeline.decode_codes(codes)
        
        print("\nONNX reconstruction (coarse only):")
        ipd.display(ipd.Audio(onnx_reconstructed, rate=sr))
        
        # Save ONNX codes for comparison
        onnx_codes = codes
        
    except Exception as e:
        print(f"Error in ONNX pipeline: {e}")
        import traceback
        traceback.print_exc()
else:
    print("ONNX models not available")

## 7. Compare Outputs

In [None]:
# Visual comparison of waveforms
if 'reconstructed_audio' in globals() and 'onnx_reconstructed' in globals():
    fig, axes = plt.subplots(3, 1, figsize=(12, 8))
    
    # Original
    axes[0].plot(audio_data[:sr//10])  # Show first 0.1 seconds
    axes[0].set_title("Original Audio")
    axes[0].set_ylabel("Amplitude")
    
    # PyTorch reconstruction
    axes[1].plot(reconstructed_audio[:sr//10])
    axes[1].set_title("PyTorch VampNet Reconstruction (Full: Coarse + Fine)")
    axes[1].set_ylabel("Amplitude")
    
    # ONNX reconstruction
    axes[2].plot(onnx_reconstructed[:sr//10])
    axes[2].set_title("ONNX Reconstruction (Coarse Only - Lower Quality)")
    axes[2].set_ylabel("Amplitude")
    axes[2].set_xlabel("Samples")
    
    plt.tight_layout()
    plt.show()
    
    # Calculate differences
    min_len = min(len(reconstructed_audio), len(onnx_reconstructed), len(audio_data))
    pytorch_mse = np.mean((audio_data[:min_len] - reconstructed_audio[:min_len])**2)
    onnx_mse = np.mean((audio_data[:min_len] - onnx_reconstructed[:min_len])**2)
    
    print("\nReconstruction Quality (MSE):")
    print(f"PyTorch (Full): {pytorch_mse:.6f}")
    print(f"ONNX (Coarse only): {onnx_mse:.6f}")
    print(f"\nQuality degradation: {(onnx_mse/pytorch_mse - 1)*100:.1f}% worse")

## 8. Compare Latent Codes

In [None]:
if 'original_codes' in globals() and 'onnx_codes' in globals():
    print("Latent code comparison:")
    print(f"Original codes shape: {original_codes.shape}")
    print(f"ONNX codes shape: {onnx_codes.shape}")
    
    # Compare coarse codes
    if original_codes.shape[-1] == onnx_codes.shape[-1]:
        coarse_match = np.allclose(original_codes[0, :4], onnx_codes[:4], rtol=0.01)
        print(f"\nCoarse codes (0-3) match: {coarse_match}")
        
        if not coarse_match:
            diff = np.abs(original_codes[0, :4] - onnx_codes[:4])
            print(f"Max difference in coarse codes: {diff.max()}")
            print(f"Mean difference in coarse codes: {diff.mean()}")
    
    # Check fine codes in ONNX
    if onnx_codes.shape[0] > 4:
        fine_codes = onnx_codes[4:]
        print(f"\nFine codes in ONNX: {fine_codes.shape}")
        print(f"Are fine codes all zeros? {np.all(fine_codes == 0)}")
    else:
        print("\n⚠️ ONNX output only contains coarse codes!")

## 9. Check for C2F Model Checkpoint

In [None]:
# Check for C2F model checkpoint
possible_c2f_paths = [
    "../models/c2f.pth",
    "../models/vampnet_c2f.pth",
    "../models/coarse_to_fine.pth",
    "~/.cache/vampnet/c2f.pth",
    "~/.cache/audiocraft/vampnet/c2f.pth"
]

print("Searching for C2F model checkpoint...")
c2f_checkpoint = None
for path in possible_c2f_paths:
    expanded_path = Path(path).expanduser()
    if expanded_path.exists():
        c2f_checkpoint = expanded_path
        print(f"✓ Found C2F checkpoint: {expanded_path}")
        break
    else:
        print(f"✗ Not found: {path}")

if c2f_checkpoint:
    # Check if we can load it
    try:
        checkpoint = torch.load(c2f_checkpoint, map_location='cpu')
        print(f"\nC2F checkpoint loaded successfully")
        print(f"Keys in checkpoint: {list(checkpoint.keys())[:5]}...")
        if 'model' in checkpoint:
            print(f"Model state dict keys: {len(checkpoint['model'])}")
    except Exception as e:
        print(f"Error loading C2F checkpoint: {e}")
else:
    print("\n⚠️ C2F model checkpoint not found!")
    print("This is why only coarse model has been exported.")

## 10. Summary and Next Steps

In [None]:
print("=" * 60)
print("SUMMARY")
print("=" * 60)

print("\nCurrent State:")
print("✓ Coarse transformer exported to ONNX (4 codebooks)")
print("✓ Codec (encoder/decoder) exported to ONNX")
print("✗ C2F transformer NOT exported (10 codebooks)")
print("✗ Complete weight transfer pending (embeddings/classifiers)")

print("\nQuality Impact:")
print("- Original VampNet: Uses all 14 codebooks (4 coarse + 10 fine)")
print("- ONNX Pipeline: Only uses 4 coarse codebooks")
print("- Result: ONNX output is lower quality (missing detail from fine codes)")

print("\nNext Steps:")
print("1. Export C2F model to ONNX")
print("2. Complete weight transfer for embeddings and classifiers")
print("3. Implement full two-stage pipeline (coarse → C2F)")
print("4. Verify outputs match original quality")

if c2f_checkpoint:
    print(f"\n✓ Good news: C2F checkpoint found at {c2f_checkpoint}")
    print("  We can proceed with exporting the C2F model!")
else:
    print("\n⚠️ C2F checkpoint not found - need to locate it first")