In [None]:
"""
Audio Tower - Spectrogram-Aware Acoustic Encoding
Encodes audio spectrograms into acoustic embeddings
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import numpy as np
from typing import Optional, Tuple
import torchaudio.transforms as T

In [None]:
class AudioEncoder(nn.Module):
    """
    Audio tower for encoding spectrograms into acoustic embeddings
    
    Architecture options:
    - ResNet-18 for image-like spectrogram processing
    - Audio Spectrogram Transformer (AST) for patch-based encoding
    """
    
    def __init__(
        self,
        embedding_dim=768,
        architecture='resnet',  # 'resnet' or 'ast'
        sample_rate=44100,
        n_mels=128,
        n_fft=2048,
        hop_length=512,
        use_archetype_supervision=True,
        device='cpu'
    ):
        super(AudioEncoder, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.architecture = architecture
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.device = device
        
        # Mel spectrogram transform
        self.mel_transform = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            normalized=True
        )
        
        # Choose backbone architecture
        if architecture == 'resnet':
            self.backbone = ResNetAudioBackbone(embedding_dim)
        elif architecture == 'ast':
            self.backbone = AudioSpectrogramTransformer(embedding_dim, n_mels)
        else:
            raise ValueError(f"Unknown architecture: {architecture}")
        
        # Auxiliary archetype prediction head (for supervision)
        self.use_archetype_supervision = use_archetype_supervision
        if use_archetype_supervision:
            self.archetype_classifier = nn.Sequential(
                nn.Linear(embedding_dim, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, 5),  # 5 archetypes
                nn.Softmax(dim=1)
            )
        
        self.to(device)
    
    def extract_spectrogram(self, audio: torch.Tensor) -> torch.Tensor:
        """
        Convert audio waveform to log-mel spectrogram
        
        Args:
            audio: Audio tensor of shape (batch, samples) or (samples,)
        
        Returns:
            Log-mel spectrogram of shape (batch, 1, n_mels, time)
        """
        # Ensure batch dimension
        if audio.dim() == 1:
            audio = audio.unsqueeze(0)
        
        # Compute mel spectrogram
        mel_spec = self.mel_transform(audio)
        
        # Convert to log scale
        log_mel_spec = torch.log(mel_spec + 1e-9)
        
        # Add channel dimension if needed
        if log_mel_spec.dim() == 3:
            log_mel_spec = log_mel_spec.unsqueeze(1)
        
        return log_mel_spec
    
    def forward(
        self, 
        audio: torch.Tensor, 
        return_archetype_pred=False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Encode audio into embeddings
        
        Args:
            audio: Audio waveform tensor (batch, samples)
            return_archetype_pred: Whether to return archetype predictions
        
        Returns:
            - Audio embeddings (batch, embedding_dim)
            - Optional archetype predictions (batch, 5)
        """
        # Extract spectrogram
        spectrogram = self.extract_spectrogram(audio)
        
        # Pass through backbone
        embeddings = self.backbone(spectrogram)
        
        # L2 normalize for contrastive learning
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        # Optional archetype prediction
        archetype_pred = None
        if return_archetype_pred and self.use_archetype_supervision:
            archetype_pred = self.archetype_classifier(embeddings)
        
        return embeddings, archetype_pred
    
    def load_audio_file(self, audio_path: str, duration: Optional[float] = None) -> torch.Tensor:
        """
        Load audio file and convert to tensor
        
        Args:
            audio_path: Path to audio file
            duration: Optional duration to load (in seconds)
        
        Returns:
            Audio tensor
        """
        audio, sr = librosa.load(audio_path, sr=self.sample_rate, duration=duration)
        audio_tensor = torch.from_numpy(audio).float().to(self.device)
        return audio_tensor

In [None]:

class ResNetAudioBackbone(nn.Module):
    """
    ResNet-18 based backbone for spectrogram encoding
    Treats spectrograms as images
    """
    
    def __init__(self, embedding_dim=768):
        super(ResNetAudioBackbone, self).__init__()
        
        # Initial convolution
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # ResNet blocks
        self.layer1 = self._make_layer(64, 64, blocks=2)
        self.layer2 = self._make_layer(64, 128, blocks=2, stride=2)
        self.layer3 = self._make_layer(128, 256, blocks=2, stride=2)
        self.layer4 = self._make_layer(256, 512, blocks=2, stride=2)
        
        # Global pooling and projection
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, embedding_dim)
    
    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        """Create a ResNet layer with residual blocks"""
        layers = []
        
        # First block may downsample
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        
        # Remaining blocks
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Initial layers
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # ResNet blocks
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Pool and project
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

