# Building a Neural Audio Codec from Scratch

Hands-on implementation of a neural audio codec with:
- Encoder/Decoder architecture
- Vector Quantization (VQ)
- Residual Vector Quantization (RVQ)
- Discriminator training
- Full training loop with profiling

This notebook implements a simplified version of SoundStream/EnCodec.

In [None]:
# !pip install torch torchaudio matplotlib numpy

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Optional
import time

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Building Blocks

### 1.1 Causal Convolution

In [None]:
class CausalConv1d(nn.Module):
    """1D convolution with causal padding (only past context)."""
    
    def __init__(self, in_channels: int, out_channels: int, 
                 kernel_size: int, stride: int = 1, dilation: int = 1):
        super().__init__()
        self.padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=stride, dilation=dilation
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pad on the left (causal)
        x = F.pad(x, (self.padding, 0))
        return self.conv(x)


class CausalConvTranspose1d(nn.Module):
    """Causal transposed convolution for upsampling."""
    
    def __init__(self, in_channels: int, out_channels: int,
                 kernel_size: int, stride: int = 1):
        super().__init__()
        self.conv = nn.ConvTranspose1d(
            in_channels, out_channels, kernel_size, stride=stride
        )
        self.trim = kernel_size - stride
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        if self.trim > 0:
            x = x[..., :-self.trim]
        return x


# Test causal conv
conv = CausalConv1d(1, 32, kernel_size=7)
x = torch.randn(1, 1, 100)
y = conv(x)
print(f"Input: {x.shape} -> Output: {y.shape}")

### 1.2 Residual Block

