# VampNet Masking Study 03: Periodic Prompt Intervals

This notebook investigates how different periodic prompt intervals affect VampNet's output. The periodic prompt parameter determines which tokens are kept unmasked at regular intervals.

In [None]:
import IPython.display as ipd
import audiotools as at
import matplotlib.pyplot as plt
import numpy as np
import torch
import vampnet
from pathlib import Path

# Create output directory
output_dir = Path("outputs/03_periodic_intervals")
output_dir.mkdir(parents=True, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load VampNet and audio
interface = vampnet.interface.Interface.default()
interface.to(device)

signal = at.AudioSignal("assets/stargazing.wav")
signal = interface._preprocess(signal)
signal.write(output_dir / "00_original.wav")

codes = interface.encode(signal)
print(f"Encoded shape: {codes.shape}")

# Calculate timing information
hop_length = interface.codec.hop_length
sample_rate = interface.codec.sample_rate
tokens_per_second = sample_rate / hop_length
print(f"\nTiming information:")
print(f"Tokens per second: {tokens_per_second:.1f}")
print(f"Milliseconds per token: {1000 / tokens_per_second:.1f} ms")

In [None]:
# Test different periodic prompt intervals
# From very frequent (every 2 tokens) to very sparse (every 100 tokens)
intervals = [2, 3, 5, 7, 10, 13, 20, 30, 50, 70, 100]

# Calculate what these intervals mean in terms of time
print("Periodic prompt intervals and their time equivalents:")
for interval in intervals:
    time_ms = interval * (1000 / tokens_per_second)
    print(f"Every {interval:3d} tokens = every {time_ms:6.1f} ms")

In [None]:
# Visualize different periodic intervals
fig, axes = plt.subplots(len(intervals), 1, figsize=(15, 2 * len(intervals)))

for idx, interval in enumerate(intervals):
    # Create mask using VampNet's build_mask with periodic_prompt
    mask = interface.build_mask(
        codes, signal,
        periodic_prompt=interval,
        upper_codebook_mask=3,
        _dropout=0.0
    )
    
    # Show first 4 codebooks, first 200 timesteps
    ax = axes[idx]
    ax.imshow(mask[0, :4, :200].cpu().numpy(), aspect='auto', cmap='RdBu')
    
    # Calculate actual density
    density = mask[:, :3, :].float().mean().item()  # Only count first 3 codebooks
    time_ms = interval * (1000 / tokens_per_second)
    
    ax.set_title(f'Interval = {interval} tokens ({time_ms:.0f} ms), Density = {density:.3f}')
    ax.set_ylabel('CB')
    
    if idx == len(intervals) - 1:
        ax.set_xlabel('Time Steps')

plt.tight_layout()
plt.savefig(output_dir / "periodic_intervals_visualization.png")
plt.show()

In [None]:
# Generate outputs for different periodic intervals
results = []

print("Generating outputs for different periodic intervals...")
for interval in intervals:
    print(f"\nProcessing interval = {interval}")
    
    # Create mask with periodic prompt
    mask = interface.build_mask(
        codes, signal,
        periodic_prompt=interval,
        upper_codebook_mask=3,
        _dropout=0.0
    )
    
    # Calculate mask density
    density = mask[:, :3, :].float().mean().item()
    time_ms = interval * (1000 / tokens_per_second)
    
    # Generate with VampNet
    with torch.no_grad():
        output_tokens = interface.vamp(
            codes, mask,
            return_mask=False,
            temperature=1.0,
            typical_filtering=True,
            batch_size=1
        )
    
    # Decode to audio
    output_signal = interface.decode(output_tokens)
    
    # Save audio
    filename = f"interval_{interval:03d}.wav"
    output_signal.write(output_dir / filename)
    
    results.append({
        'interval': interval,
        'time_ms': time_ms,
        'density': density,
        'tokens': output_tokens,
        'signal': output_signal,
        'mask': mask,
        'filename': filename
    })
    
    print(f"  Time interval: {time_ms:.1f} ms")
    print(f"  Mask density: {density:.3f}")
    print(f"  Saved: {filename}")

print("\nAll intervals generated!")

In [None]:
# Analyze the relationship between interval and token changes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

intervals_list = [r['interval'] for r in results]
densities = [r['density'] for r in results]
changes = [(r['tokens'] != codes).float().mean().item() * 100 for r in results]

# Plot 1: Interval vs Token Changes
ax1.plot(intervals_list, changes, 'o-', linewidth=2, markersize=8)
ax1.set_xlabel('Periodic Interval (tokens)')
ax1.set_ylabel('Percentage of Tokens Changed (%)')
ax1.set_title('Token Changes vs Periodic Interval')
ax1.grid(True, alpha=0.3)
ax1.set_xscale('log')

# Plot 2: Density vs Token Changes
ax2.plot(densities, changes, 'o-', linewidth=2, markersize=8, color='red')
ax2.set_xlabel('Mask Density')
ax2.set_ylabel('Percentage of Tokens Changed (%)')
ax2.set_title('Token Changes vs Mask Density')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / "interval_analysis.png")
plt.show()

