# VampNet Masking Study 02: Random vs Regular Masking Patterns

This notebook compares different masking patterns:
- Random masking (uniformly distributed)
- Regular/periodic masking (every Nth token)
- Block masking (contiguous regions)
- Strided patterns

In [None]:
import IPython.display as ipd
import audiotools as at
import matplotlib.pyplot as plt
import numpy as np
import torch
import vampnet
from pathlib import Path

# Create output directory
output_dir = Path("outputs/02_random_regular")
output_dir.mkdir(parents=True, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load VampNet and audio
interface = vampnet.interface.Interface.default()
interface.to(device)

signal = at.AudioSignal("assets/stargazing.wav")
signal = interface._preprocess(signal)
signal.write(output_dir / "00_original.wav")

codes = interface.encode(signal)
print(f"Encoded shape: {codes.shape}")

In [None]:
def create_random_mask(codes, density=0.5, upper_codebook_mask=3):
    """Uniform random masking"""
    batch_size, n_codebooks, seq_len = codes.shape
    mask = torch.rand(batch_size, n_codebooks, seq_len) < density
    mask = mask.long().to(codes.device)
    
    if upper_codebook_mask > 0:
        mask[:, upper_codebook_mask:, :] = 1
    
    return mask

def create_periodic_mask(codes, period=7, width=1, upper_codebook_mask=3):
    """Keep every Nth token (VampNet's periodic prompt)"""
    batch_size, n_codebooks, seq_len = codes.shape
    mask = torch.ones(batch_size, n_codebooks, seq_len).long().to(codes.device)
    
    # Keep tokens at periodic intervals
    offset = torch.randint(0, period, (1,)).item()
    for i in range(offset, seq_len, period):
        end_idx = min(i + width, seq_len)
        mask[:, :, i:end_idx] = 0
    
    if upper_codebook_mask > 0:
        mask[:, upper_codebook_mask:, :] = 1
    
    return mask

def create_block_mask(codes, block_size=50, density=0.5, upper_codebook_mask=3):
    """Mask contiguous blocks"""
    batch_size, n_codebooks, seq_len = codes.shape
    mask = torch.zeros(batch_size, n_codebooks, seq_len).long().to(codes.device)
    
    # Calculate number of blocks to mask
    n_blocks = int((seq_len / block_size) * density)
    
    for _ in range(n_blocks):
        start = torch.randint(0, seq_len - block_size, (1,)).item()
        mask[:, :, start:start + block_size] = 1
    
    if upper_codebook_mask > 0:
        mask[:, upper_codebook_mask:, :] = 1
    
    return mask

def create_strided_mask(codes, stride=10, mask_length=5, upper_codebook_mask=3):
    """Regular strided pattern"""
    batch_size, n_codebooks, seq_len = codes.shape
    mask = torch.zeros(batch_size, n_codebooks, seq_len).long().to(codes.device)
    
    for i in range(0, seq_len, stride):
        end_idx = min(i + mask_length, seq_len)
        mask[:, :, i:end_idx] = 1
    
    if upper_codebook_mask > 0:
        mask[:, upper_codebook_mask:, :] = 1
    
    return mask

In [None]:
# Visualize different masking patterns
fig, axes = plt.subplots(4, 1, figsize=(15, 12))

# Random mask
mask_random = create_random_mask(codes, density=0.5)
axes[0].imshow(mask_random[0, :4, :200].cpu().numpy(), aspect='auto', cmap='RdBu')
axes[0].set_title('Random Masking (50% density)')
axes[0].set_ylabel('Codebook')

# Periodic mask
mask_periodic = create_periodic_mask(codes, period=7)
axes[1].imshow(mask_periodic[0, :4, :200].cpu().numpy(), aspect='auto', cmap='RdBu')
axes[1].set_title('Periodic Masking (keep every 7th token)')
axes[1].set_ylabel('Codebook')

# Block mask
mask_block = create_block_mask(codes, block_size=30, density=0.5)
axes[2].imshow(mask_block[0, :4, :200].cpu().numpy(), aspect='auto', cmap='RdBu')
axes[2].set_title('Block Masking (30-token blocks, 50% density)')
axes[2].set_ylabel('Codebook')

# Strided mask
mask_strided = create_strided_mask(codes, stride=15, mask_length=10)
axes[3].imshow(mask_strided[0, :4, :200].cpu().numpy(), aspect='auto', cmap='RdBu')
axes[3].set_title('Strided Masking (10 masked, 5 kept)')
axes[3].set_ylabel('Codebook')
axes[3].set_xlabel('Time Steps')

plt.tight_layout()
plt.savefig(output_dir / "masking_patterns_visualization.png")
plt.show()

In [None]:
# Generate outputs for different masking patterns
patterns = [
    ("random_30", lambda c: create_random_mask(c, density=0.3)),
    ("random_50", lambda c: create_random_mask(c, density=0.5)),
    ("random_70", lambda c: create_random_mask(c, density=0.7)),
    ("periodic_5", lambda c: create_periodic_mask(c, period=5)),
    ("periodic_10", lambda c: create_periodic_mask(c, period=10)),
    ("periodic_20", lambda c: create_periodic_mask(c, period=20)),
    ("block_20", lambda c: create_block_mask(c, block_size=20, density=0.5)),
    ("block_50", lambda c: create_block_mask(c, block_size=50, density=0.5)),
    ("block_100", lambda c: create_block_mask(c, block_size=100, density=0.5)),
    ("strided_5_5", lambda c: create_strided_mask(c, stride=10, mask_length=5)),
    ("strided_10_10", lambda c: create_strided_mask(c, stride=20, mask_length=10)),
]

results = []

print("Generating outputs for different masking patterns...")
for name, mask_fn in patterns:
    print(f"\nProcessing pattern: {name}")
    
    # Create mask
    mask = mask_fn(codes)
    
    # Generate with VampNet
    with torch.no_grad():
        output_tokens = interface.vamp(
            codes, mask,
            return_mask=False,
            temperature=1.0,
            typical_filtering=True,
            batch_size=1
        )
    
    # Decode to audio
    output_signal = interface.decode(output_tokens)
    
    # Save audio
    filename = f"pattern_{name}.wav"
    output_signal.write(output_dir / filename)
    
    # Calculate mask density
    density = mask.float().mean().item()
    
    results.append({
        'name': name,
        'tokens': output_tokens,
        'signal': output_signal,
        'mask': mask,
        'density': density,
        'filename': filename
    })
    
    print(f"  Mask density: {density:.3f}")
    print(f"  Saved: {filename}")

print("\nAll patterns generated!")

In [None]:
# Analyze token changes by pattern type
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Group results by pattern type
pattern_types = {
    'Random': [r for r in results if 'random' in r['name']],
    'Periodic': [r for r in results if 'periodic' in r['name']],
    'Block': [r for r in results if 'block' in r['name']],
    'Strided': [r for r in results if 'strided' in r['name']]
}

# Plot 1: Token changes vs density for each pattern type
colors = ['blue', 'red', 'green', 'orange']
for (pattern_name, pattern_results), color in zip(pattern_types.items(), colors):
    densities = [r['density'] for r in pattern_results]
    changes = [(r['tokens'] != codes).float().mean().item() * 100 for r in pattern_results]
    ax1.scatter(densities, changes, label=pattern_name, s=100, alpha=0.7, color=color)

ax1.set_xlabel('Mask Density')
ax1.set_ylabel('Percentage of Tokens Changed (%)')
ax1.set_title('Token Changes by Pattern Type')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Bar chart of average changes by pattern type
avg_changes = {}
for pattern_name, pattern_results in pattern_types.items():
    changes = [(r['tokens'] != codes).float().mean().item() * 100 for r in pattern_results]
    avg_changes[pattern_name] = np.mean(changes)

ax2.bar(avg_changes.keys(), avg_changes.values(), color=colors, alpha=0.7)
ax2.set_ylabel('Average Token Changes (%)')
ax2.set_title('Average Token Changes by Pattern Type')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(output_dir / "pattern_analysis.png")
plt.show()

In [None]:
# Audio comparison
print("Audio comparison of different masking patterns:")
print("\nOriginal:")
ipd.display(ipd.Audio(signal.samples.squeeze().numpy(), rate=signal.sample_rate))

# Show one example from each pattern type
examples = [
    ("Random (50%)", "random_50"),
    ("Periodic (every 10)", "periodic_10"),
    ("Block (50 tokens)", "block_50"),
    ("Strided (10/10)", "strided_10_10")
]

for display_name, pattern_name in examples:
    result = next(r for r in results if r['name'] == pattern_name)
    print(f"\n{display_name}:")
    ipd.display(ipd.Audio(result['signal'].samples.squeeze().numpy(), rate=result['signal'].sample_rate))

## Observations

1. **Random masking**: Produces smooth, evenly distributed changes throughout the audio
2. **Periodic masking**: Maintains temporal structure while allowing variation, good for preserving rhythm
3. **Block masking**: Creates sections of original audio interspersed with generated content
4. **Strided masking**: Regular alternation between original and generated content

Different patterns are suitable for different creative applications:
- Use random for general variation
- Use periodic for rhythm-preserving variations
- Use block for creating mashup-like effects
- Use strided for regular, predictable variations