# Audio Preprocessing Pipeline

Complete walkthrough of audio preprocessing for ML, with profiling at every step.

**Topics covered:**
- Audio loading and format handling
- Resampling strategies
- Spectrogram extraction
- Mel spectrogram computation
- Data augmentation
- Batch processing optimization

In [None]:
# Install dependencies
# !pip install torch torchaudio librosa soundfile matplotlib numpy

In [None]:
import time
import numpy as np
import torch
import torchaudio
from pathlib import Path
from typing import Tuple, Optional
import matplotlib.pyplot as plt

print(f"PyTorch: {torch.__version__}")
print(f"torchaudio: {torchaudio.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Audio Loading Comparison

Compare different audio loading backends.

In [None]:
def generate_test_audio(duration_sec: float, sample_rate: int = 16000) -> torch.Tensor:
    """Generate synthetic speech-like audio for testing."""
    t = torch.linspace(0, duration_sec, int(duration_sec * sample_rate))
    
    # Multi-frequency signal (speech-like)
    f0 = 150 + 50 * torch.sin(2 * np.pi * 0.5 * t)  # Varying pitch
    audio = (
        0.5 * torch.sin(2 * np.pi * f0 * t) +
        0.3 * torch.sin(2 * np.pi * 2 * f0 * t) +
        0.1 * torch.sin(2 * np.pi * 3 * f0 * t) +
        0.05 * torch.randn_like(t)
    )
    return audio.unsqueeze(0)  # Add channel dimension

# Generate test audio
test_audio = generate_test_audio(5.0, 16000)
print(f"Test audio shape: {test_audio.shape}")
print(f"Duration: {test_audio.shape[1] / 16000:.2f} seconds")

In [None]:
# Save test audio for loading experiments
import tempfile
import os

temp_dir = tempfile.mkdtemp()
wav_path = os.path.join(temp_dir, "test.wav")
torchaudio.save(wav_path, test_audio, 16000)
print(f"Saved to: {wav_path}")

In [None]:
def benchmark_loading(path: str, n_runs: int = 10):
    """Benchmark different audio loading methods."""
    results = {}
    
    # torchaudio
    start = time.perf_counter()
    for _ in range(n_runs):
        audio, sr = torchaudio.load(path)
    results['torchaudio'] = (time.perf_counter() - start) / n_runs * 1000
    
    # librosa (if available)
    try:
        import librosa
        start = time.perf_counter()
        for _ in range(n_runs):
            audio, sr = librosa.load(path, sr=None)
        results['librosa'] = (time.perf_counter() - start) / n_runs * 1000
    except ImportError:
        results['librosa'] = None
    
    # soundfile (if available)
    try:
        import soundfile as sf
        start = time.perf_counter()
        for _ in range(n_runs):
            audio, sr = sf.read(path)
        results['soundfile'] = (time.perf_counter() - start) / n_runs * 1000
    except ImportError:
        results['soundfile'] = None
    
    return results

loading_results = benchmark_loading(wav_path)
print("\nAudio Loading Benchmark (ms):")
for method, time_ms in loading_results.items():
    if time_ms is not None:
        print(f"  {method}: {time_ms:.3f} ms")

## 2. Resampling

Resampling is expensive - always pre-resample your dataset!

In [None]:
def benchmark_resampling(audio: torch.Tensor, orig_sr: int, target_sr: int, n_runs: int = 10):
    """Benchmark resampling methods."""
    results = {}
    
    # torchaudio functional
    start = time.perf_counter()
    for _ in range(n_runs):
        resampled = torchaudio.functional.resample(audio, orig_sr, target_sr)
    results['torchaudio_functional'] = (time.perf_counter() - start) / n_runs * 1000
    
    # torchaudio transform (reusable kernel)
    resampler = torchaudio.transforms.Resample(orig_sr, target_sr)
    # Warmup
    _ = resampler(audio)
    start = time.perf_counter()
    for _ in range(n_runs):
        resampled = resampler(audio)
    results['torchaudio_transform'] = (time.perf_counter() - start) / n_runs * 1000
    
    # GPU resampling (if available)
    if torch.cuda.is_available():
        audio_gpu = audio.cuda()
        resampler_gpu = torchaudio.transforms.Resample(orig_sr, target_sr).cuda()
        # Warmup
        _ = resampler_gpu(audio_gpu)
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(n_runs):
            resampled = resampler_gpu(audio_gpu)
        torch.cuda.synchronize()
        results['torchaudio_gpu'] = (time.perf_counter() - start) / n_runs * 1000
    
    return results

# Generate 44.1kHz audio
audio_44k = generate_test_audio(10.0, 44100)
print(f"Original: {audio_44k.shape[1]} samples @ 44100 Hz")

resample_results = benchmark_resampling(audio_44k, 44100, 16000)
print("\nResampling Benchmark (44.1kHz → 16kHz, 10s audio):")
for method, time_ms in resample_results.items():
    print(f"  {method}: {time_ms:.3f} ms")

## 3. Spectrogram Extraction

STFT and mel spectrogram computation.

In [None]:
class AudioFeatureExtractor:
    """Whisper-style audio feature extractor."""
    
    def __init__(
        self,
        sample_rate: int = 16000,
        n_fft: int = 400,
        hop_length: int = 160,
        n_mels: int = 80,
        device: str = "cpu"
    ):
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_mels = n_mels
        self.device = device
        
        # Create mel spectrogram transform
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            norm="slaney",
            mel_scale="slaney"
        ).to(device)
    
    def extract(self, audio: torch.Tensor) -> torch.Tensor:
        """Extract log mel spectrogram features."""
        audio = audio.to(self.device)
        
        # Compute mel spectrogram
        mel_spec = self.mel_transform(audio)
        
        # Log compression
        log_mel = torch.log(mel_spec.clamp(min=1e-10))
        
        return log_mel