In [None]:
class ResidualBlock(nn.Module):
    """Residual block with dilated convolutions."""
    
    def __init__(self, channels: int, dilation: int = 1):
        super().__init__()
        self.block = nn.Sequential(
            nn.ELU(),
            CausalConv1d(channels, channels, kernel_size=3, dilation=dilation),
            nn.ELU(),
            CausalConv1d(channels, channels, kernel_size=1),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.block(x)


# Test
block = ResidualBlock(32)
x = torch.randn(1, 32, 100)
y = block(x)
print(f"ResBlock: {x.shape} -> {y.shape}")

## 2. Encoder Architecture

In [None]:
class Encoder(nn.Module):
    """
    Audio encoder that compresses waveform to latent representation.
    
    Architecture:
    - Initial conv
    - Multiple downsample blocks (conv + residual)
    - Final conv to latent dim
    """
    
    def __init__(
        self,
        channels: int = 32,
        latent_dim: int = 128,
        strides: List[int] = [2, 4, 5, 8],  # Total: 2*4*5*8 = 320
        num_residual: int = 3,
    ):
        super().__init__()
        self.strides = strides
        self.total_stride = np.prod(strides)
        
        layers = []
        
        # Initial conv
        layers.append(CausalConv1d(1, channels, kernel_size=7))
        
        # Downsample blocks
        in_ch = channels
        for i, stride in enumerate(strides):
            out_ch = min(in_ch * 2, 512)
            
            # Residual blocks with increasing dilation
            for j in range(num_residual):
                layers.append(ResidualBlock(in_ch, dilation=3**j))
            
            # Strided conv for downsampling
            layers.append(nn.ELU())
            layers.append(CausalConv1d(in_ch, out_ch, kernel_size=stride*2, stride=stride))
            
            in_ch = out_ch
        
        # Final residual blocks
        for j in range(num_residual):
            layers.append(ResidualBlock(in_ch, dilation=3**j))
        
        # Project to latent dim
        layers.append(nn.ELU())
        layers.append(CausalConv1d(in_ch, latent_dim, kernel_size=3))
        
        self.encoder = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: (batch, 1, samples) -> (batch, latent_dim, frames)"""
        return self.encoder(x)


# Test encoder
encoder = Encoder(channels=32, latent_dim=128)
x = torch.randn(1, 1, 16000)  # 1 second @ 16kHz
z = encoder(x)
print(f"Encoder: {x.shape} -> {z.shape}")
print(f"Compression: {x.shape[2]} -> {z.shape[2]} (stride={x.shape[2]//z.shape[2]}x)")

## 3. Vector Quantization

In [None]:
class VectorQuantizer(nn.Module):
    """
    Vector Quantization with EMA codebook update.
    
    Maps continuous vectors to discrete codebook entries.
    """
    
    def __init__(
        self,
        codebook_size: int = 1024,
        codebook_dim: int = 128,
        commitment_weight: float = 0.25,
        ema_decay: float = 0.99,
        epsilon: float = 1e-5,
    ):
        super().__init__()
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.commitment_weight = commitment_weight
        self.ema_decay = ema_decay
        self.epsilon = epsilon
        
        # Codebook (EMA updated, not gradient trained)
        self.register_buffer('embedding', torch.randn(codebook_size, codebook_dim))
        self.register_buffer('cluster_size', torch.zeros(codebook_size))
        self.register_buffer('embed_avg', self.embedding.clone())
        
        # Initialize with unit norm
        self.embedding.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
    
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            z: (batch, dim, time) encoder output
        Returns:
            z_q: quantized output (same shape as z)
            loss: commitment + codebook loss
            indices: codebook indices (batch, time)
        """
        B, D, T = z.shape
        
        # Reshape: (B, D, T) -> (B*T, D)
        z_flat = z.permute(0, 2, 1).reshape(-1, D)
        
        # Compute distances to all codebook entries
        # ||z - e||^2 = ||z||^2 + ||e||^2 - 2<z, e>
        dist = (
            z_flat.pow(2).sum(1, keepdim=True)
            + self.embedding.pow(2).sum(1)
            - 2 * z_flat @ self.embedding.t()
        )
        
        # Find nearest codebook entry
        indices = dist.argmin(dim=1)  # (B*T,)
        
        # Lookup quantized vectors
        z_q_flat = self.embedding[indices]  # (B*T, D)
        
        # EMA codebook update (during training)
        if self.training:
            # One-hot encoding
            encodings = F.one_hot(indices, self.codebook_size).float()  # (B*T, K)
            
            # Update cluster sizes
            self.cluster_size.data.mul_(self.ema_decay).add_(
                encodings.sum(0), alpha=1 - self.ema_decay
            )
            
            # Update embedding averages
            embed_sum = encodings.t() @ z_flat  # (K, D)
            self.embed_avg.data.mul_(self.ema_decay).add_(
                embed_sum, alpha=1 - self.ema_decay
            )
            
            # Normalize
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.epsilon)
                / (n + self.codebook_size * self.epsilon) * n
            )
            self.embedding.data.copy_(self.embed_avg / cluster_size.unsqueeze(1))
        
        # Commitment loss
        commitment_loss = F.mse_loss(z_flat, z_q_flat.detach())
        loss = self.commitment_weight * commitment_loss
        
        # Straight-through estimator
        z_q_flat = z_flat + (z_q_flat - z_flat).detach()
        
        # Reshape back
        z_q = z_q_flat.view(B, T, D).permute(0, 2, 1)
        indices = indices.view(B, T)
        
        return z_q, loss, indices
    
    def decode(self, indices: torch.Tensor) -> torch.Tensor:
        """Decode indices to vectors."""
        B, T = indices.shape
        z_q = self.embedding[indices.view(-1)].view(B, T, -1)
        return z_q.permute(0, 2, 1)


# Test VQ
vq = VectorQuantizer(codebook_size=1024, codebook_dim=128)
z = torch.randn(2, 128, 50)  # 2 batch, 128 dim, 50 frames
z_q, loss, indices = vq(z)
print(f"VQ: {z.shape} -> {z_q.shape}")
print(f"Indices: {indices.shape}")
print(f"Commitment loss: {loss.item():.4f}")

## 4. Residual Vector Quantization

