# **Feature Extraction from Audio using Variational Autoencoder (VAE)**

This section describes the process of feature extraction from music data using a Variational Autoencoder (VAE). The workflow includes loading audio files, extracting Mel-spectrograms, training a VAE, and saving the latent representations for downstream tasks like clustering.

In [1]:
# Import necessary libraries
import numpy as np
import soundfile as sf
import librosa
from pathlib import Path
import pandas as pd
from pathlib import Path
import os
import glob
import soundfile as sf
import librosa
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

In [2]:
# Load audio files from the specified directory
audio_dir = Path("../data/audio")
audio_files = sorted(audio_dir.glob("*/*.mp3"))

print(f"Loaded {len(audio_files)} audio files")

Loaded 3554 audio files


In [3]:
# Extract Mel-spectrogram features
def load_audio(path, target_sr=22050):
    audio, sr = sf.read(path, dtype='float32')

    # Convert stereo to mono
    if audio.ndim > 1:
        audio = audio.mean(axis=1)

    # Resample if needed (keyword arguments!)
    if sr != target_sr:
        audio = librosa.resample(
            y=audio,
            orig_sr=sr,
            target_sr=target_sr
        )

    return audio

In [None]:
# Extract Mel-spectrogram features
def extract_mel(path, n_mels=64, fixed_len=1304):  
    y, sr = librosa.load(path, sr=22050)
    mel = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, hop_length=512)
    mel = np.log(mel + 1e-9)
    
    # Normalize to [0,1]
    mel = (mel - mel.min()) / (mel.max() - mel.min() + 1e-9)
    
    # Pad or truncate to fixed_len
    if mel.shape[1] < fixed_len:
        pad_width = fixed_len - mel.shape[1]
        mel = np.pad(mel, ((0,0), (0,pad_width)), mode='constant')
    else:
        mel = mel[:, :fixed_len]

    return torch.tensor(mel, dtype=torch.float32)

In [21]:
# Test the extract_mel function
x = extract_mel(audio_files[0])
print(x.shape)

torch.Size([64, 1304])


In [22]:
# Dataset Custom Class
class AudioDataset(Dataset):
    def __init__(self, audio_files, fixed_len=1304):
        self.audio_files = audio_files
        self.fixed_len = fixed_len

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        mel = extract_mel(self.audio_files[idx], fixed_len=self.fixed_len)
        return mel.unsqueeze(0)  

In [23]:
# VAE Model Definition
class VAE(nn.Module):
    def __init__(self, latent_dim=32, fixed_len=1304, n_mels=64):
        super().__init__()
        self.n_mels = n_mels
        self.fixed_len = fixed_len

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
        )

        # Compute encoded shape dynamically
        dummy_input = torch.zeros(1, 1, n_mels, fixed_len)
        h = self.encoder(dummy_input)
        self.enc_shape = h.shape[1:]  # (C,H,W)
        self.flattened_size = h.numel() // h.shape[0]

        # Latent space
        self.fc_mu = nn.Linear(self.flattened_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_size, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, self.flattened_size)

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=(0,1)),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=(0,1)),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=(0,1)),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(h.size(0), *self.enc_shape)
        recon = self.decoder(h)
        # Ensure output matches input size
        recon = torch.nn.functional.interpolate(
            recon, size=(self.n_mels, self.fixed_len), mode='bilinear', align_corners=False
        )
        return recon

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [24]:
# VAE Loss Function
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='mean')  
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld

In [25]:
# Training Loop
def train_vae(audio_files, epochs=10, batch_size=8, latent_dim=32, lr=1e-3, fixed_len=1304):
    dataset = AudioDataset(audio_files, fixed_len=fixed_len)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = VAE(latent_dim=latent_dim, fixed_len=fixed_len).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        running_loss = 0
        for batch in loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = model(batch)
            loss = vae_loss(recon, batch, mu, logvar)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(loader):.6f}")

    return model

In [26]:
# Train the VAE model
vae_model = train_vae(audio_files, epochs=10, batch_size=8, latent_dim=32, fixed_len=1304)
os.makedirs("../results/models", exist_ok=True)
torch.save(vae_model.state_dict(), "../results/models/audio_vae.pth")
print("VAE model saved successfully at '../results/models/audio_vae.pth'")

Epoch [1/10], Loss: 0.029492
Epoch [2/10], Loss: 0.017347
Epoch [3/10], Loss: 0.017001
Epoch [4/10], Loss: 0.016912
Epoch [5/10], Loss: 0.016804
Epoch [6/10], Loss: 0.016812
Epoch [7/10], Loss: 0.016667
Epoch [8/10], Loss: 0.016703
Epoch [9/10], Loss: 0.016638
Epoch [10/10], Loss: 0.016648
VAE model saved successfully at '../results/models/audio_vae.pth'
