# CNN-powered MP3 to MIDI for drums

In [2]:
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import os
import librosa
import subprocess as sp
import torchaudio
import torchvision
import torch
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from ipyfilechooser import FileChooser
import pretty_midi
from scipy.io.wavfile import write
import numpy as np

drum_labels = ['kick', 'snare', 'hihat', 'tom', 'crash', 'ride']

## Select song

After selecting a valid .mp3 file, run all cells below this one to extract drum notation.

In [3]:
fc = FileChooser()
display(fc)

FileChooser(path='C:\Work\Final Project\HeartsOnFire', filename='', title='', show_hidden=False, select_desc='…

## Define model architecture and data loader

### Model A

In [4]:
class DrumCNN(nn.Module):
    def __init__(self):
        super(DrumCNN, self).__init__()
        
        # Convolutional Layers
        self.conv1 = nn.Conv2d(2, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Fully Connected Layers
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.bn5 = nn.BatchNorm1d(512)
        self.relu5 = nn.ReLU()
        self.dropout1 = nn.Dropout(p=0.5)
        
        self.fc2 = nn.Linear(512, 256)
        self.bn6 = nn.BatchNorm1d(256)
        self.relu6 = nn.ReLU()
        self.dropout2 = nn.Dropout(p=0.4)
        
        self.fc3 = nn.Linear(256, len(drum_labels))
    
    def forward(self, x):
        # Convolutional Layers
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully Connected Layers
        x = self.fc1(x)
        x = self.bn5(x)
        x = self.relu5(x)
        x = self.dropout1(x)
        
        x = self.fc2(x)
        x = self.bn6(x)
        x = self.relu6(x)
        x = self.dropout2(x)
        
        x = self.fc3(x)
        
        return x

### Dataset

In [5]:
class DrumDataset(Dataset):
    def __init__(self, df, audio, transform, window_size=8192):
        self.df = df
        self.window_size = window_size
        self.transform = transform
        self.audio = audio
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Load the onset time, label, and track name for the given index
        row = self.df.iloc[idx]
        onset_time = row['onset_time']
        labels = row[drum_labels].astype(int).values.flatten()
        labels = torch.tensor(labels).float()

        audio = self.audio[0]
        sr = self.audio[1]

        onset_window = audio[int(onset_time*sr)-self.window_size//2:int(onset_time*sr)+self.window_size//2]
        spec = self.transform(onset_window)
        return spec, labels
    
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Lambda(lambda x: x.to("cuda")),
    torchvision.transforms.Lambda(lambda x: torch.stack([
            torchaudio.transforms.MelSpectrogram(
                n_fft=1024,
                hop_length=64,
                n_mels=128
            ).to("cuda")(x),
    # torchaudio.transforms.MFCC(
    #             n_mfcc=128,
    #             melkwargs={'n_fft': 1024, 'hop_length': 64, 'n_mels': 128}).to("cuda")(x)
            ], dim=0).to("cuda")),
    torchvision.transforms.Lambda(lambda x: torch.stack([
            torchaudio.transforms.AmplitudeToDB().to("cuda")(x[0]),
            x[0],
            # x[1]
        ], dim=0).to("cuda")),
    torchvision.transforms.Lambda(lambda x: x.to("cpu"))
])

## Load the model

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
checkpoint = torch.load('models/HeartsOnFire-v.1.0.3_nfft1024_89.76.pth')
model = DrumCNN().to(device)
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

## Preprocessing

In [7]:
command = ["demucs", "--mp3", "--mp3-bitrate", "320", "--two-stems=drums", fc.selected]
if torch.cuda.is_available():
    print("Generating splits for \""+fc.selected_filename+"\" with GPU...")
    sp.run(command)

demucs_path = os.path.join("separated/htdemucs", fc.selected_filename[:-4], "drums.mp3")

Generating splits for "ATLUS - The Fog.mp3" with GPU...


In [8]:
audio, sr = torchaudio.load(demucs_path, format="mp3")
if audio.shape[0] == 2:
    audio = torch.mean(audio, dim=0, keepdim=False)

y = audio.numpy()
onset_env = librosa.onset.onset_strength(y=y, sr=sr, n_fft=1024)
tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
onset_frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr)
onset_times = pd.DataFrame(librosa.frames_to_time(onset_frames, sr=sr), columns=['onset_time'])
onset_times[drum_labels] = False
onset_times

