# M2D Electronic Music Genre Classification

This notebook implements genre classification using M2D embeddings with a linear classifier on top.

In [3]:
# Install required packages
!pip install timm einops nnAudio librosa wget

Collecting wget
  Using cached wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25ldone
[?25h  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9655 sha256=cce7401b45382f05ec3302ee69651bb92ee825cbc7e46c2caee30c885a8dc852
  Stored in directory: /home/ziga/.cache/pip/wheels/01/46/3b/e29ffbe4ebe614ff224bad40fc6a5773a67a163251585a13a9
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2


In [4]:
import torch
import torchaudio
import numpy as np
from pathlib import Path
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm.notebook import tqdm
import zipfile
import wget
from tqdm.notebook import tqdm

In [5]:
# Download M2D model files
!wget https://raw.githubusercontent.com/nttcslab/m2d/master/examples/portable_m2d.py
!wget https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip

# Extract the model weights
with zipfile.ZipFile("m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip", "r") as zip_ref:
    zip_ref.extractall(".")

--2025-01-04 01:19:58--  https://raw.githubusercontent.com/nttcslab/m2d/master/examples/portable_m2d.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 15862 (15K) [text/plain]
Saving to: ‘portable_m2d.py’


2025-01-04 01:19:58 (8.28 MB/s) - ‘portable_m2d.py’ saved [15862/15862]

--2025-01-04 01:19:58--  https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/589370928/0bdeb8a7-c3f3-44c5-afb9-9b9edaa3e861?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credenti

In [6]:
# Load base M2D model
from portable_m2d import PortableM2D

# Initialize model without classification head (we'll add our own)
model = PortableM2D(
    weight_file='m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth',
    num_classes=None  # Set to None to get embeddings instead of classification
)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()  # Set to evaluation mode



 using 151 parameters, while dropped 9 out of 160 parameters from m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth
 (dropped: ['module.ar.runtime.to_spec.mel_basis', 'module.ar.runtime.to_spec.stft.wsin', 'module.ar.runtime.to_spec.stft.wcos', 'module.ar.runtime.to_spec.stft.window_mask', 'module.head.norm.running_mean'] ...)
<All keys matched successfully>


PortableM2D(
  (backbone): LocalViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
      

In [7]:
# Audio preprocessing settings from the paper
SAMPLE_RATE = model.cfg.sample_rate  # Use M2D's sample rate
N_FFT = int(0.025 * SAMPLE_RATE)  # 25ms window
HOP_LENGTH = int(0.010 * SAMPLE_RATE)  # 10ms hop
N_MELS = 80
F_MIN = 50
F_MAX = 8000
AUDIO_MEAN = -7.1
AUDIO_STD = 4.2

In [16]:
class AudioDataset(Dataset):
    def __init__(self, data_dir, labels_file=None, transform=None):
        self.data_dir = Path(data_dir)
        self.transform = transform
        
        # Initialize mel spectrogram transform
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH,
            n_mels=N_MELS,
            f_min=F_MIN,
            f_max=F_MAX
        )
        
        # Load labels from CSV
        self.labels_df = pd.read_csv('genre_dataset.csv')
        self.files = [self.data_dir / f for f in self.labels_df['path'].values]
        self.labels = self.labels_df['genre'].values
            
        # Convert genre names to indices
        self.label_to_idx = {label: idx for idx, label in enumerate(sorted(set(self.labels)))}
        self.labels = [self.label_to_idx[label] for label in self.labels]
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        # Load audio file
        waveform, sr = torchaudio.load(self.files[idx])
        
        # Resample if necessary
        if sr != SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
            waveform = resampler(waveform)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        return waveform, self.labels[idx]

In [36]:
def train_linear_classifier(model, train_loader, val_loader, num_classes, device='cuda', patience=10):
    classifier = nn.Linear(3840, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(classifier.parameters(), lr=0.1)
    
    num_epochs = 20
    best_acc = 0
    patience_counter = 0
    
    # For early stopping
    best_val_acc = 0
    epochs_without_improvement = 0
    
    for epoch in tqdm(range(num_epochs), desc='Training'):
        classifier.train()
        train_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, leave=False)):
            inputs, targets = inputs.to(device), targets.to(device)
            
            with torch.no_grad():
                embeddings = model(inputs)
                embeddings = embeddings.mean(dim=1)
            
            outputs = classifier(embeddings)
            loss = criterion(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        train_acc = 100. * correct / total
        
        # Validation
        classifier.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                embeddings = model(inputs)
                embeddings = embeddings.mean(dim=1)
                outputs = classifier(embeddings)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        val_acc = 100. * correct / total
        
        print(f'Epoch: {epoch}')
        print(f'Train Loss: {train_loss/len(train_loader):.3f} | Train Acc: {train_acc:.3f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.3f} | Val Acc: {val_acc:.3f}%')
        
        # Save if it's the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_without_improvement = 0
            # Save both classifier and metadata
            torch.save({
                'epoch': epoch,
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'train_acc': train_acc,
            }, 'best_genre_classifier.pth')
            print(f'New best model saved with validation accuracy: {val_acc:.3f}%')
        else:
            epochs_without_improvement += 1
            
        # Early stopping check
        if epochs_without_improvement >= patience:
            print(f'Early stopping after {epoch + 1} epochs without improvement')
            break
    
    return classifier

In [37]:
# Create dataset
dataset = AudioDataset('.')

# Split into train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [38]:
# Train classifier
num_classes = len(dataset.label_to_idx)
classifier = train_linear_classifier(model, train_loader, val_loader, num_classes)

Training:   0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 0
Train Loss: 11.657 | Train Acc: 24.000%
Val Loss: 23.623 | Val Acc: 36.000%
New best model saved with validation accuracy: 36.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 1
Train Loss: 19.872 | Train Acc: 16.000%
Val Loss: 12.151 | Val Acc: 52.000%
New best model saved with validation accuracy: 52.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 2
Train Loss: 3.624 | Train Acc: 72.000%
Val Loss: 0.322 | Val Acc: 84.000%
New best model saved with validation accuracy: 84.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 3
Train Loss: 0.396 | Train Acc: 91.000%
Val Loss: 1.060 | Val Acc: 80.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 4
Train Loss: 0.174 | Train Acc: 96.000%
Val Loss: 1.230 | Val Acc: 80.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 5
Train Loss: 0.058 | Train Acc: 96.000%
Val Loss: 0.922 | Val Acc: 76.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 6
Train Loss: 0.021 | Train Acc: 99.000%
Val Loss: 0.931 | Val Acc: 76.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 7
Train Loss: 0.008 | Train Acc: 100.000%
Val Loss: 1.103 | Val Acc: 76.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 8
Train Loss: 0.004 | Train Acc: 100.000%
Val Loss: 1.027 | Val Acc: 76.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 9
Train Loss: 0.002 | Train Acc: 100.000%
Val Loss: 1.029 | Val Acc: 76.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 10
Train Loss: 0.002 | Train Acc: 100.000%
Val Loss: 1.010 | Val Acc: 76.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 11
Train Loss: 0.002 | Train Acc: 100.000%
Val Loss: 1.007 | Val Acc: 76.000%


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch: 12
Train Loss: 0.002 | Train Acc: 100.000%
Val Loss: 1.000 | Val Acc: 76.000%
Early stopping after 13 epochs without improvement
