import torch
import numpy as np
import matplotlib.pyplot as plt
import onnxruntime as ort
from pathlib import Path
import sys
sys.path.append('..')

# Import VampNet
from vampnet.interface import Interface
from vampnet import mask as pmask

# For comparing distributions
from scipy import stats

In [ ]:
# Load VampNet models
interface = Interface(
    coarse_ckpt=Path("/Users/stephen/Documents/Development/MusicHackspace/vampnet-onnx-export-cleanup/models/vampnet/coarse.pth"),
    coarse2fine_ckpt=Path("/Users/stephen/Documents/Development/MusicHackspace/vampnet-onnx-export-cleanup/models/vampnet/c2f.pth"),
    codec_ckpt=Path("/Users/stephen/Documents/Development/MusicHackspace/vampnet-onnx-export-cleanup/models/vampnet/codec.pth"),
    device="cpu"
)

# Access the actual models
coarse_vamp = interface.coarse
c2f_vamp = interface.c2f

print("VampNet models loaded")
print(f"Coarse model: {type(coarse_vamp)}")
print(f"C2F model: {type(c2f_vamp)}")

In [None]:
# Load VampNet models
interface = Interface(
    coarse_ckpt="/Users/stephen/Documents/Development/MusicHackspace/vampnet-onnx-export-cleanup/models/vampnet/coarse.pth",
    c2f_ckpt="/Users/stephen/Documents/Development/MusicHackspace/vampnet-onnx-export-cleanup/models/vampnet/c2f.pth",
    codec_ckpt="/Users/stephen/Documents/Development/MusicHackspace/vampnet-onnx-export-cleanup/models/vampnet/codec.pth",
    device="cpu"
)

# Access the actual models
coarse_vamp = interface.coarse
c2f_vamp = interface.c2f

print("VampNet models loaded")
print(f"Coarse model: {type(coarse_vamp)}")
print(f"C2F model: {type(c2f_vamp)}")

In [None]:
# Load ONNX models
onnx_coarse_session = ort.InferenceSession("../onnx_models_fixed/coarse_logits_v3.onnx")
onnx_c2f_session = ort.InferenceSession("../onnx_models_fixed/c2f_logits_v3.onnx")

print("ONNX models loaded")
print(f"Coarse inputs: {[inp.name for inp in onnx_coarse_session.get_inputs()]}")
print(f"C2F inputs: {[inp.name for inp in onnx_c2f_session.get_inputs()]}")

In [None]:
# Create test input
batch_size = 1
n_tokens = 100
n_codebooks = 4

# Create random codes
codes = torch.randint(0, 1024, (batch_size, n_codebooks, n_tokens))
mask_ratio = 0.5

# Create mask
mask = torch.rand(batch_size, n_codebooks, n_tokens) < mask_ratio

print(f"Test input shape: {codes.shape}")
print(f"Mask shape: {mask.shape}")
print(f"Masked positions: {mask.sum().item()} / {mask.numel()} ({mask.float().mean()*100:.1f}%)")

In [None]:
# Get VampNet logits
print("Getting VampNet logits...")

with torch.no_grad():
    # Apply mask
    z_masked = codes.clone()
    z_masked[mask] = coarse_vamp.mask_token
    
    # Get embeddings
    z_e = coarse_vamp.embedding.from_codes(z_masked, coarse_vamp.n_codebooks)
    z_e = coarse_vamp.embedding.add_positional_encoding(z_e)
    
    # Run through transformer
    z_hat = coarse_vamp.transformer(z_e)
    
    # Get logits
    logits_vampnet = coarse_vamp.classifier(z_hat)
    
    # Reshape to match expected format
    # VampNet outputs (batch, seq_len, n_codebooks * vocab_size)
    batch, seq_len, _ = logits_vampnet.shape
    logits_vampnet = logits_vampnet.view(batch, seq_len, n_codebooks, coarse_vamp.vocab_size)
    logits_vampnet = logits_vampnet.permute(0, 2, 1, 3)  # -> (batch, n_codebooks, seq_len, vocab_size)

print(f"VampNet logits shape: {logits_vampnet.shape}")
print(f"VampNet logits range: [{logits_vampnet.min():.2f}, {logits_vampnet.max():.2f}]")

In [None]:
# Get ONNX logits
print("Getting ONNX logits...")

# Prepare inputs
onnx_inputs = {
    'codes': codes.numpy().astype(np.int64),
    'mask': mask.numpy()
}

