# Complete Audio LLM Inference Guide

Hands-on guide for using audio language models including Moshi, EnCodec, and Whisper integration.

**Covers:**
- Audio tokenization with neural codecs
- Semantic token extraction
- LLM inference with audio tokens
- Multi-stream processing
- Performance optimization

In [None]:
import torch
import torchaudio
import numpy as np
import time
from typing import List, Tuple, Optional

print(f"PyTorch: {torch.__version__}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

## 1. Audio Tokenization with EnCodec

In [None]:
try:
    from encodec import EncodecModel
    from encodec.utils import convert_audio
    
    # Load EnCodec
    encodec = EncodecModel.encodec_model_24khz()
    encodec.set_target_bandwidth(6.0)  # 6 kbps
    encodec = encodec.cuda() if torch.cuda.is_available() else encodec
    encodec.eval()
    
    print("✓ EnCodec loaded")
    ENCODEC_AVAILABLE = True
except ImportError:
    print("EnCodec not available. Install: pip install encodec")
    ENCODEC_AVAILABLE = False

In [None]:
if ENCODEC_AVAILABLE:
    # Generate test audio
    test_audio = torch.randn(1, 1, 24000 * 5).to(encodec.device)  # 5 seconds
    
    # Encode to tokens
    with torch.no_grad():
        encoded_frames = encodec.encode(test_audio)
    
    # Extract codes
    codes = [frame[0] for frame in encoded_frames]
    codes_tensor = torch.cat(codes, dim=-1)
    
    print(f"Audio shape: {test_audio.shape}")
    print(f"Encoded codes shape: {codes_tensor.shape}")
    print(f"Frame rate: {codes_tensor.shape[-1] / 5:.1f} Hz")
    print(f"Tokens per second: {codes_tensor.shape[1] * codes_tensor.shape[-1] / 5:.0f}")
    
    # Decode back
    with torch.no_grad():
        reconstructed = encodec.decode(encoded_frames)
    
    print(f"Reconstructed shape: {reconstructed.shape}")
    
    # Measure reconstruction quality
    mse = torch.nn.functional.mse_loss(test_audio, reconstructed)
    print(f"Reconstruction MSE: {mse.item():.6f}")

## 2. Semantic Token Extraction with WavLM

In [None]:
try:
    from transformers import WavLMModel, Wav2Vec2Processor
    
    # Load WavLM
    wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")
    processor = Wav2Vec2Processor.from_pretrained("microsoft/wavlm-base")
    wavlm = wavlm.cuda() if torch.cuda.is_available() else wavlm
    wavlm.eval()
    
    print("✓ WavLM loaded")
    WAVLM_AVAILABLE = True
except ImportError:
    print("WavLM not available. Install: pip install transformers")
    WAVLM_AVAILABLE = False

In [None]:
if WAVLM_AVAILABLE:
    # Generate test audio (16kHz for WavLM)
    test_audio_16k = torch.randn(1, 16000 * 5).to(wavlm.device)
    
    # Extract semantic features
    with torch.no_grad():
        outputs = wavlm(test_audio_16k, output_hidden_states=True)
    
    # Get features from layer 7 (good for semantic content)
    semantic_features = outputs.hidden_states[7]
    
    print(f"Audio shape: {test_audio_16k.shape}")
    print(f"Semantic features shape: {semantic_features.shape}")
    print(f"Feature rate: {semantic_features.shape[1] / 5:.1f} Hz")
    
    # In practice, quantize these features with k-means
    # to get discrete semantic tokens (like Mimi does)

## 3. Multi-Stream Token Processing

In [None]:
class MultiStreamTokenizer:
    """
    Tokenize audio for multi-stream LLM (Moshi-style).
    
    Streams:
    - User audio (8 RVQ levels per frame)
    - System audio (8 RVQ levels per frame)
    - Text (1 token per ~4 frames)
    """
    def __init__(self, codec, text_tokenizer, frame_rate=12.5):
        self.codec = codec
        self.text_tokenizer = text_tokenizer
        self.frame_rate = frame_rate
        self.text_rate = frame_rate / 4  # Text every 4 audio frames
    
    def tokenize(self, user_audio, system_audio, text):
        """
        Tokenize all streams.
        
        Returns:
            Interleaved token sequence for LLM
        """
        # Tokenize audio
        with torch.no_grad():
            user_tokens = self.codec.encode(user_audio)
            system_tokens = self.codec.encode(system_audio)
        
        # Tokenize text
        text_tokens = self.text_tokenizer(text)
        
        # Interleave streams
        num_frames = user_tokens.shape[-1]
        interleaved = []
        
        for t in range(num_frames):
            # User audio (8 tokens)
            interleaved.extend(user_tokens[:, :, t].flatten().tolist())
            
            # System audio (8 tokens)
            interleaved.extend(system_tokens[:, :, t].flatten().tolist())
            
            # Text (every 4 frames)
            if t % 4 == 0 and t // 4 < len(text_tokens):
                interleaved.append(text_tokens[t // 4])
        
        return torch.tensor(interleaved)


# Demo
if ENCODEC_AVAILABLE:
    tokenizer = MultiStreamTokenizer(encodec, lambda x: [1, 2, 3])  # Dummy text tokenizer
    
    user_audio = torch.randn(1, 1, 24000).to(encodec.device)
    system_audio = torch.randn(1, 1, 24000).to(encodec.device)
    
    tokens = tokenizer.tokenize(user_audio, system_audio, "test")
    print(f"\nInterleaved tokens: {tokens.shape}")
    print(f"Tokens for 1 second: {len(tokens)} tokens")

## 4. Performance Profiling

In [None]:
def profile_audio_tokenization(codec, audio_seconds=10.0, num_runs=50):
    """
    Profile encode/decode latency.
    """
    device = next(codec.parameters()).device
    audio = torch.randn(1, 1, int(24000 * audio_seconds), device=device)
    
    # Warmup
    with torch.no_grad():
        for _ in range(5):
            frames = codec.encode(audio)
            _ = codec.decode(frames)
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    # Profile encoding
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(num_runs):
            frames = codec.encode(audio)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    encode_time = (time.perf_counter() - start) / num_runs * 1000
    
    # Profile decoding
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(num_runs):
            _ = codec.decode(frames)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    decode_time = (time.perf_counter() - start) / num_runs * 1000
    
    print(f"\nCodec Profiling ({audio_seconds}s audio):")
    print(f"  Encode: {encode_time:.2f}ms (RTF: {encode_time/1000/audio_seconds:.4f})")
    print(f"  Decode: {decode_time:.2f}ms (RTF: {decode_time/1000/audio_seconds:.4f})")
    print(f"  Total RTF: {(encode_time+decode_time)/1000/audio_seconds:.4f}")
    print(f"  Throughput: {1/((encode_time+decode_time)/1000/audio_seconds):.0f}x real-time")


if ENCODEC_AVAILABLE and torch.cuda.is_available():
    profile_audio_tokenization(encodec, audio_seconds=10.0)

## 5. Key Takeaways

1. **Neural codecs enable LLM compatibility** - Discrete tokens from continuous audio
2. **Semantic tokens improve understanding** - WavLM features capture meaning
3. **Multi-stream processing** - Handle user/system audio + text simultaneously
4. **Profiling is critical** - Ensure real-time performance
5. **Token rate matters** - Lower rate = longer context for LLM