# Transformer Masking Test with 100-Token Batch

This notebook tests the VampNet transformer with different periodic masking patterns on the coarse layers using a single 100-token batch.

In [ ]:
# Initialize VampNet
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load VampNet interface
interface = vampnet.interface.Interface(
    device=device,
    codec_ckpt="../models/vampnet/codec.pth",
    coarse_ckpt="../models/vampnet/coarse.pth",
    wavebeat_ckpt="../models/vampnet/wavebeat.pth"
)

codec = interface.codec
coarse_model = interface.coarse
codec.eval()
coarse_model.eval()

print("\nModels loaded:")
print(f"  Codec - Sample rate: {codec.sample_rate}, Hop length: {codec.hop_length}")
print(f"  Coarse model - n_codebooks: {coarse_model.n_codebooks}")
print(f"  Vocabulary size: {coarse_model.vocab_size}")

# Load ONNX models
onnx_encoder_path = Path("../scripts/models/vampnet_encoder_prepadded.onnx")
onnx_coarse_path = Path("../onnx_models_fixed/coarse_complete_v3.onnx")

if not onnx_encoder_path.exists():
    raise FileNotFoundError(f"ONNX encoder not found at {onnx_encoder_path}")
if not onnx_coarse_path.exists():
    raise FileNotFoundError(f"ONNX coarse transformer not found at {onnx_coarse_path}")

onnx_encoder = ort.InferenceSession(str(onnx_encoder_path))
onnx_coarse = ort.InferenceSession(str(onnx_coarse_path))

print("\nONNX models loaded")

## 1. Initialize Models

In [2]:
# Initialize VampNet
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load VampNet interface
interface = vampnet.interface.Interface(
    device=device,
    codec_ckpt="../models/vampnet/codec.pth",
    coarse_ckpt="../models/vampnet/coarse.pth",
    wavebeat_ckpt="../models/vampnet/wavebeat.pth"
)

codec = interface.codec
coarse_model = interface.coarse
codec.eval()
coarse_model.eval()

print("\nModels loaded:")
print(f"  Codec - Sample rate: {codec.sample_rate}, Hop length: {codec.hop_length}")
print(f"  Coarse model - n_codebooks: {coarse_model.n_codebooks}")
print(f"  Vocabulary size: {coarse_model.vocab_size}")

# Load ONNX models
onnx_encoder_path = Path("../scripts/models/vampnet_encoder_prepadded.onnx")
onnx_coarse_path = Path("../onnx_models_fixed/coarse_transformer_v2_weighted.onnx")

if not onnx_encoder_path.exists():
    raise FileNotFoundError(f"ONNX encoder not found at {onnx_encoder_path}")
if not onnx_coarse_path.exists():
    raise FileNotFoundError(f"ONNX coarse transformer not found at {onnx_coarse_path}")

onnx_encoder = ort.InferenceSession(str(onnx_encoder_path))
onnx_coarse = ort.InferenceSession(str(onnx_coarse_path))

print("\nONNX models loaded")

Using device: cpu


  model_dict = torch.load(location, "cpu")
  WeightNorm.apply(module, name, dim)
