# VampNet Masking Study 01: Sparse vs Dense Masking

This notebook investigates how mask density affects VampNet's output by varying the amount of masked tokens from sparse (few masked) to dense (many masked).

In [None]:
import IPython.display as ipd
import audiotools as at
import matplotlib.pyplot as plt
import numpy as np
import torch
import vampnet
import os
from pathlib import Path

# Create output directory for saved audio
output_dir = Path("outputs/01_sparse_dense")
output_dir.mkdir(parents=True, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load VampNet interface
interface = vampnet.interface.Interface.default()
interface.to(device)

# Load and preprocess input audio
signal = at.AudioSignal("assets/stargazing.wav")
signal = interface._preprocess(signal)
print(f"Input signal shape: {signal.samples.shape}")
print(f"Duration: {signal.duration:.2f}s")

# Save original for reference
signal.write(output_dir / "00_original.wav")

# Encode to tokens
codes = interface.encode(signal)
print(f"Encoded shape: {codes.shape}")

In [None]:
# Function to create masks with different densities
def create_density_mask(codes, density, upper_codebook_mask=3):
    """Create a mask with specified density (0.0 = no masking, 1.0 = all masked)"""
    batch_size, n_codebooks, seq_len = codes.shape
    
    # Create random mask with specified density
    mask = torch.rand(batch_size, n_codebooks, seq_len) < density
    mask = mask.long().to(codes.device)
    
    # Apply codebook masking (only mask certain codebooks)
    if upper_codebook_mask > 0:
        # For codebooks beyond upper_codebook_mask, mask everything
        mask[:, upper_codebook_mask:, :] = 1
    
    return mask

# Visualize different mask densities
fig, axes = plt.subplots(3, 2, figsize=(15, 10))
densities = [0.1, 0.3, 0.5, 0.7, 0.9, 0.95]

for idx, density in enumerate(densities):
    ax = axes[idx // 2, idx % 2]
    mask = create_density_mask(codes, density)
    
    # Show first 4 codebooks, first 100 timesteps
    ax.imshow(mask[0, :4, :100].cpu().numpy(), aspect='auto', cmap='RdBu')
    ax.set_title(f'Density = {density} ({density*100:.0f}% masked)')
    ax.set_ylabel('Codebook')
    if idx >= 4:
        ax.set_xlabel('Time Steps')

plt.tight_layout()
plt.savefig(output_dir / "mask_densities_visualization.png")
plt.show()

In [None]:
# Generate outputs for different mask densities
densities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
results = []

print("Generating outputs for different mask densities...")
for density in densities:
    print(f"\nProcessing density = {density}")
    
    # Create mask
    mask = create_density_mask(codes, density)
    
    # Generate with VampNet
    with torch.no_grad():
        output_tokens = interface.vamp(
            codes, mask,
            return_mask=False,
            temperature=1.0,
            typical_filtering=True,  # As requested
            batch_size=1
        )
    
    # Decode to audio
    output_signal = interface.decode(output_tokens)
    
    # Save audio
    filename = f"density_{int(density*100):02d}.wav"
    output_signal.write(output_dir / filename)
    
    # Store results
    results.append({
        'density': density,
        'tokens': output_tokens,
        'signal': output_signal,
        'mask': mask,
        'filename': filename
    })
    
    print(f"  Saved: {filename}")

print("\nAll outputs generated and saved!")

In [None]:
# Analyze how many tokens were changed at each density
plt.figure(figsize=(10, 6))

changed_percentages = []
for result in results:
    # Compare generated tokens with original
    changed = (result['tokens'] != codes).float().mean().item() * 100
    changed_percentages.append(changed)

plt.plot([r['density'] for r in results], changed_percentages, 'o-', linewidth=2, markersize=8)
plt.xlabel('Mask Density')
plt.ylabel('Percentage of Tokens Changed (%)')
plt.title('Token Changes vs Mask Density')
plt.grid(True, alpha=0.3)
plt.savefig(output_dir / "tokens_changed_analysis.png")
plt.show()

# Print summary
print("Summary of token changes:")
for i, result in enumerate(results):
    print(f"Density {result['density']:.1f}: {changed_percentages[i]:.1f}% tokens changed")

In [None]:
# Create audio player for comparing outputs
print("Audio comparison (listen to how the output changes with mask density):")
print("\nOriginal:")
ipd.display(ipd.Audio(signal.samples.squeeze().numpy(), rate=signal.sample_rate))

# Show a few key densities
key_densities = [0.1, 0.3, 0.5, 0.7, 0.9]
for result in results:
    if result['density'] in key_densities:
        print(f"\nDensity = {result['density']} ({result['density']*100:.0f}% masked):")
        ipd.display(ipd.Audio(result['signal'].samples.squeeze().numpy(), rate=result['signal'].sample_rate))

## Observations

1. **Low density (10-30%)**: Minimal changes, output closely resembles the original
2. **Medium density (40-60%)**: Noticeable variations while maintaining structure
3. **High density (70-90%)**: Significant changes, more creative variations
4. **Very high density (>90%)**: Output may diverge significantly from original

The relationship between mask density and output variation appears roughly linear, with higher mask densities producing more dramatic changes to the audio.