In [None]:
class ResidualVQ(nn.Module):
    """
    Residual Vector Quantization.
    
    Applies multiple VQ layers to the residual.
    """
    
    def __init__(
        self,
        num_quantizers: int = 8,
        codebook_size: int = 1024,
        codebook_dim: int = 128,
        commitment_weight: float = 0.25,
    ):
        super().__init__()
        self.num_quantizers = num_quantizers
        
        self.quantizers = nn.ModuleList([
            VectorQuantizer(
                codebook_size=codebook_size,
                codebook_dim=codebook_dim,
                commitment_weight=commitment_weight,
            )
            for _ in range(num_quantizers)
        ])
    
    def forward(
        self, 
        z: torch.Tensor,
        num_quantizers: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """
        Args:
            z: encoder output
            num_quantizers: number of levels to use (for RVQ dropout)
        """
        if num_quantizers is None:
            num_quantizers = self.num_quantizers
        
        z_q = torch.zeros_like(z)
        residual = z
        total_loss = 0
        all_indices = []
        
        for i in range(num_quantizers):
            quantized, loss, indices = self.quantizers[i](residual)
            
            z_q = z_q + quantized
            residual = residual - quantized
            total_loss = total_loss + loss
            all_indices.append(indices)
        
        return z_q, total_loss, all_indices
    
    def decode(self, indices_list: List[torch.Tensor]) -> torch.Tensor:
        """Decode from list of indices."""
        z_q = 0
        for i, indices in enumerate(indices_list):
            z_q = z_q + self.quantizers[i].decode(indices)
        return z_q


# Test RVQ
rvq = ResidualVQ(num_quantizers=8, codebook_size=1024, codebook_dim=128)
z = torch.randn(2, 128, 50)
z_q, loss, indices = rvq(z)
print(f"RVQ: {z.shape} -> {z_q.shape}")
print(f"Number of codebook levels: {len(indices)}")
print(f"Total loss: {loss.item():.4f}")

# Compute reconstruction error at each level
print("\nReconstruction error by level:")
for n in [1, 2, 4, 8]:
    z_q_n, _, _ = rvq(z, num_quantizers=n)
    mse = F.mse_loss(z, z_q_n).item()
    print(f"  {n} levels: MSE = {mse:.6f}")

## 5. Decoder Architecture

In [None]:
class Decoder(nn.Module):
    """
    Audio decoder that reconstructs waveform from latent.
    
    Mirror of encoder with transposed convolutions for upsampling.
    """
    
    def __init__(
        self,
        channels: int = 32,
        latent_dim: int = 128,
        strides: List[int] = [8, 5, 4, 2],  # Reverse of encoder
        num_residual: int = 3,
    ):
        super().__init__()
        
        layers = []
        
        # Calculate channel progression
        channel_mult = [min(2**i, 16) for i in range(len(strides), 0, -1)]
        in_channels = [channels * m for m in channel_mult]
        
        # Initial conv from latent
        layers.append(CausalConv1d(latent_dim, in_channels[0], kernel_size=7))
        
        # Initial residual blocks
        for j in range(num_residual):
            layers.append(ResidualBlock(in_channels[0], dilation=3**j))
        
        # Upsample blocks
        for i, stride in enumerate(strides):
            in_ch = in_channels[i]
            out_ch = in_channels[i + 1] if i + 1 < len(in_channels) else channels
            
            # Transposed conv for upsampling
            layers.append(nn.ELU())
            layers.append(CausalConvTranspose1d(in_ch, out_ch, kernel_size=stride*2, stride=stride))
            
            # Residual blocks
            for j in range(num_residual):
                layers.append(ResidualBlock(out_ch, dilation=3**j))
        
        # Final conv to waveform
        layers.append(nn.ELU())
        layers.append(CausalConv1d(channels, 1, kernel_size=7))
        layers.append(nn.Tanh())  # Output in [-1, 1]
        
        self.decoder = nn.Sequential(*layers)
    
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """z: (batch, latent_dim, frames) -> (batch, 1, samples)"""
        return self.decoder(z)


# Test decoder
decoder = Decoder(channels=32, latent_dim=128)
z = torch.randn(1, 128, 50)  # 50 frames
x_hat = decoder(z)
print(f"Decoder: {z.shape} -> {x_hat.shape}")
print(f"Upsampling: {z.shape[2]} -> {x_hat.shape[2]} ({x_hat.shape[2]//z.shape[2]}x)")

## 6. Complete Codec

In [None]:
class NeuralAudioCodec(nn.Module):
    """
    Complete neural audio codec with encoder, RVQ, and decoder.
    """
    
    def __init__(
        self,
        sample_rate: int = 16000,
        channels: int = 32,
        latent_dim: int = 128,
        strides: List[int] = [2, 4, 5, 8],
        num_quantizers: int = 8,
        codebook_size: int = 1024,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.latent_dim = latent_dim
        
        self.encoder = Encoder(
            channels=channels,
            latent_dim=latent_dim,
            strides=strides,
        )
        
        self.quantizer = ResidualVQ(
            num_quantizers=num_quantizers,
            codebook_size=codebook_size,
            codebook_dim=latent_dim,
        )
        
        self.decoder = Decoder(
            channels=channels,
            latent_dim=latent_dim,
            strides=strides[::-1],  # Reverse strides for decoder
        )
        
        self.total_stride = np.prod(strides)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """Forward pass: encode, quantize, decode."""
        # Encode
        z = self.encoder(x)
        
        # Quantize
        z_q, vq_loss, indices = self.quantizer(z)
        
        # Decode
        x_hat = self.decoder(z_q)
        
        # Match lengths
        min_len = min(x.shape[-1], x_hat.shape[-1])
        x = x[..., :min_len]
        x_hat = x_hat[..., :min_len]
        
        return x_hat, vq_loss, indices
    
    def encode(self, x: torch.Tensor) -> List[torch.Tensor]:
        """Encode audio to tokens."""
        z = self.encoder(x)
        _, _, indices = self.quantizer(z)
        return indices
    
    def decode(self, indices: List[torch.Tensor]) -> torch.Tensor:
        """Decode tokens to audio."""
        z_q = self.quantizer.decode(indices)
        return self.decoder(z_q)


# Test complete codec
codec = NeuralAudioCodec(
    sample_rate=16000,
    channels=32,
    latent_dim=128,
    num_quantizers=8,
    codebook_size=1024,
)

x = torch.randn(1, 1, 16000)  # 1 second @ 16kHz
x_hat, vq_loss, indices = codec(x)

print(f"Input: {x.shape}")
print(f"Output: {x_hat.shape}")
print(f"VQ Loss: {vq_loss.item():.4f}")
print(f"Tokens per second: {len(indices)} levels Ã— {indices[0].shape[1]} frames = {len(indices) * indices[0].shape[1]} tokens")

# Count parameters
total_params = sum(p.numel() for p in codec.parameters())
print(f"Total parameters: {total_params:,}")

## 7. Loss Functions

In [None]:
class MultiScaleSpectralLoss(nn.Module):
    """
    Multi-scale spectral loss (L1 + L2 on spectrograms).
    
    Computes loss at multiple STFT resolutions.
    """
    
    def __init__(
        self,
        n_ffts: List[int] = [512, 1024, 2048],
        hop_lengths: List[int] = [128, 256, 512],
        alpha: float = 1.0,  # L1 weight
        beta: float = 1.0,   # L2 weight
    ):
        super().__init__()
        self.n_ffts = n_ffts
        self.hop_lengths = hop_lengths
        self.alpha = alpha
        self.beta = beta
    
    def forward(self, x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
        """Compute multi-scale spectral loss."""
        loss = 0
        
        for n_fft, hop_length in zip(self.n_ffts, self.hop_lengths):
            # Compute spectrograms
            x_spec = torch.stft(
                x.squeeze(1), n_fft=n_fft, hop_length=hop_length,
                window=torch.hann_window(n_fft, device=x.device),
                return_complex=True
            )
            x_hat_spec = torch.stft(
                x_hat.squeeze(1), n_fft=n_fft, hop_length=hop_length,
                window=torch.hann_window(n_fft, device=x_hat.device),
                return_complex=True
            )
            
            # Magnitude
            x_mag = x_spec.abs()
            x_hat_mag = x_hat_spec.abs()
            
            # L1 and L2 loss
            loss += self.alpha * F.l1_loss(x_hat_mag, x_mag)
            loss += self.beta * F.mse_loss(x_hat_mag, x_mag)
        
        return loss / len(self.n_ffts)


# Test spectral loss
spec_loss = MultiScaleSpectralLoss()
x = torch.randn(2, 1, 16000)
x_hat = torch.randn(2, 1, 16000)
loss = spec_loss(x, x_hat)
print(f"Spectral loss: {loss.item():.4f}")

## 8. Training Loop

In [None]:
def train_codec(
    codec: NeuralAudioCodec,
    num_steps: int = 100,
    batch_size: int = 4,
    audio_length: int = 16000,
    lr: float = 3e-4,
    device: str = 'cuda',
):
    """Simple training loop for demonstration."""
    
    codec = codec.to(device)
    optimizer = torch.optim.Adam(codec.parameters(), lr=lr)
    spec_loss_fn = MultiScaleSpectralLoss().to(device)
    
    losses = []
    
    print("Training codec...")
    start_time = time.time()
    
    for step in range(num_steps):
        # Generate random audio (replace with real data)
        x = torch.randn(batch_size, 1, audio_length, device=device)
        
        # Forward pass
        x_hat, vq_loss, _ = codec(x)
        
        # Compute losses
        recon_loss = spec_loss_fn(x, x_hat)
        total_loss = recon_loss + vq_loss
        
        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        losses.append(total_loss.item())
        
        if (step + 1) % 20 == 0:
            elapsed = time.time() - start_time
            print(f"Step {step+1}/{num_steps} | Loss: {total_loss.item():.4f} | "
                  f"Recon: {recon_loss.item():.4f} | VQ: {vq_loss.item():.4f} | "
                  f"Time: {elapsed:.1f}s")
    
    return losses


# Run training (short demo)
if torch.cuda.is_available():
    codec = NeuralAudioCodec()
    losses = train_codec(codec, num_steps=60, batch_size=4)
    
    # Plot losses
    plt.figure(figsize=(10, 4))
    plt.plot(losses)
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.grid(True)
    plt.show()
else:
    print("CUDA not available, skipping training demo")

## 9. Inference and Profiling

In [None]:
def profile_codec(codec: NeuralAudioCodec, audio_seconds: float = 10.0,
                  num_runs: int = 50):
    """Profile encode/decode latency and throughput."""
    
    device = next(codec.parameters()).device
    audio = torch.randn(1, 1, int(codec.sample_rate * audio_seconds), device=device)
    
    codec.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(5):
            tokens = codec.encode(audio)
            _ = codec.decode(tokens)
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    # Profile encoding
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(num_runs):
            tokens = codec.encode(audio)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    encode_time = (time.perf_counter() - start) / num_runs * 1000
    
    # Profile decoding
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(num_runs):
            _ = codec.decode(tokens)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    decode_time = (time.perf_counter() - start) / num_runs * 1000
    
    # Calculate metrics
    rtf_encode = encode_time / 1000 / audio_seconds
    rtf_decode = decode_time / 1000 / audio_seconds
    
    print(f"\nCodec Profiling ({audio_seconds}s audio):")
    print(f"  Encode: {encode_time:.2f}ms (RTF: {rtf_encode:.4f})")
    print(f"  Decode: {decode_time:.2f}ms (RTF: {rtf_decode:.4f})")
    print(f"  Total RTF: {rtf_encode + rtf_decode:.4f}")
    print(f"  {1/(rtf_encode + rtf_decode):.0f}x faster than real-time")


if torch.cuda.is_available():
    profile_codec(codec, audio_seconds=10.0)

## 10. Exercises

1. **Add a discriminator** for adversarial training (like EnCodec)
2. **Implement RVQ dropout** for variable bitrate
3. **Add mel spectrogram loss** for better perceptual quality
4. **Train on real audio** from LibriSpeech or similar
5. **Implement streaming** encode/decode for real-time use

In [None]:
print("\n" + "="*60)
print("KEY TAKEAWAYS")
print("="*60)
print("""
1. Neural codecs = Encoder + VQ + Decoder
   - Encoder compresses waveform
   - VQ discretizes for LLM compatibility
   - Decoder reconstructs waveform

2. RVQ enables high-fidelity with small codebooks
   - Each level refines the residual
   - Variable bitrate by using fewer levels

3. EMA codebook updates are crucial
   - More stable than gradient descent
   - Avoids codebook collapse

4. Spectral loss improves quality
   - Multi-scale for different frequencies
   - Perceptually meaningful

5. Real systems add discriminators
   - Adversarial training for sharper output
   - Multi-scale and multi-period discriminators
""")