# Create extractor
extractor = AudioFeatureExtractor()

# Test
features = extractor.extract(test_audio)
print(f"Input audio: {test_audio.shape}")
print(f"Output features: {features.shape}")
print(f"Feature rate: {features.shape[-1] / 5.0:.1f} frames/sec")

In [None]:
def benchmark_feature_extraction(audio: torch.Tensor, n_runs: int = 50):
    """Benchmark feature extraction."""
    results = {}
    
    # CPU extraction
    extractor_cpu = AudioFeatureExtractor(device="cpu")
    # Warmup
    _ = extractor_cpu.extract(audio)
    start = time.perf_counter()
    for _ in range(n_runs):
        _ = extractor_cpu.extract(audio)
    results['cpu'] = (time.perf_counter() - start) / n_runs * 1000
    
    # GPU extraction (if available)
    if torch.cuda.is_available():
        extractor_gpu = AudioFeatureExtractor(device="cuda")
        audio_gpu = audio.cuda()
        # Warmup
        _ = extractor_gpu.extract(audio_gpu)
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(n_runs):
            _ = extractor_gpu.extract(audio_gpu)
        torch.cuda.synchronize()
        results['gpu'] = (time.perf_counter() - start) / n_runs * 1000
    
    return results

audio_10s = generate_test_audio(10.0, 16000)
feat_results = benchmark_feature_extraction(audio_10s)
print("\nMel Spectrogram Extraction (10s audio):")
for device, time_ms in feat_results.items():
    print(f"  {device}: {time_ms:.3f} ms")

## 4. Visualization

In [None]:
def plot_audio_and_spectrogram(audio: torch.Tensor, sample_rate: int = 16000):
    """Plot waveform and mel spectrogram."""
    fig, axes = plt.subplots(3, 1, figsize=(12, 8))
    
    # Waveform
    time_axis = torch.arange(audio.shape[1]) / sample_rate
    axes[0].plot(time_axis.numpy(), audio[0].numpy())
    axes[0].set_xlabel("Time (s)")
    axes[0].set_ylabel("Amplitude")
    axes[0].set_title("Waveform")
    
    # Spectrogram
    spec_transform = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160)
    spec = spec_transform(audio)
    spec_db = 10 * torch.log10(spec.clamp(min=1e-10))
    axes[1].imshow(spec_db[0].numpy(), aspect='auto', origin='lower', cmap='viridis')
    axes[1].set_xlabel("Time frames")
    axes[1].set_ylabel("Frequency bins")
    axes[1].set_title("Spectrogram (dB)")
    
    # Mel spectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate, n_fft=512, hop_length=160, n_mels=80
    )
    mel_spec = mel_transform(audio)
    mel_db = 10 * torch.log10(mel_spec.clamp(min=1e-10))
    axes[2].imshow(mel_db[0].numpy(), aspect='auto', origin='lower', cmap='viridis')
    axes[2].set_xlabel("Time frames")
    axes[2].set_ylabel("Mel bins")
    axes[2].set_title("Mel Spectrogram (dB)")
    
    plt.tight_layout()
    plt.show()

