In [1]:
import soundfile
import librosa
import os
import pandas as pd
from IPython.display import Audio
import matplotlib.pyplot as plt
import librosa.display as display
import numpy as np
import json
import torch as torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import time
from random import randrange

import torchcrepe

import warnings
warnings.filterwarnings("ignore")

In [2]:
base_data_dir = '/home/purnima/appdir/Github/DATA/NSynth/' #Path to your Downloaded NSynth dataset
train_data_dir = os.path.join(base_data_dir,'nsynth-train', 'audio')
test_data_dir = os.path.join(base_data_dir,'nsynth-test', 'audio')
validate_data_dir = os.path.join(base_data_dir,'nsynth-valid', 'audio')

labels_dir = '/home/purnima/appdir/Github/DATA/NSynth'
labels_file_name = 'examples-subset-full-acoustic-3000.json'

labels_train_dir = os.path.join(labels_dir,'nsynth-train', labels_file_name)
labels_test_dir = os.path.join(labels_dir,'nsynth-test', labels_file_name)
labels_validate_dir = os.path.join(labels_dir,'nsynth-valid', labels_file_name)

sample_rate = 16000
sample_length = 2048
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
# This class is different than the one in the code base. This will be used to source dataset for pYIN and CREPE
# which need only raw audio data and not the STFT channels.
class NSynthDataSet_RawAudio(Dataset):
    def __init__(self, meta_data_file, audio_dir, lower_pitch_limit, upper_pitch_limit, sr=16000):
        self.meta_data_file = meta_data_file
        self.audio_dir = audio_dir
        self.sr = sr
        
        with open(meta_data_file) as f:
            params = json.load(f)
            self.nsynth_meta_df = pd.DataFrame.from_dict(params)
            self.nsynth_meta_df = self.nsynth_meta_df.transpose()
            self.nsynth_meta_df = self.nsynth_meta_df[self.nsynth_meta_df['instrument_family_str'] == 'guitar']
            self.nsynth_meta_df = self.nsynth_meta_df[(self.nsynth_meta_df['pitch'] >= lower_pitch_limit) \
                                                      & (self.nsynth_meta_df['pitch'] < upper_pitch_limit)]
            
            self.nsynth_meta_df['part'] = 1
            nsynth_meta_df_2 = self.nsynth_meta_df.copy(deep=True)
            nsynth_meta_df_2['part'] = 2
            nsynth_meta_df_2.index = nsynth_meta_df_2.index + '-2'
            nsynth_meta_df_3 = self.nsynth_meta_df.copy(deep=True)
            nsynth_meta_df_3['part'] = 3
            nsynth_meta_df_3.index = nsynth_meta_df_3.index + '-3'
            nsynth_meta_df_4 = self.nsynth_meta_df.copy(deep=True)
            nsynth_meta_df_4['part'] = 4
            nsynth_meta_df_4.index = nsynth_meta_df_4.index + '-4'
            self.nsynth_meta_df = pd.concat([self.nsynth_meta_df, nsynth_meta_df_2, nsynth_meta_df_3, nsynth_meta_df_4])
        
    def __len__(self):
        return self.nsynth_meta_df.shape[0]

    
    def __getitem__(self, idx):
        
        if torch.is_tensor(idx): #In case we get [0] instead of 0
            idx = idx.tolist()
        audio_file_name = self.nsynth_meta_df.iloc[idx].note_str + '.wav'
        audio_pitch = self.nsynth_meta_df.iloc[idx].pitch
        audio_data, _ = librosa.load(os.path.join(self.audio_dir, audio_file_name), sr=self.sr)
        
        mult = 0.25 + ((self.nsynth_meta_df.iloc[idx].part - 1) * 0.5)
        start_location = int(16000 * mult)
        
        audio_data = audio_data[start_location:start_location+sample_length]
        return audio_data, audio_pitch

In [4]:
#DS for MIDI 21 to 40
lower_freq_validate_ds = NSynthDataSet_RawAudio(meta_data_file=labels_validate_dir, audio_dir=validate_data_dir, lower_pitch_limit=21, upper_pitch_limit=41, sr=sample_rate)
lower_freq_validate_loader = torch.utils.data.DataLoader(lower_freq_validate_ds, batch_size=1, shuffle=False)

#DS for MIDI 41 to 80
upper_freq_validate_ds = NSynthDataSet_RawAudio(meta_data_file=labels_validate_dir, audio_dir=validate_data_dir, lower_pitch_limit=41, upper_pitch_limit=81, sr=sample_rate)
upper_freq_validate_loader = torch.utils.data.DataLoader(upper_freq_validate_ds, batch_size=1, shuffle=False)


In [5]:
def validate_pyin(dl):
    num_correct_pyin = 0
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dl):
            data = data[0].float()
            data = data.numpy()
            
            pyin_f0, _, _ = librosa.pyin(data, fmin=librosa.note_to_hz('A0'), fmax=librosa.note_to_hz('A5'), sr=16000, frame_length=1000, hop_length=1000)
            pyin_f0 = np.nan_to_num(pyin_f0)
            pyin_f0 = np.max(pyin_f0)
            
            prediction = librosa.core.hz_to_midi(pyin_f0)
            if target - 0.5 < prediction < target + 0.5: #Tolerance of 0.5
                num_correct_pyin += 1
                
            
    print('Total correct = ', num_correct_pyin, ' i.e.{:.2f}%'.format(num_correct_pyin * 100/(len(dl.dataset))))

In [6]:
validate_pyin(lower_freq_validate_loader) #pYIN algo takes time to execute. Around 5-10 mins.
validate_pyin(upper_freq_validate_loader)

Total correct =  749  i.e.55.24%
Total correct =  2517  i.e.93.78%


In [7]:
def validate_crepe(dl):
    #crepe
    num_correct_crepe = 0
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dl):
            data = data[0].float()
            data = data.numpy()
            
            data_tensor = torch.from_numpy(data).view(1, -1)
            crepe_f0 = torchcrepe.predict(data_tensor,
                           16000,
                           2048,
                           40,
                           3400,
                           'full',
                           batch_size=1,
                           device=device)
            crepe_f0 = np.mean(crepe_f0.numpy())
            
            prediction = librosa.core.hz_to_midi(crepe_f0)
            if target - 0.5 < prediction < target + 0.5: #Tolerance of 0.5
                num_correct_crepe += 1
            
    print('Total correct = ', num_correct_crepe, ' i.e.{:.2f}%'.format(num_correct_crepe * 100/(len(dl.dataset))))

In [None]:
validate_crepe(lower_freq_validate_loader)
validate_crepe(upper_freq_validate_loader)

Total correct =  240  i.e.17.70%