/Users/stephen/Documents/Development/MusicHackspace/vampnet-onnx-export-cleanup/venv/lib/python3.11/site-packages/pytorch_lightning/utilities/migration/migration.py:208: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.1.8 to v2.5.1.post0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../models/vampnet/wavebeat.pth`



Models loaded:
  Codec - Sample rate: 44100, Hop length: 768
  Coarse model - n_codebooks: 4
  Vocabulary size: 1024


FileNotFoundError: ONNX coarse transformer not found at ../onnx_models_fixed/coarse_transformer_v2_weighted.onnx

## 2. Create Test Audio (100 tokens)

In [None]:
# Create exactly 100 tokens worth of audio
n_tokens = 100
hop_length = 768
n_samples = n_tokens * hop_length  # 76,800 samples
sample_rate = 44100
duration = n_samples / sample_rate

print(f"Creating test audio: {n_samples} samples ({duration:.2f} seconds)")

# Create a musical test signal
t = np.linspace(0, duration, n_samples)
audio = np.zeros_like(t)

# Add harmonics to create a richer sound
fundamentals = [220, 330, 440]  # A3, E4, A4
for i, freq in enumerate(fundamentals):
    # Fundamental
    audio += 0.3 * np.sin(2 * np.pi * freq * t) / (i + 1)
    # Harmonics
    audio += 0.1 * np.sin(2 * np.pi * freq * 2 * t) / (i + 1)
    audio += 0.05 * np.sin(2 * np.pi * freq * 3 * t) / (i + 1)

# Add some envelope
envelope = np.exp(-t * 0.5) * (1 + 0.2 * np.sin(2 * np.pi * 6 * t))
audio = audio * envelope

# Normalize
audio = audio / np.max(np.abs(audio)) * 0.8
audio = audio.astype(np.float32)

print(f"Audio shape: {audio.shape}")
print(f"Duration: {duration:.2f} seconds")
print(f"Expected tokens: {n_tokens}")

# Plot the audio
plt.figure(figsize=(12, 4))
plt.plot(t[:4410], audio[:4410])  # First 0.1 seconds
plt.title('Test Audio Waveform (first 0.1s)')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.grid(True, alpha=0.3)
plt.show()

## 3. Encode Audio

In [None]:
# Encode with VampNet
print("Encoding with VampNet...")
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0).to(device)

with torch.no_grad():
    result = codec.encode(audio_tensor, sample_rate)
    vampnet_codes = result["codes"]

print(f"VampNet codes shape: {vampnet_codes.shape}")

# Encode with ONNX
print("\nEncoding with ONNX...")
audio_onnx = audio.reshape(1, 1, -1)
onnx_codes = onnx_encoder.run(None, {'audio_padded': audio_onnx})[0]

print(f"ONNX codes shape: {onnx_codes.shape}")

# Verify they match
match_rate = (vampnet_codes.cpu().numpy() == onnx_codes).mean()
print(f"\nEncoder match rate: {match_rate:.1%}")

# Use VampNet codes for transformer testing
codes = vampnet_codes
codes_numpy = codes.cpu().numpy()

# Extract just the coarse codes (first 4 codebooks)
n_coarse = 4
coarse_codes = codes[:, :n_coarse, :]
coarse_codes_numpy = codes_numpy[:, :n_coarse, :]

print(f"\nCoarse codes shape: {coarse_codes.shape}")
print(f"Coarse codes range: [{coarse_codes.min()}, {coarse_codes.max()}]")

## 4. Create Different Masking Patterns

In [None]:
def create_periodic_mask(shape, period, offset=0, mask_ratio=None):
    """
    Create a periodic mask pattern.
    
    Args:
        shape: (batch, codebooks, sequence) shape
        period: Mask every `period`-th token
        offset: Start offset for the pattern
        mask_ratio: If provided, use this ratio instead of periodic
    """
    batch, n_codebooks, seq_len = shape
    mask = np.zeros(shape, dtype=bool)
    
    if mask_ratio is not None:
        # Random masking with given ratio
        n_mask = int(seq_len * mask_ratio)
        for b in range(batch):
            for c in range(n_codebooks):
                indices = np.random.choice(seq_len, n_mask, replace=False)
                mask[b, c, indices] = True
    else:
        # Periodic masking
        for i in range(seq_len):
            if (i + offset) % period == 0:
                mask[:, :, i] = True
    
    return mask

# Create different masking patterns
mask_patterns = {
    "every_2": create_periodic_mask(coarse_codes.shape, period=2),
    "every_3": create_periodic_mask(coarse_codes.shape, period=3),
    "every_4": create_periodic_mask(coarse_codes.shape, period=4),
    "every_5": create_periodic_mask(coarse_codes.shape, period=5),
    "random_30": create_periodic_mask(coarse_codes.shape, period=None, mask_ratio=0.3),
    "random_50": create_periodic_mask(coarse_codes.shape, period=None, mask_ratio=0.5),
    "random_70": create_periodic_mask(coarse_codes.shape, period=None, mask_ratio=0.7),
}

# Visualize masks
fig, axes = plt.subplots(len(mask_patterns), 1, figsize=(15, len(mask_patterns) * 1.5))

for idx, (name, mask) in enumerate(mask_patterns.items()):
    ax = axes[idx]
    # Show mask for first codebook
    ax.imshow(mask[0, :1, :], aspect='auto', cmap='RdBu', interpolation='nearest')
    ax.set_title(f'{name} - {mask.mean():.1%} masked')
    ax.set_ylabel('Codebook')
    if idx == len(mask_patterns) - 1:
        ax.set_xlabel('Token Index')
    ax.set_yticks([0])
    ax.set_yticklabels(['CB 0'])

plt.suptitle('Masking Patterns (Blue = Masked, Red = Unmasked)', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Test VampNet Transformer with Different Masks

In [None]:
# Important: Check if coarse model expects embeddings or raw codes
print("Checking VampNet coarse model structure...")
print(f"Model type: {type(coarse_model)}")

# Test with VampNet's native generation
print("\nTesting VampNet transformer with different masks...")

vampnet_results = {}

for name, mask in mask_patterns.items():
    print(f"\n{name}:")
    
    # Create masked input
    masked_codes = coarse_codes.clone()
    mask_torch = torch.from_numpy(mask).to(device)
    
    # Apply mask (set to mask token)
    mask_token = coarse_model.vocab_size  # Usually 1024
    masked_codes[mask_torch] = mask_token
    
    print(f"  Masked tokens: {mask_torch.sum().item()} / {mask_torch.numel()} ({mask_torch.float().mean():.1%})")
    print(f"  Unique values in masked codes: {torch.unique(masked_codes).cpu().numpy()}")
    
    # Run through transformer
    with torch.no_grad():
        # The coarse model expects the codes directly, not embeddings
        # It will handle the embedding internally
        try:
            # Try the standard forward pass
            output = coarse_model(masked_codes)
            
            # Get predictions for masked positions
            predicted_codes = output.argmax(dim=-1)
            
            # Calculate accuracy on masked positions
            masked_positions = mask_torch
            correct = (predicted_codes[masked_positions] == coarse_codes[masked_positions]).float()
            accuracy = correct.mean().item()
            
            print(f"  Accuracy on masked positions: {accuracy:.1%}")
            
            vampnet_results[name] = {
                'masked_codes': masked_codes.cpu().numpy(),
                'predictions': predicted_codes.cpu().numpy(),
                'accuracy': accuracy
            }
            
        except Exception as e:
            print(f"  Error with standard forward: {e}")
            print(f"  Trying alternative approach...")
            
            # Try using the interface's generate method
            # This handles the generation process properly
            from vampnet import mask as mask_module
            
            # Create a proper mask using VampNet's masking
            mask_obj = mask_module.random(coarse_codes, mask_torch.float().mean().item())
            
            # Generate
            generated = interface.coarse_to_fine(
                coarse_codes,
                mask=mask_obj
            )
            
            print(f"  Generated shape: {generated.shape}")
            vampnet_results[name] = {'generated': generated.cpu().numpy()}

## 6. Test ONNX Transformer with Different Masks

In [None]:
print("Testing ONNX transformer with different masks...")

# Check ONNX model inputs
print("\nONNX Coarse model inputs:")
for inp in onnx_coarse.get_inputs():
    print(f"  {inp.name}: shape={inp.shape}, type={inp.type}")

onnx_results = {}

for name, mask in mask_patterns.items():
    print(f"\n{name}:")
    
    # Prepare inputs for ONNX
    # ONNX expects raw codes, not embeddings
    codes_input = coarse_codes_numpy.astype(np.int64)
    mask_input = mask.astype(bool)
    
    print(f"  Codes shape: {codes_input.shape}, dtype: {codes_input.dtype}")
    print(f"  Mask shape: {mask_input.shape}, dtype: {mask_input.dtype}")
    print(f"  Masked ratio: {mask_input.mean():.1%}")
    
    try:
        # Run ONNX model
        outputs = onnx_coarse.run(None, {
            'codes': codes_input,
            'mask': mask_input
        })
        
        generated_codes = outputs[0]
        print(f"  Output shape: {generated_codes.shape}")
        print(f"  Output range: [{generated_codes.min()}, {generated_codes.max()}]")
        
        # Check for mask tokens (1024) in output
        mask_token = 1024
        n_mask_tokens = (generated_codes == mask_token).sum()
        print(f"  Mask tokens in output: {n_mask_tokens}")
        
        # Calculate how many tokens were actually changed
        changed = (generated_codes != codes_input).sum()
        print(f"  Tokens changed: {changed} / {codes_input.size} ({changed/codes_input.size:.1%})")
        
        onnx_results[name] = {
            'input_codes': codes_input,
            'generated_codes': generated_codes,
            'mask': mask_input,
            'n_changed': changed
        }
        
    except Exception as e:
        print(f"  Error: {e}")
        onnx_results[name] = {'error': str(e)}

## 7. Compare VampNet vs ONNX Results

In [None]:
# Compare results for each masking pattern
print("Comparing VampNet vs ONNX results...\n")

for name in mask_patterns.keys():
    print(f"{name}:")
    
    if name in vampnet_results and name in onnx_results:
        vamp_res = vampnet_results[name]
        onnx_res = onnx_results[name]
        
        if 'error' not in onnx_res:
            # Compare generated codes
            if 'predictions' in vamp_res:
                vamp_codes = vamp_res['predictions']
            elif 'generated' in vamp_res:
                vamp_codes = vamp_res['generated'][:, :n_coarse, :]
            else:
                print("  VampNet: No valid output")
                continue
                
            onnx_codes = onnx_res['generated_codes']
            
            # Calculate match rate
            if vamp_codes.shape == onnx_codes.shape:
                matches = (vamp_codes == onnx_codes)
                match_rate = matches.mean()
                print(f"  Match rate: {match_rate:.1%}")
                
                # Check match rate on masked positions only
                mask = mask_patterns[name]
                masked_matches = matches[mask]
                masked_match_rate = masked_matches.mean() if len(masked_matches) > 0 else 0
                print(f"  Match rate (masked positions): {masked_match_rate:.1%}")
            else:
                print(f"  Shape mismatch: VampNet {vamp_codes.shape} vs ONNX {onnx_codes.shape}")
        else:
            print(f"  ONNX error: {onnx_res['error']}")
    else:
        print("  Missing results from one or both models")
    
    print()

## 8. Visualize Token Distributions

In [None]:
# Visualize token distributions for one masking pattern
test_pattern = "every_4"

if test_pattern in onnx_results and 'error' not in onnx_results[test_pattern]:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes = axes.flatten()
    
    # Get data
    original = coarse_codes_numpy[0]  # Shape: (4, 100)
    onnx_gen = onnx_results[test_pattern]['generated_codes'][0]
    mask = mask_patterns[test_pattern][0]
    
    for cb in range(4):
        ax = axes[cb]
        
        # Plot original and generated
        x = np.arange(100)
        ax.plot(x, original[cb], 'b-', label='Original', alpha=0.7)
        ax.plot(x, onnx_gen[cb], 'r--', label='Generated', alpha=0.7)
        
        # Highlight masked positions
        masked_indices = np.where(mask[cb])[0]
        ax.scatter(masked_indices, original[cb, masked_indices], 
                  color='blue', s=100, marker='o', edgecolor='black', 
                  label='Original (masked)', zorder=5)
        ax.scatter(masked_indices, onnx_gen[cb, masked_indices], 
                  color='red', s=100, marker='x', 
                  label='Generated (masked)', zorder=5)
        
        ax.set_title(f'Codebook {cb}')
        ax.set_xlabel('Token Index')
        ax.set_ylabel('Token Value')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.suptitle(f'Token Generation with {test_pattern} Masking', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Show statistics
    print(f"\nStatistics for {test_pattern} masking:")
    print(f"Total masked positions: {mask.sum()}")
    print(f"Generated tokens range: [{onnx_gen.min()}, {onnx_gen.max()}]")
    
    # Check if any mask tokens (1024) remain
    mask_token = 1024
    remaining_masks = (onnx_gen == mask_token).sum()
    print(f"Remaining mask tokens: {remaining_masks}")
    
    # Token diversity
    for cb in range(4):
        unique_orig = len(np.unique(original[cb]))
        unique_gen = len(np.unique(onnx_gen[cb]))
        print(f"Codebook {cb} - Unique tokens: Original={unique_orig}, Generated={unique_gen}")

## 9. Test Embedding Issue

Check if the transformer is using raw codes instead of embeddings (a known issue).

In [None]:
# Check if codes are being used directly instead of embeddings
print("Testing for direct code usage vs embeddings...\n")

# Create a simple test case
test_codes = np.array([[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 10]], dtype=np.int64)  # Shape: (1, 1, 100)
test_codes = np.tile(test_codes, (1, 4, 1))  # Shape: (1, 4, 100)

# Create a mask for positions 0, 2, 4, 6, 8
test_mask = np.zeros((1, 4, 100), dtype=bool)
test_mask[:, :, ::2] = True  # Mask every other position

print(f"Test codes shape: {test_codes.shape}")
print(f"Test codes sample: {test_codes[0, 0, :10]}")
print(f"Masked positions: {np.where(test_mask[0, 0])[0][:10]}")

# Run through ONNX
try:
    test_output = onnx_coarse.run(None, {
        'codes': test_codes,
        'mask': test_mask
    })
    
    generated = test_output[0]
    print(f"\nGenerated shape: {generated.shape}")
    print(f"Generated sample: {generated[0, 0, :10]}")
    
    # Check if the generated values are suspiciously similar to input codes
    # If codes are used directly, we might see patterns related to the input values
    matches = (generated == test_codes)
    print(f"\nExact matches with input: {matches.mean():.1%}")
    
    # Check value ranges
    print(f"\nValue ranges:")
    print(f"  Input: [{test_codes.min()}, {test_codes.max()}]")
    print(f"  Generated: [{generated.min()}, {generated.max()}]")
    
    # If embeddings are working correctly, generated values should be in [0, 1023]
    # If codes are used directly, we might see strange patterns
    if generated.max() > 1023:
        print("\n⚠️ WARNING: Generated values exceed vocabulary size!")
        print("This suggests the model might be using codes directly instead of embeddings.")
    
except Exception as e:
    print(f"Error in test: {e}")

## 10. Decode and Listen to Results

In [None]:
# Decode one of the generated results
test_pattern = "every_4"

if test_pattern in onnx_results and 'error' not in onnx_results[test_pattern]:
    print(f"Decoding results for {test_pattern} masking...")
    
    # Get the generated coarse codes
    generated_coarse = onnx_results[test_pattern]['generated_codes']
    
    # We need all 14 codebooks for decoding
    # Use original fine codes (codebooks 4-13)
    full_codes = codes_numpy.copy()
    full_codes[:, :n_coarse, :] = generated_coarse
    
    # Convert to torch
    full_codes_torch = torch.from_numpy(full_codes).long().to(device)
    
    # Decode
    with torch.no_grad():
        audio_dict = interface.decode(full_codes_torch)
        reconstructed = audio_dict.audio_data.squeeze().cpu().numpy()
    
    print(f"Original audio shape: {audio.shape}")
    print(f"Reconstructed shape: {reconstructed.shape}")
    
    # Save audio files
    output_dir = Path("outputs/transformer_masking_test")
    output_dir.mkdir(exist_ok=True, parents=True)
    
    sf.write(output_dir / "original.wav", audio, sample_rate)
    sf.write(output_dir / f"generated_{test_pattern}.wav", reconstructed[:len(audio)], sample_rate)
    
    # Also decode with original codes for comparison
    with torch.no_grad():
        audio_dict_orig = interface.decode(codes)
        reconstructed_orig = audio_dict_orig.audio_data.squeeze().cpu().numpy()
    
    sf.write(output_dir / "reconstructed_original_codes.wav", reconstructed_orig[:len(audio)], sample_rate)
    
    print(f"\nAudio files saved to {output_dir}")
    print("  - original.wav: Original test audio")
    print(f"  - generated_{test_pattern}.wav: Generated with {test_pattern} masking")
    print("  - reconstructed_original_codes.wav: Reconstructed from original codes")
    
    # Plot comparison
    fig, axes = plt.subplots(3, 1, figsize=(12, 10))
    
    t_plot = np.arange(4410) / sample_rate  # First 0.1 seconds
    
    axes[0].plot(t_plot, audio[:4410])
    axes[0].set_title('Original Audio')
    axes[0].set_ylabel('Amplitude')
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(t_plot, reconstructed_orig[:4410], color='green')
    axes[1].set_title('Reconstructed (Original Codes)')
    axes[1].set_ylabel('Amplitude')
    axes[1].grid(True, alpha=0.3)
    
    axes[2].plot(t_plot, reconstructed[:4410], color='red')
    axes[2].set_title(f'Generated ({test_pattern} masking)')
    axes[2].set_xlabel('Time (s)')
    axes[2].set_ylabel('Amplitude')
    axes[2].grid(True, alpha=0.3)
    
    plt.suptitle('Audio Comparison (first 0.1s)', fontsize=14)
    plt.tight_layout()
    plt.show()

## Summary

In [None]:
print("\n" + "=" * 70)
print("TRANSFORMER MASKING TEST SUMMARY")
print("=" * 70)

print(f"\nTest configuration:")
print(f"  Audio duration: {duration:.2f} seconds")
print(f"  Number of tokens: {n_tokens}")
print(f"  Coarse codebooks: {n_coarse}")
print(f"  Encoder match rate: {match_rate:.1%}")

print(f"\nMasking patterns tested:")
for name, mask in mask_patterns.items():
    print(f"  {name}: {mask.mean():.1%} masked")

print(f"\nKey findings:")
if match_rate > 0.99:
    print("  ✅ Encoders produce identical results")
else:
    print("  ❌ Encoder mismatch detected")

# Check for embedding issues
has_embedding_issue = False
for name, result in onnx_results.items():
    if 'generated_codes' in result:
        if result['generated_codes'].max() > 1023:
            has_embedding_issue = True
            break

if has_embedding_issue:
    print("  ⚠️  WARNING: Possible embedding issue detected!")
    print("     Generated tokens exceed vocabulary size.")
    print("     The model might be using codes directly instead of embeddings.")
else:
    print("  ✅ Token values within expected range [0, 1023]")

print("\n" + "=" * 70)