Unnamed: 0,onset_time,kick,snare,hihat,tom,crash,ride
0,0.592109,False,False,False,False,False,False
1,1.335147,False,False,False,False,False,False
2,1.718277,False,False,False,False,False,False
3,1.880816,False,False,False,False,False,False
4,2.078186,False,False,False,False,False,False
...,...,...,...,...,...,...,...
1069,230.968889,False,False,False,False,False,False
1070,231.154649,False,False,False,False,False,False
1071,231.340408,False,False,False,False,False,False
1072,231.711927,False,False,False,False,False,False


### Load the data

In [9]:
pred_dataset = DrumDataset(onset_times, (audio, sr), transforms)
pred_loader = DataLoader(pred_dataset, batch_size=16, shuffle=False)

## Run the model to map onsets to drums

In [10]:
model.eval()
predicted_labels=[]
with torch.no_grad():
    for i, (inputs, labels) in enumerate(tqdm(pred_loader, total=len(pred_loader), unit='batch', desc=f"Labeling")):
        inputs = inputs.to(device)
        outputs = model(inputs)
        predicted_labels.extend((outputs>0.0).cpu().numpy().tolist())
onset_times[drum_labels] = predicted_labels
onset_times

Labeling: 100%|██████████| 68/68 [00:10<00:00,  6.47batch/s]


Unnamed: 0,onset_time,kick,snare,hihat,tom,crash,ride
0,0.592109,True,False,False,False,True,False
1,1.335147,False,True,True,False,False,False
2,1.718277,False,False,True,False,False,False
3,1.880816,True,False,False,False,False,False
4,2.078186,False,False,True,False,False,False
...,...,...,...,...,...,...,...
1069,230.968889,False,False,True,False,False,False
1070,231.154649,True,False,False,False,False,False
1071,231.340408,False,True,True,False,False,False
1072,231.711927,False,False,False,False,False,False


## MIDI generation

Currently this section automatically synthesizes the .midi file using `fluidsynth`. It will be altered in the future.

In [11]:
midi = pretty_midi.PrettyMIDI(initial_tempo=tempo)

drum_program = 0  # MIDI program number for standard drum kit
drum_instrument = pretty_midi.Instrument(program=drum_program, is_drum=True)

# Iterate through the data rows
for index, row in onset_times.iterrows():
    onset_time = row['onset_time']
    if row['kick']:
        kick_note = pretty_midi.Note(
            velocity=100,  # Set the velocity (volume) of the note
            pitch=36,  # Set the MIDI note number for kick drum
            start=onset_time,  # Set the start time of the note
            end=onset_time+0.1
        )
        drum_instrument.notes.append(kick_note)
    if row['snare']:
        snare_note = pretty_midi.Note(
            velocity=100,
            pitch=38,  # Set the MIDI note number for snare drum
            start=onset_time,
            end=onset_time+0.1
        )
        drum_instrument.notes.append(snare_note)
    if row['hihat']:
        hihat_note = pretty_midi.Note(
            velocity=100,
            pitch=42,  # Set the MIDI note number for hi-hat
            start=onset_time,
            end=onset_time+0.1
        )
        drum_instrument.notes.append(hihat_note)
    if row['tom']:
        tom_note = pretty_midi.Note(
            velocity=100,
            pitch=48,  # Set the MIDI note number for tom
            start=onset_time,
            end=onset_time+0.1
        )
        drum_instrument.notes.append(tom_note)
    if row['crash']:
        crash_note = pretty_midi.Note(
            velocity=100,
            pitch=49,  # Set the MIDI note number for crash cymbal
            start=onset_time,
            end=onset_time+0.1
        )
        drum_instrument.notes.append(crash_note)
    if row['ride']:
        ride_note = pretty_midi.Note(
            velocity=100,
            pitch=51,  # Set the MIDI note number for ride cymbal
            start=onset_time,
            end=onset_time+0.1
        )
        drum_instrument.notes.append(ride_note)
# Add the instrument objects to the MIDI object
midi.instruments.extend([drum_instrument])
# Save the MIDI file
synth = midi.fluidsynth(sf2_path='JV_1080_Drums.sf2')
output_file = f'{fc.selected_filename[:-4]}_drums.wav'
write(os.path.join("separated", "htdemucs", fc.selected_filename[:-4], output_file), 44100, synth.astype(np.float32))

In [12]:
onset_times.to_csv(os.path.join("separated", "htdemucs", fc.selected_filename[:-4], "onsets.csv"), index=False)