# Print summary
print("\nSummary:")
for r in results:
    change_pct = (r['tokens'] != codes).float().mean().item() * 100
    print(f"Interval {r['interval']:3d} ({r['time_ms']:5.0f} ms): "
          f"density={r['density']:.3f}, changed={change_pct:.1f}%")

In [None]:
# Audio comparison
print("Audio comparison of different periodic intervals:")
print("\nOriginal:")
ipd.display(ipd.Audio(signal.samples.squeeze().numpy(), rate=signal.sample_rate))

# Show key intervals: very frequent, medium, and sparse
key_intervals = [3, 10, 30, 70]
for interval in key_intervals:
    result = next(r for r in results if r['interval'] == interval)
    print(f"\nInterval = {interval} tokens ({result['time_ms']:.0f} ms):")
    print(f"(Keeps every {interval}th token, mask density = {result['density']:.3f})")
    ipd.display(ipd.Audio(result['signal'].samples.squeeze().numpy(), rate=result['signal'].sample_rate))

In [None]:
# Create a spectrogram comparison
from scipy import signal as scipy_signal

fig, axes = plt.subplots(3, 2, figsize=(15, 12))

# Original spectrogram
f, t, Sxx = scipy_signal.spectrogram(signal.samples[0, 0].cpu().numpy(), 
                                     fs=signal.sample_rate, nperseg=2048)
axes[0, 0].pcolormesh(t, f[:1000], 10 * np.log10(Sxx[:1000]), shading='gouraud')
axes[0, 0].set_title('Original')
axes[0, 0].set_ylabel('Frequency [Hz]')

# Show spectrograms for different intervals
comparison_intervals = [3, 7, 13, 30, 70]
for idx, interval in enumerate(comparison_intervals):
    result = next(r for r in results if r['interval'] == interval)
    row = (idx + 1) // 2
    col = (idx + 1) % 2
    
    f, t, Sxx = scipy_signal.spectrogram(result['signal'].samples[0, 0].cpu().numpy(), 
                                         fs=result['signal'].sample_rate, nperseg=2048)
    axes[row, col].pcolormesh(t, f[:1000], 10 * np.log10(Sxx[:1000]), shading='gouraud')
    axes[row, col].set_title(f'Interval = {interval} ({result["time_ms"]:.0f} ms)')
    
    if row == 2:
        axes[row, col].set_xlabel('Time [s]')
    if col == 0:
        axes[row, col].set_ylabel('Frequency [Hz]')

plt.tight_layout()
plt.savefig(output_dir / "spectrogram_comparison.png")
plt.show()

## Observations

1. **Very frequent prompts (2-5 tokens)**: Minimal variation, output stays very close to original
2. **Moderate intervals (7-20 tokens)**: Good balance between preservation and variation
3. **Large intervals (30-70 tokens)**: Significant variation while maintaining overall structure
4. **Very large intervals (100+ tokens)**: Maximum variation, may lose coherence with original

The periodic prompt interval directly controls the trade-off between:
- **Fidelity**: How closely the output matches the original
- **Creativity**: How much variation is introduced

Musical considerations:
- For preserving rhythm: Use intervals that align with the beat (e.g., if 57 tokens ≈ 1 second)
- For harmonic preservation: Smaller intervals (5-15) work well
- For creative remixing: Larger intervals (30-70) provide more freedom