# M2D Electronic Music Genre Inference

This notebook loads a trained M2D genre classifier and performs inference on new audio files.

In [1]:
import torch
import torchaudio
import numpy as np
from pathlib import Path
import torch.nn as nn
from tqdm.notebook import tqdm
from portable_m2d import PortableM2D



In [2]:
# Load M2D base model
model = PortableM2D(
    weight_file='m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth',
    num_classes=None
)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

 using 151 parameters, while dropped 9 out of 160 parameters from m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth
 (dropped: ['module.ar.runtime.to_spec.mel_basis', 'module.ar.runtime.to_spec.stft.wsin', 'module.ar.runtime.to_spec.stft.wcos', 'module.ar.runtime.to_spec.stft.window_mask', 'module.head.norm.running_mean'] ...)
<All keys matched successfully>


PortableM2D(
  (backbone): LocalViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
      

In [3]:
# Load trained classifier
checkpoint = torch.load('best_genre_classifier.pth')
num_classes = 5  # Set this to match your number of genres
classifier = nn.Linear(3840, num_classes).to(device)
classifier.load_state_dict(checkpoint['model_state_dict'])
classifier.eval()

# Map indices to genre names
idx_to_genre = {0: 'ambient', 1: 'drum_and_bass', 2: 'house', 3: 'techno', 4: 'trance'}  # Adjust these to match your genres

In [4]:
def predict_genre(audio_path):
    """Predict genre for a single audio file"""
    # Load audio
    waveform, sr = torchaudio.load(audio_path)
    
    # Resample if necessary
    if sr != model.cfg.sample_rate:
        resampler = torchaudio.transforms.Resample(sr, model.cfg.sample_rate)
        waveform = resampler(waveform)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Move to device
    waveform = waveform.to(device)
    
    # Get embeddings
    with torch.no_grad():
        embeddings = model(waveform.unsqueeze(0))  # Add batch dimension
        embeddings = embeddings.mean(dim=1)  # Average over time
        
        # Get predictions
        outputs = classifier(embeddings)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        pred = outputs.argmax(dim=1)
    
    # Get genre name and probability
    genre = idx_to_genre[pred.item()]
    confidence = probs[0, pred].item()
    
    # Get all probabilities
    all_probs = {idx_to_genre[i]: probs[0, i].item() for i in range(num_classes)}
    
    return {
        'predicted_genre': genre,
        'confidence': confidence,
        'all_probabilities': all_probs
    }

In [12]:
def predict_directory(directory):
    """Predict genres for all audio files in a directory"""
    directory = Path(directory)
    results = []

    
    # Get all audio files in all subdirectories
    audio_files = list(directory.glob('**/*.wav')) + list(directory.glob('**/*.mp3'))

    
    for audio_file in tqdm(audio_files, desc='Predicting genres'):
        try:
            prediction = predict_genre(audio_file)
            results.append({
                'file': str(audio_file),
                **prediction
            })
        except Exception as e:
            print(f"Error processing {audio_file}: {e}")
    
    return results

In [11]:
# Example usage for a single file
result = predict_genre("/mnt/g/glasba/minimal/02. Floating Points - Birth4000.mp3")
print(f"Predicted genre: {result['predicted_genre']} (confidence: {result['confidence']:.2%})")
print("\nAll probabilities:")
for genre, prob in result['all_probabilities'].items():
    print(f"{genre}: {prob:.2%}")

Predicted genre: techno (confidence: 100.00%)

All probabilities:
ambient: 0.00%
drum_and_bass: 0.00%
house: 0.00%
techno: 100.00%
trance: 0.00%


In [14]:
# Example usage for a directory
results = predict_directory("/mnt/g/glasba/hiša/classics")

# Display results as a table
import pandas as pd
df = pd.DataFrame(results)
display(df)

Predicting genres:   0%|          | 0/13 [00:00<?, ?it/s]

[src/libmpg123/id3.c:process_comment():584] error: No comment text / valid description?


Unnamed: 0,file,predicted_genre,confidence,all_probabilities
0,"/mnt/g/glasba/hiša/classics/010. cajmere, daja...",house,0.999964,"{'ambient': 1.9782055767775253e-12, 'drum_and_..."
1,/mnt/g/glasba/hiša/classics/08-eric_prydz-woz_...,techno,0.998484,"{'ambient': 2.4735991033253413e-10, 'drum_and_..."
2,/mnt/g/glasba/hiša/classics/11. Todd Terry - S...,house,0.964874,"{'ambient': 6.005754471516411e-07, 'drum_and_b..."
3,/mnt/g/glasba/hiša/classics/2 Feeling For You ...,house,0.988789,"{'ambient': 8.604713053195212e-10, 'drum_and_b..."
4,"/mnt/g/glasba/hiša/classics/20. Benny Benassi,...",techno,0.999986,"{'ambient': 3.225889404234218e-15, 'drum_and_b..."
5,/mnt/g/glasba/hiša/classics/305. Shakedown - A...,techno,0.701145,"{'ambient': 1.1366687755962057e-12, 'drum_and_..."
6,/mnt/g/glasba/hiša/classics/A-Trak - Bubble Gu...,house,0.601523,"{'ambient': 2.1735168331815657e-07, 'drum_and_..."
7,/mnt/g/glasba/hiša/classics/barbara tucker - b...,house,0.999607,"{'ambient': 7.305317684114243e-09, 'drum_and_b..."
8,/mnt/g/glasba/hiša/classics/Cajmere - Brighter...,house,0.999987,"{'ambient': 6.222221089063895e-12, 'drum_and_b..."
9,/mnt/g/glasba/hiša/classics/Cassius - Feeling ...,house,0.992378,"{'ambient': 1.2867519272319328e-09, 'drum_and_..."


In [15]:
df.to_csv('predictions.csv', index=False)