# Run ONNX
logits_onnx = onnx_coarse_session.run(None, onnx_inputs)[0]

print(f"ONNX logits shape: {logits_onnx.shape}")
print(f"ONNX logits range: [{logits_onnx.min():.2f}, {logits_onnx.max():.2f}]")

# Note: ONNX outputs vocab_size+1 (1025) to include mask token
# VampNet outputs vocab_size (1024)
print(f"\nVocab size difference:")
print(f"  VampNet: {logits_vampnet.shape[-1]}")
print(f"  ONNX: {logits_onnx.shape[-1]}")

In [None]:
# Compare logits at masked positions
print("Comparing logits at masked positions...\n")

# Get mask positions
mask_indices = torch.where(mask)

# Compare a few masked positions
n_compare = min(5, len(mask_indices[0]))

for i in range(n_compare):
    b, c, t = mask_indices[0][i], mask_indices[1][i], mask_indices[2][i]
    
    # Get logits at this position
    logits_v = logits_vampnet[b, c, t, :].numpy()
    logits_o = logits_onnx[b, c, t, :1024]  # Only compare first 1024 tokens
    
    # Compare statistics
    print(f"Position [{b}, {c}, {t}]:")
    print(f"  VampNet: mean={logits_v.mean():.3f}, std={logits_v.std():.3f}, max={logits_v.max():.3f}")
    print(f"  ONNX:    mean={logits_o.mean():.3f}, std={logits_o.std():.3f}, max={logits_o.max():.3f}")
    
    # Top-5 predictions
    top5_v = np.argsort(logits_v)[-5:][::-1]
    top5_o = np.argsort(logits_o)[-5:][::-1]
    print(f"  VampNet top-5: {top5_v.tolist()}")
    print(f"  ONNX top-5: {top5_o.tolist()}")
    print(f"  Agreement: {len(set(top5_v) & set(top5_o))}/5")
    print()

In [None]:
# Statistical comparison of logits
print("Statistical comparison of all logits...\n")

# Flatten logits at masked positions
logits_v_masked = logits_vampnet[mask].numpy()
logits_o_masked = logits_onnx[mask.numpy()][:, :1024]  # Only first 1024

print(f"Number of masked positions: {mask.sum().item()}")
print(f"Logits shape per position: {logits_v_masked.shape[1]} (VampNet), {logits_o_masked.shape[1]} (ONNX)")

# Overall statistics
print(f"\nOverall statistics:")
print(f"  VampNet: mean={logits_v_masked.mean():.3f}, std={logits_v_masked.std():.3f}")
print(f"  ONNX:    mean={logits_o_masked.mean():.3f}, std={logits_o_masked.std():.3f}")

# Distribution comparison
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.hist(logits_v_masked.flatten(), bins=50, alpha=0.5, label='VampNet', density=True)
plt.hist(logits_o_masked.flatten(), bins=50, alpha=0.5, label='ONNX', density=True)
plt.xlabel('Logit Value')
plt.ylabel('Density')
plt.title('Logit Distribution')
plt.legend()

plt.subplot(1, 3, 2)
# Compare max logits per position
max_v = logits_v_masked.max(axis=1)
max_o = logits_o_masked.max(axis=1)
plt.scatter(max_v, max_o, alpha=0.5)
plt.plot([max_v.min(), max_v.max()], [max_v.min(), max_v.max()], 'r--')
plt.xlabel('VampNet Max Logit')
plt.ylabel('ONNX Max Logit')
plt.title('Max Logit Comparison')

plt.subplot(1, 3, 3)
# Compare entropy
probs_v = torch.softmax(torch.from_numpy(logits_v_masked), dim=-1).numpy()
probs_o = torch.softmax(torch.from_numpy(logits_o_masked), dim=-1).numpy()
entropy_v = -np.sum(probs_v * np.log(probs_v + 1e-10), axis=1)
entropy_o = -np.sum(probs_o * np.log(probs_o + 1e-10), axis=1)
plt.scatter(entropy_v, entropy_o, alpha=0.5)
plt.plot([entropy_v.min(), entropy_v.max()], [entropy_v.min(), entropy_v.max()], 'r--')
plt.xlabel('VampNet Entropy')
plt.ylabel('ONNX Entropy')
plt.title('Distribution Entropy Comparison')

plt.tight_layout()
plt.show()

# Correlation
correlation = np.corrcoef(logits_v_masked.flatten(), logits_o_masked.flatten())[0, 1]
print(f"\nLogit correlation: {correlation:.4f}")