In [None]:
class ResidualBlock(nn.Module):
    """Basic residual block for ResNet"""
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(residual)
        out = self.relu(out)
        
        return out

In [None]:
class AudioSpectrogramTransformer(nn.Module):
    """
    Audio Spectrogram Transformer (AST) using patch embeddings
    Inspired by Vision Transformer for audio
    """
    
    def __init__(self, embedding_dim=768, n_mels=128, patch_size=16):
        super(AudioSpectrogramTransformer, self).__init__()
        
        self.patch_size = patch_size
        self.n_mels = n_mels
        
        # Patch embedding layer
        self.patch_embed = nn.Conv2d(
            in_channels=1,
            out_channels=embedding_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        
        # Positional embedding (learnable)
        self.num_patches = (n_mels // patch_size) * 10  # Assuming ~10 time patches
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embedding_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=8,
            dim_feedforward=embedding_dim * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=6)
        
        # Classification token (like BERT's [CLS])
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Extract patches
        patches = self.patch_embed(x)  # (batch, embed_dim, h_patches, w_patches)
        patches = patches.flatten(2).transpose(1, 2)  # (batch, num_patches, embed_dim)
        
        # Add positional embeddings
        patches = patches + self.pos_embed[:, :patches.size(1), :]
        
        # Add classification token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        patches = torch.cat([cls_tokens, patches], dim=1)
        
        # Transformer encoding
        encoded = self.transformer(patches)
        
        # Use classification token as final embedding
        embedding = encoded[:, 0]
        
        return embedding

In [None]:
class AudioFeatureExtractor:
    """
    Extract traditional audio features for analysis and validation
    """
    
    def __init__(self, sample_rate=44100):
        self.sample_rate = sample_rate
    
    def extract_features(self, audio: np.ndarray) -> dict:
        """
        Extract comprehensive audio features
        
        Returns dict with:
        - spectral_centroid, spectral_rolloff, spectral_bandwidth
        - zero_crossing_rate, rms_energy, harmonic_ratio
        - mfccs
        """
        features = {}
        
        # Spectral features
        features['spectral_centroid'] = np.mean(
            librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate)
        )
        features['spectral_rolloff'] = np.mean(
            librosa.feature.spectral_rolloff(y=audio, sr=self.sample_rate)
        )
        features['spectral_bandwidth'] = np.mean(
            librosa.feature.spectral_bandwidth(y=audio, sr=self.sample_rate)
        )
        
        # Zero crossing rate (indicates noisiness)
        features['zero_crossing_rate'] = np.mean(
            librosa.feature.zero_crossing_rate(audio)
        )
        
        # RMS energy
        features['rms_energy'] = np.mean(librosa.feature.rms(y=audio))
        
        # Harmonic/percussive separation
        y_harmonic, y_percussive = librosa.effects.hpss(audio)
        features['harmonic_ratio'] = (
            np.mean(np.abs(y_harmonic)) / (np.mean(np.abs(audio)) + 1e-10)
        )
        
        # MFCCs
        mfccs = librosa.feature.mfcc(y=audio, sr=self.sample_rate, n_mfcc=13)
        features['mfccs'] = np.mean(mfccs, axis=1)
        
        return features
    

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
# Initialize audio encoder with ResNet
audio_encoder_resnet = AudioEncoder(
    embedding_dim=768,
    architecture='resnet',
    use_archetype_supervision=True,
    device=device
)

In [None]:
# Test with random audio
batch_size = 4
audio_length = 44100 * 2  # 2 seconds
dummy_audio = torch.randn(batch_size, audio_length).to(device)

embeddings, archetype_pred = audio_encoder_resnet(
    dummy_audio, 
    return_archetype_pred=True
)

print(f"Audio embeddings shape: {embeddings.shape}")
print(f"Archetype predictions shape: {archetype_pred.shape}")
print(f"Sample embedding norm: {torch.norm(embeddings[0]).item():.4f}")
print(f"Sample archetype prediction: {archetype_pred[0]}")

In [None]:
# Test AST architecture
print("\n--- Testing AST Architecture ---")
audio_encoder_ast = AudioEncoder(
    embedding_dim=768,
    architecture='ast',
    device=device
)

embeddings_ast, _ = audio_encoder_ast(dummy_audio)
print(f"AST embeddings shape: {embeddings_ast.shape}")