# VampNet Masking Study 04: Codebook Masking Depth

This notebook investigates how the `upper_codebook_mask` parameter affects generation. This parameter controls which codebooks are preserved vs masked, affecting the level of acoustic detail that's regenerated.

In [None]:
import IPython.display as ipd
import audiotools as at
import matplotlib.pyplot as plt
import numpy as np
import torch
import vampnet
from pathlib import Path

# Create output directory
output_dir = Path("outputs/04_codebook_depth")
output_dir.mkdir(parents=True, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load VampNet and audio
interface = vampnet.interface.Interface.default()
interface.to(device)

signal = at.AudioSignal("assets/stargazing.wav")
signal = interface._preprocess(signal)
signal.write(output_dir / "00_original.wav")

codes = interface.encode(signal)
print(f"Encoded shape: {codes.shape}")
print(f"Total codebooks: {codes.shape[1]}")
print(f"Coarse codebooks: {interface.coarse.n_codebooks}")
print(f"Fine codebooks: {codes.shape[1] - interface.coarse.n_codebooks}")

In [None]:
# Understanding codebook hierarchy
print("\nCodebook hierarchy in VampNet:")
print("- Lower codebooks (0-3): Coarse acoustic features")
print("  - Fundamental frequencies, pitch, rhythm")
print("  - Core musical structure")
print("\n- Higher codebooks (4-13): Fine acoustic details")
print("  - Timbre, texture, harmonics")
print("  - Specific instrument characteristics")
print("\nThe upper_codebook_mask parameter determines how many lower codebooks to preserve.")

In [None]:
# Visualize different codebook masking depths
codebook_depths = [0, 1, 2, 3, 4]  # 0 means mask all, 4 means preserve all coarse
fig, axes = plt.subplots(len(codebook_depths), 1, figsize=(15, 2 * len(codebook_depths)))

for idx, depth in enumerate(codebook_depths):
    # Create mask with different codebook depths
    mask = interface.build_mask(
        codes, signal,
        periodic_prompt=13,  # Fixed periodic prompt
        upper_codebook_mask=depth,
        _dropout=0.0
    )
    
    # Show all codebooks, first 200 timesteps
    ax = axes[idx]
    ax.imshow(mask[0, :, :200].cpu().numpy(), aspect='auto', cmap='RdBu')
    
    # Add codebook labels
    ax.set_yticks(range(14))
    ax.set_yticklabels([f'CB{i}' for i in range(14)])
    
    if depth == 0:
        title = f'Depth = {depth} (all codebooks masked)'
    else:
        title = f'Depth = {depth} (preserve codebooks 0-{depth-1})'
    ax.set_title(title)
    
    if idx == len(codebook_depths) - 1:
        ax.set_xlabel('Time Steps')

plt.tight_layout()
plt.savefig(output_dir / "codebook_depth_visualization.png")
plt.show()

In [None]:
# Generate outputs for different codebook depths
results = []

print("Generating outputs for different codebook masking depths...")
for depth in codebook_depths:
    print(f"\nProcessing upper_codebook_mask = {depth}")
    
    # Create mask
    mask = interface.build_mask(
        codes, signal,
        periodic_prompt=13,  # Keep consistent
        upper_codebook_mask=depth,
        _dropout=0.0
    )
    
    # Calculate mask statistics
    total_density = mask.float().mean().item()
    coarse_density = mask[:, :interface.coarse.n_codebooks, :].float().mean().item()
    
    # Generate with VampNet
    with torch.no_grad():
        output_tokens = interface.vamp(
            codes, mask,
            return_mask=False,
            temperature=1.0,
            typical_filtering=True,
            batch_size=1
        )
    
    # Decode to audio
    output_signal = interface.decode(output_tokens)
    
    # Save audio
    filename = f"codebook_depth_{depth}.wav"
    output_signal.write(output_dir / filename)
    
    results.append({
        'depth': depth,
        'tokens': output_tokens,
        'signal': output_signal,
        'mask': mask,
        'total_density': total_density,
        'coarse_density': coarse_density,
        'filename': filename
    })
    
    print(f"  Total mask density: {total_density:.3f}")
    print(f"  Coarse mask density: {coarse_density:.3f}")
    print(f"  Saved: {filename}")

print("\nAll depths generated!")

In [None]:
# Analyze token changes by codebook
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

# Plot 1: Heatmap of changes by codebook and depth
changes_matrix = np.zeros((len(codebook_depths), codes.shape[1]))

for i, result in enumerate(results):
    for cb in range(codes.shape[1]):
        changes = (result['tokens'][0, cb, :] != codes[0, cb, :]).float().mean().item() * 100
        changes_matrix[i, cb] = changes

im = ax1.imshow(changes_matrix.T, aspect='auto', cmap='viridis')
ax1.set_xlabel('Upper Codebook Mask Value')
ax1.set_ylabel('Codebook Index')
ax1.set_title('Percentage of Tokens Changed by Codebook')
ax1.set_xticks(range(len(codebook_depths)))
ax1.set_xticklabels(codebook_depths)
ax1.set_yticks(range(14))
plt.colorbar(im, ax=ax1, label='% Changed')

# Add lines to separate coarse and fine codebooks
ax1.axhline(y=3.5, color='red', linestyle='--', linewidth=2)
ax1.text(-0.5, 1.5, 'Coarse', rotation=90, va='center', color='red')
ax1.text(-0.5, 8, 'Fine', rotation=90, va='center', color='red')

# Plot 2: Total changes vs depth
total_changes = [(r['tokens'] != codes).float().mean().item() * 100 for r in results]
ax2.plot(codebook_depths, total_changes, 'o-', linewidth=2, markersize=10)
ax2.set_xlabel('Upper Codebook Mask Value')
ax2.set_ylabel('Total Percentage Changed (%)')
ax2.set_title('Total Token Changes vs Codebook Masking Depth')
ax2.grid(True, alpha=0.3)
ax2.set_xticks(codebook_depths)

plt.tight_layout()
plt.savefig(output_dir / "codebook_analysis.png")
plt.show()

In [None]:
# Audio comparison
print("Audio comparison of different codebook masking depths:")
print("\nOriginal:")
ipd.display(ipd.Audio(signal.samples.squeeze().numpy(), rate=signal.sample_rate))

for result in results:
    depth = result['depth']
    if depth == 0:
        desc = "All codebooks masked (maximum variation)"
    elif depth == interface.coarse.n_codebooks:
        desc = "All coarse codebooks preserved (fine details only)"
    else:
        desc = f"Preserve codebooks 0-{depth-1}"
    
    print(f"\nDepth = {depth} ({desc}):")
    ipd.display(ipd.Audio(result['signal'].samples.squeeze().numpy(), rate=result['signal'].sample_rate))

In [None]:
# Special comparison: Fixed periodic prompt with varying depths
print("\nSpecial test: Very sparse periodic prompt (every 50 tokens) with different depths:")

sparse_results = []
for depth in [0, 2, 4]:
    mask = interface.build_mask(
        codes, signal,
        periodic_prompt=50,  # Very sparse
        upper_codebook_mask=depth,
        _dropout=0.0
    )
    
    with torch.no_grad():
        output_tokens = interface.vamp(
            codes, mask,
            return_mask=False,
            temperature=1.0,
            typical_filtering=True,
            batch_size=1
        )
    
    output_signal = interface.decode(output_tokens)
    filename = f"sparse_depth_{depth}.wav"
    output_signal.write(output_dir / filename)
    
    print(f"\nSparse prompt, depth = {depth}:")
    ipd.display(ipd.Audio(output_signal.samples.squeeze().numpy(), rate=output_signal.sample_rate))

## Observations

1. **Depth = 0**: All codebooks masked - maximum variation, may lose musical coherence
2. **Depth = 1-2**: Preserves fundamental pitch/rhythm, varies harmony and timbre
3. **Depth = 3**: Preserves most structure, varies fine details
4. **Depth = 4**: Only fine codebooks masked - subtle timbral variations

Musical applications:
- **Remixing/Transformation**: Use depth 0-1 for dramatic changes
- **Variation/Improvisation**: Use depth 2-3 for musical variations
- **Sound Design/Texture**: Use depth 4+ for timbral modifications

The codebook hierarchy allows precise control over what aspects of the music are preserved vs regenerated.