In [None]:
# Sample from both and compare
print("Sampling from logits and comparing outputs...\n")

from scripts.export_vampnet_transformer_v3_sampling import sample_from_onnx_output

# Sample from VampNet logits
with torch.no_grad():
    # VampNet sampling
    probs = torch.softmax(logits_vampnet / 0.8, dim=-1)
    
    # Apply top-p
    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
    cutoff_mask = cumsum_probs > 0.9
    cutoff_mask[..., 0] = False
    sorted_probs[cutoff_mask] = 0
    sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
    
    # Restore order
    probs_vampnet = torch.zeros_like(probs)
    probs_vampnet.scatter_(-1, sorted_indices, sorted_probs)
    
    # Sample
    probs_flat = probs_vampnet.view(-1, probs_vampnet.shape[-1])
    sampled_flat = torch.multinomial(probs_flat, num_samples=1)
    sampled_vampnet = sampled_flat.view(batch_size, n_codebooks, n_tokens)
    
    # Apply mask
    output_vampnet = codes.clone()
    output_vampnet[mask] = sampled_vampnet[mask]

# Sample from ONNX logits
output_onnx = sample_from_onnx_output(
    codes.numpy(),
    mask.numpy(),
    logits_onnx,
    temperature=0.8,
    top_p=0.9
)

# Compare
changes_vampnet = (output_vampnet != codes).sum().item()
changes_onnx = (torch.from_numpy(output_onnx) != codes).sum().item()

print(f"Changes from original:")
print(f"  VampNet: {changes_vampnet} tokens")
print(f"  ONNX: {changes_onnx} tokens")
print(f"  Expected (masked): {mask.sum().item()} tokens")

# Token distribution
tokens_vampnet = output_vampnet[mask].numpy()
tokens_onnx = output_onnx[mask.numpy()]

print(f"\nToken distribution at masked positions:")
print(f"  VampNet: {len(np.unique(tokens_vampnet))} unique tokens")
print(f"  ONNX: {len(np.unique(tokens_onnx))} unique tokens")

# Most common tokens
print(f"\nMost common tokens:")
print(f"  VampNet: {np.bincount(tokens_vampnet).argsort()[-5:][::-1].tolist()}")
print(f"  ONNX: {np.bincount(tokens_onnx).argsort()[-5:][::-1].tolist()}")

In [None]:
# Check if logits have systematic differences
print("Checking for systematic differences...\n")

# Average difference
diff = logits_o_masked - logits_v_masked
print(f"Average logit difference (ONNX - VampNet): {diff.mean():.4f} ± {diff.std():.4f}")

# Per-token bias
per_token_diff = diff.mean(axis=0)
print(f"\nPer-token bias range: [{per_token_diff.min():.4f}, {per_token_diff.max():.4f}]")

# Plot per-token bias
plt.figure(figsize=(12, 4))
plt.plot(per_token_diff)
plt.axhline(y=0, color='r', linestyle='--')
plt.xlabel('Token ID')
plt.ylabel('Average Difference (ONNX - VampNet)')
plt.title('Per-Token Logit Bias')
plt.grid(True, alpha=0.3)
plt.show()

# Check if certain tokens are consistently different
problematic_tokens = np.where(np.abs(per_token_diff) > 0.5)[0]
if len(problematic_tokens) > 0:
    print(f"\nTokens with large bias (|diff| > 0.5): {problematic_tokens.tolist()}")
else:
    print("\nNo tokens with large systematic bias found.")

In [None]:
# Final summary
print("\n" + "="*60)
print("SUMMARY")
print("="*60)

print(f"\nLogit Statistics:")
print(f"  Correlation: {correlation:.4f}")
print(f"  Mean difference: {diff.mean():.4f} ± {diff.std():.4f}")
print(f"  Max absolute difference: {np.abs(diff).max():.4f}")

print(f"\nDistribution Comparison:")
print(f"  VampNet entropy: {entropy_v.mean():.3f} ± {entropy_v.std():.3f}")
print(f"  ONNX entropy: {entropy_o.mean():.3f} ± {entropy_o.std():.3f}")

print(f"\nSampling Results:")
print(f"  Both models changed approximately the expected number of tokens")
print(f"  Token diversity is similar between models")

if correlation > 0.9:
    print(f"\n✅ Logits are highly correlated - models are producing similar outputs")
elif correlation > 0.5:
    print(f"\n⚠️ Logits are moderately correlated - there may be some differences")
else:
    print(f"\n❌ Logits are poorly correlated - models are producing very different outputs")