In [1]:
import torch 
from torch.utils.data import Dataset
import pandas as pd
import os
import pickle
import sys
sys.path.append('../')
import numpy as np

from onsets import input_features_extractor

In [5]:
filters = {
    "drummer" : None,
    "session" : None,
    "loop_id" : None, 
    "master_id" : None,
    "style_primary" : None,
    "bpm"  : [120],
    "beat_type" :["beat"],
    "time_signature" : ["4-4"],
    "full_midi_filename"  : None,
    "full_audio_filename": None
}

sr = 44100
FRAME_INTERVAL = 0.01
hop_length =  int(round(FRAME_INTERVAL * sr))

input_features_parameters = {
    "sr" : sr,
    "n_fft" : 1024,
    "win_length" : 1024,
    "hop_length" : hop_length,
    "n_bins_per_octave" : 16,
    "n_octaves" : 9,
    "f_min" : 40,
    "mean_filter_size" : 22,
    "n_bars" : 2, 
    "time_signature_numerator" : 4, 
    "time_signature_denominator" : 4, 
    "beat_division_factors" : [4],
    "qpm" : 120
}

a = get_input_features("./misc/temp.wav", **input_features_parameters)
print(a.shape) #34? should be 32?

torch.Size([34, 16])


In [6]:
def check_if_passes_filters(obj,filters):
    for key in filters:
        if filters[key] is not None and obj.to_dict()[key] not in filters[key]:
            return False        
    return True

class GrooveMidiDataset(Dataset):
    def __init__(self,
                source_path = '../processed_dataset/hvo_0.2.0/Processed_On_09_05_2021_at_23_06_hrs',
                subset = 'GrooveMIDI_processed_test',
                metadata_csv_filename='metadata.csv',
                hvo_pickle_filename='hvo_sequence_data.obj',
                filters=filters,
                input_features_parameters=input_features_parameters,
                synthesized_audio_path = '../synthesized_audio/hvo_0.2.0/Processed_On_09_05_2021_at_23_06_hrs',
                sf_path = "../hvo_sequence/soundfonts/Standard_Drum_Kit.sf2",
                max_len= 32):
        train_file = open(os.path.join(source_path, subset, hvo_pickle_filename),'rb')
        train_set = pickle.load(train_file)
        metadata = pd.read_csv(os.path.join(source_path, subset, metadata_csv_filename))
        
        self.hvo_sequences=[]
        self.processed_inputs=[]
        self.processed_outputs=[]
        
        for ix, hvo_seq in enumerate(train_set):
            if len(hvo_seq.time_signatures) == 1: # ignore if time_signature change happens
                all_zeros = not np.any(hvo_seq.hvo.flatten())
                if not all_zeros: # ignore silent patterns
                    if check_if_passes_filters(metadata.loc[ix], filters):
                                                
                        #add metadata to hvo_seq scores
                        hvo_seq.drummer = metadata.loc[ix].at["drummer"]
                        hvo_seq.session = metadata.loc[ix].at["session"]
                        hvo_seq.master_id = metadata.loc[ix].at["master_id"]
                        hvo_seq.style_primary = metadata.loc[ix].at["style_primary"]
                        hvo_seq.style_secondary = metadata.loc[ix].at["style_secondary"]
                        hvo_seq.beat_type = metadata.loc[ix].at["beat_type"]
                        hvo_seq.loop_id = metadata.loc[ix].at["loop_id"]
                        
                        # pad with zeros to match max_len
                        pad_count = max(max_len - hvo_seq.hvo.shape[0],0)
                        hvo_seq.hvo = np.pad(hvo_seq.hvo, ((0,pad_count), (0,0)), 'constant')
                        hvo_seq.hvo = hvo_seq.hvo[:max_len, :] # in case seq exceeds max len
                        self.hvo_sequences.append(hvo_seq)
                        
                        # get processed inputs
                        audio_file_path = os.path.join(synthesized_audio_path, subset, hvo_seq.master_id)
                        if not os.path.exists(audio_file_path): os.makedirs(audio_file_path)
                        filename = os.path.join(audio_file_path, hvo_seq.loop_id.split('/')[-1].replace(':','_')+'.wav')
                        hvo_seq.save_audio(filename,sr,sf_path)
                        
                        input_features = input_features_extractor(filename, **input_features_parameters)
                        self.processed_inputs.append(input_features)
                        
                        #get processed_outputs
                        # iteration 1
                        
                        
        
    def __len__(self):
        return len(self.hvo_sequences)
    
    def __getitem__(self,idx):
        return self.hvo_sequences[idx].hvo, idx




In [8]:
gmd = GrooveMidiDataset()

In [9]:
gmd.__len__()

30