plot_audio_and_spectrogram(test_audio, 16000)

## 5. Data Augmentation

In [None]:
class AudioAugmentation:
    """Common audio augmentations for training."""
    
    @staticmethod
    def add_noise(audio: torch.Tensor, snr_db: float = 20.0) -> torch.Tensor:
        """Add Gaussian noise at specified SNR."""
        signal_power = audio.pow(2).mean()
        noise_power = signal_power / (10 ** (snr_db / 10))
        noise = torch.randn_like(audio) * noise_power.sqrt()
        return audio + noise
    
    @staticmethod
    def time_stretch(audio: torch.Tensor, rate: float = 1.0) -> torch.Tensor:
        """Time stretch without changing pitch (simplified)."""
        # Use torchaudio's speed effect
        effects = [["speed", str(rate)]]
        augmented, _ = torchaudio.sox_effects.apply_effects_tensor(
            audio, 16000, effects
        )
        return augmented
    
    @staticmethod
    def pitch_shift(audio: torch.Tensor, semitones: float = 0.0) -> torch.Tensor:
        """Shift pitch by semitones."""
        effects = [["pitch", str(semitones * 100)]]
        augmented, _ = torchaudio.sox_effects.apply_effects_tensor(
            audio, 16000, effects
        )
        return augmented
    
    @staticmethod
    def random_crop(audio: torch.Tensor, length: int) -> torch.Tensor:
        """Random crop to fixed length."""
        if audio.shape[1] <= length:
            # Pad if too short
            pad_amount = length - audio.shape[1]
            audio = torch.nn.functional.pad(audio, (0, pad_amount))
            return audio
        
        start = torch.randint(0, audio.shape[1] - length, (1,)).item()
        return audio[:, start:start + length]

# Demonstrate augmentations
aug = AudioAugmentation()

print("Original audio:", test_audio.shape)
print("With noise:", aug.add_noise(test_audio, snr_db=20).shape)
print("Random crop:", aug.random_crop(test_audio, 48000).shape)

## 6. Complete Pipeline

