# Test ONNX with Proper Sampling

This notebook tests the new ONNX export that returns logits instead of using ArgMax.

In [None]:
import sys
sys.path.append('..')

# First, export the new models
from scripts.export_vampnet_transformer_v3_sampling import export_model_with_proper_sampling

# Export coarse model with logits output
export_model_with_proper_sampling("coarse")
export_model_with_proper_sampling("c2f")

In [ ]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import soundfile as sf
import onnxruntime as ort
from pathlib import Path

from scripts.export_vampnet_transformer_v3_sampling import sample_from_onnx_output

# Load models
encoder_session = ort.InferenceSession("../scripts/models/vampnet_encoder_prepadded.onnx")
decoder_session = ort.InferenceSession("../scripts/models/vampnet_codec_decoder.onnx")
coarse_session = ort.InferenceSession("../onnx_models_fixed/coarse_logits_v3.onnx")
c2f_session = ort.InferenceSession("../onnx_models_fixed/c2f_logits_v3.onnx")

print("Models loaded")

# Check input names for each model
print("\nModel inputs:")
print(f"  Coarse: {[inp.name for inp in coarse_session.get_inputs()]}")
print(f"  C2F: {[inp.name for inp in c2f_session.get_inputs()]}")

In [None]:
# Create test audio
n_samples = 76800  # 100 tokens
sample_rate = 44100
t = np.linspace(0, n_samples/sample_rate, n_samples)
audio = 0.5 * np.sin(2 * np.pi * 440 * t).astype(np.float32)

# Encode
codes = encoder_session.run(None, {'audio_padded': audio.reshape(1, 1, -1)})[0]
print(f"Encoded shape: {codes.shape}")

In [ ]:
# Test iterative generation with proper sampling
mask_schedule = [0.9, 0.7, 0.5, 0.3, 0.1, 0.0]
z = codes.copy()
n_coarse = 4

print("Running iterative generation with proper sampling...")

for i, mask_ratio in enumerate(mask_schedule):
    print(f"\nStep {i+1}: mask_ratio = {mask_ratio}")
    
    if mask_ratio > 0:
        # Create mask
        mask = np.random.rand(1, n_coarse, 100) < mask_ratio
        
        # Get logits from ONNX
        logits = coarse_session.run(None, {
            'codes': z[:, :n_coarse, :].astype(np.int64),
            'mask': mask.astype(bool)
        })[0]
        
        print(f"  Logits shape: {logits.shape}")
        
        # Apply proper sampling
        z_new = sample_from_onnx_output(
            z[:, :n_coarse, :], 
            mask,
            logits,
            temperature=0.8,
            top_p=0.9
        )
        
        # Update coarse codes
        z[:, :n_coarse, :] = z_new
        
        # Count changes
        n_changed = (z != codes).sum()
        print(f"  Changed {n_changed} tokens")

# Apply C2F
print("\nApplying C2F...")
c2f_mask = np.zeros((1, 14, 100), dtype=bool)
c2f_mask[:, 4:, :] = True

# Pad z to 14 codebooks
z_full = np.zeros((1, 14, 100), dtype=np.int64)
z_full[:, :4, :] = z[:, :4, :]

# Check if C2F expects mask input
c2f_inputs = {inp.name: None for inp in c2f_session.get_inputs()}
print(f"C2F expects inputs: {list(c2f_inputs.keys())}")

# Get C2F logits
if 'mask' in c2f_inputs:
    c2f_logits = c2f_session.run(None, {
        'codes': z_full,
        'mask': c2f_mask
    })[0]
else:
    # C2F doesn't expect mask - it's baked into the model
    c2f_logits = c2f_session.run(None, {
        'codes': z_full
    })[0]

print(f"C2F logits shape: {c2f_logits.shape}")

# Apply sampling
z_final = sample_from_onnx_output(
    z_full,
    c2f_mask,
    c2f_logits,
    temperature=0.8,
    top_p=0.9,
    n_conditioning_codebooks=4
)

# Decode
audio_generated = decoder_session.run(None, {'codes': z_final})[0]
audio_generated = audio_generated[0, 0, :]

# Save
output_dir = Path("outputs/sampling_test")
output_dir.mkdir(exist_ok=True, parents=True)

sf.write(output_dir / "original.wav", audio, sample_rate)
sf.write(output_dir / "generated_with_sampling.wav", audio_generated, sample_rate)

print(f"\nSaved audio to {output_dir}")

# Plot
fig, axes = plt.subplots(2, 1, figsize=(12, 6))
axes[0].plot(audio[:4410])
axes[0].set_title('Original')
axes[1].plot(audio_generated[:4410])
axes[1].set_title('Generated with Proper Sampling')
plt.tight_layout()
plt.show()