In [None]:
class AudioPipeline:
    """
    Complete audio preprocessing pipeline for training.
    
    Typical use for Whisper-like models.
    """
    
    def __init__(
        self,
        sample_rate: int = 16000,
        n_mels: int = 80,
        max_duration: float = 30.0,
        augment: bool = True,
        device: str = "cpu"
    ):
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.max_samples = int(max_duration * sample_rate)
        self.augment = augment
        self.device = device
        
        # Feature extractor
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=400,
            hop_length=160,
            n_mels=n_mels
        ).to(device)
        
        self.timing = {}
    
    def __call__(self, audio: torch.Tensor, sr: int) -> torch.Tensor:
        """Process audio through full pipeline."""
        self.timing = {}
        
        # 1. Resample if needed
        start = time.perf_counter()
        if sr != self.sample_rate:
            audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
        self.timing['resample'] = (time.perf_counter() - start) * 1000
        
        # 2. Convert to mono if stereo
        start = time.perf_counter()
        if audio.shape[0] > 1:
            audio = audio.mean(dim=0, keepdim=True)
        self.timing['to_mono'] = (time.perf_counter() - start) * 1000
        
        # 3. Pad/trim to max duration
        start = time.perf_counter()
        if audio.shape[1] > self.max_samples:
            if self.augment:
                # Random crop during training
                start_idx = torch.randint(0, audio.shape[1] - self.max_samples, (1,)).item()
                audio = audio[:, start_idx:start_idx + self.max_samples]
            else:
                # Trim from start during inference
                audio = audio[:, :self.max_samples]
        else:
            # Pad with zeros
            pad_amount = self.max_samples - audio.shape[1]
            audio = torch.nn.functional.pad(audio, (0, pad_amount))
        self.timing['pad_trim'] = (time.perf_counter() - start) * 1000
        
        # 4. Augmentation (training only)
        start = time.perf_counter()
        if self.augment:
            # Add noise with probability 0.5
            if torch.rand(1).item() > 0.5:
                snr = torch.randint(15, 30, (1,)).item()
                audio = AudioAugmentation.add_noise(audio, snr)
        self.timing['augment'] = (time.perf_counter() - start) * 1000
        
        # 5. Move to device
        start = time.perf_counter()
        audio = audio.to(self.device)
        self.timing['to_device'] = (time.perf_counter() - start) * 1000
        
        # 6. Extract mel spectrogram
        start = time.perf_counter()
        mel_spec = self.mel_transform(audio)
        log_mel = torch.log(mel_spec.clamp(min=1e-10))
        self.timing['mel_extract'] = (time.perf_counter() - start) * 1000
        
        # 7. Normalize
        start = time.perf_counter()
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-8)
        self.timing['normalize'] = (time.perf_counter() - start) * 1000
        
        return log_mel
    
    def print_timing(self):
        """Print timing breakdown."""
        total = sum(self.timing.values())
        print("\nPipeline Timing:")
        for step, time_ms in self.timing.items():
            pct = 100 * time_ms / total if total > 0 else 0
            print(f"  {step:<15}: {time_ms:>8.3f} ms ({pct:>5.1f}%)")
        print(f"  {'TOTAL':<15}: {total:>8.3f} ms")

# Test pipeline
pipeline = AudioPipeline(augment=True, device="cpu")
features = pipeline(test_audio, 16000)
print(f"Input: {test_audio.shape}")
print(f"Output: {features.shape}")
pipeline.print_timing()

## 7. Memory Analysis

In [None]:
def analyze_memory(duration_sec: float = 30.0, sample_rate: int = 16000):
    """Analyze memory usage for different representations."""
    print(f"\nMemory Analysis for {duration_sec}s audio @ {sample_rate} Hz:")
    print("=" * 50)
    
    num_samples = int(duration_sec * sample_rate)
    
    # Raw audio (int16)
    raw_int16 = num_samples * 2  # 2 bytes per sample
    print(f"Raw audio (int16):     {raw_int16 / 1024:.2f} KB")
    
    # Raw audio (float32)
    raw_float32 = num_samples * 4  # 4 bytes per sample
    print(f"Raw audio (float32):   {raw_float32 / 1024:.2f} KB")
    
    # Spectrogram
    n_fft = 512
    hop_length = 160
    num_frames = (num_samples - n_fft) // hop_length + 1
    num_bins = n_fft // 2 + 1
    spec_size = num_frames * num_bins * 4  # float32
    print(f"Spectrogram:           {spec_size / 1024:.2f} KB ({num_frames} × {num_bins})")
    
    # Mel spectrogram
    n_mels = 80
    mel_size = num_frames * n_mels * 4  # float32
    print(f"Mel spectrogram:       {mel_size / 1024:.2f} KB ({num_frames} × {n_mels})")
    
    # Compression ratios
    print("\nCompression vs raw float32:")
    print(f"  Mel spectrogram: {raw_float32 / mel_size:.1f}x")

analyze_memory(30.0, 16000)

## 8. Key Takeaways

1. **Pre-resample your data** - Resampling is expensive, do it once during dataset preparation

2. **Use GPU for feature extraction** - 10-20x speedup for mel spectrograms

3. **Mel spectrograms are efficient** - Smaller than raw audio at 16kHz

4. **Augmentation is cheap** - Noise addition and time stretching are fast

5. **Batch processing helps** - Process multiple files together for better GPU utilization

In [None]:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
print("Cleaned up temporary files.")