In [3]:
import numpy as np
from Music_Style_Transfer_master.project.midi_handler import score2midi, midi2score
from Music_Style_Transfer_master.project.utils import add_beat, padding, load_model, get_representation, dataset_import
from Music_Style_Transfer_master.project.test import style_transfer
from Music_Style_Transfer_master.project.model_transformer_relative import _2way_transformer_wavenet_absolute_pitch, _2way_transformer_wavenet_relative_pitch

import tqdm
import pypianoroll as pr
import tensorflow as tf
import os

import matplotlib.pyplot as plt
#from keras.models import load_model
%matplotlib inline

# 33:97 A1~C7  21:109 A0~C8
# 7 octave 21:105  A0~G#7  
def midi2score(path, beat_resolution, pitch_range=np.arange(21, 109)):
    tmp = pr.Multitrack(filename=path, beat_resolution=beat_resolution)   
    tmp.binarize()
    tmp.assign_constant(1)
    # TODO: use variable to handle pitch range
    score = tmp.get_merged_pianoroll()[:, pitch_range]
    
    return score

def load_model(path_model):
    model = _2way_transformer_wavenet_absolute_pitch(len_context = 8*4*beat_resolution,
                                                 n_octave = 7,
                                                 size_embedding = 84,
                                                 n_transf_layers = 4,
                                                 n_transf_downsample = 0,
                                                 n_conv_layers = 7,
                                                 context_layers = 1)
    model.load_weights(path_model)
    
    return model

def time_index_chosing(_range,
                       interval,
                       random=True,
                       time_list=None):
    # For parallel generation
    
    interval += 1
    if (random):
        time_index_base = np.random.randint(_range)
    else:
        time_index_base = time_list[np.random.randint(len(time_list))]

    c = time_index_base + np.arange(-interval * 200, interval * 200, interval)

    return c[np.where(np.logical_and(c >= 0, c < _range))]


def midi_index_chosing(_len, pitch_range):
    return np.random.randint(pitch_range, size=_len)


def generation(model, score, meta, time_indexes, midi_indexes,
                       len_context
                       ):
    time_indexes = np.array(time_indexes) +  len_context
    
    left_features = (np.array(score[[np.arange(t -  len_context, t) for t in time_indexes], :]) > 0).astype(int)
    left_metas = np.array(meta[[np.arange(t -  len_context, t) for t in time_indexes], :])
    
    right_features = (np.array(score[[np.arange(t + 1, t + 1 + len_context) for t in time_indexes], :]) > 0).astype(int)
    right_metas = np.array(meta[[np.arange(t + 1, t + 1 + len_context) for t in time_indexes], :])
    
    central_features = (np.array(score[time_indexes, :]) > 0).astype(int)
    central_metas = np.array(meta[time_indexes, :])
    
    for a, b in enumerate(midi_indexes):
        central_features[a, b:] = 2
        
    central_features = np.reshape(central_features, (len(time_indexes), 1, -1))
    central_metas = np.reshape(central_metas, (len(time_indexes), 1, -1))

    p = model.predict([left_features,  left_metas, 
                       central_features, central_metas,
                       right_features, right_metas])
    
    return p

def style_transfer_new(score, meta, score_melody, model, len_context, pitch_range, 
                   iter_num=25, threshold=0.5
                   ):

    fixed_rhythm_score = score_melody
    original_len = len(score_melody)
    new_score = np.array(score)
    counter = 0
    alpha_initial = 0.6
    alpha = alpha_initial
    alpha_min = 0
    annealing_fraction = 0.6
    update_count = 0
    
    for i in tqdm.tqdm(range(iter_num)):
        time_list = np.arange(original_len)
        if i < iter_num*0.9:
            score[len_context:-len_context] = np.logical_or(score[len_context:-len_context], score_melody).astype(int)
        print("alpha = ", alpha)
        
        while (time_list.size > 0):
            if(alpha != -1):
                alpha = max(0, alpha_initial - update_count * (alpha_initial - alpha_min) / (
                    iter_num * original_len * annealing_fraction))
            if(alpha == 0):
                score = new_score
                alpha = -1
            elif(counter / original_len > alpha and alpha != -1):
                counter = 0
                score = np.array(new_score)

            time_indexes = time_index_chosing(original_len, len_context, random=False, time_list=time_list)
            l = len(time_indexes)
            sorter = np.argsort(time_list)
            d = sorter[np.searchsorted(time_list, time_indexes, sorter=sorter)]
            time_list = np.delete(time_list, d, 0)
            counter += l

            update_count += l

            if(alpha != -1):
                midi_indexes = np.arange(pitch_range).tolist() * len(time_indexes)
                time_indexes_repeat = np.repeat(time_indexes, pitch_range)
                p = generation(model, score, meta, time_indexes_repeat, midi_indexes, len_context)
                
                # TODO: This for loop seems can be removed
                for j, t in enumerate(time_indexes_repeat):
                    update = p[0][j]
#                    if(fixed_rhythm_score[t, midi_indexes[j]] == 0):
#                     new_score[t + len_context, midi_indexes[j]] = np.random.binomial(1, p[0][j][1])
                    if update < threshold :
                        new_score[t + len_context, midi_indexes[j]] = 0
                    else:
                        new_score[t + len_context, midi_indexes[j]] = 1
               
            else:
                for midi_index in range(pitch_range):
                    
                    midi_indexes = [midi_index] * l
                    p = generation(model, score, meta, time_indexes, midi_indexes, len_context)
                        
                    for j, t in enumerate(time_indexes):
                        update = p[0][j]
                        if i > (0.8 *iter_num):
                            if(fixed_rhythm_score[t, midi_indexes[j]] == 0):
                                if update < threshold :
                                    new_score[t + len_context, midi_indexes[j]] = 0
                                else:
                                    new_score[t + len_context, midi_indexes[j]] = 1
                        else:   
#                         if(fixed_rhythm_score[t, midi_indexes[j]] == 0):
#                             new_score[t + len_context, midi_indexes[j]] = np.random.binomial(1, p[0][j][1])
                            if update < threshold :
                                new_score[t + len_context, midi_indexes[j]] = 0
                            else:
                                new_score[t + len_context, midi_indexes[j]] = 1
        
    return new_score[len_context:-len_context, :pitch_range]

In [2]:
beat_resolution = 8
pitch_range = 84
len_context = beat_resolution*4*8
max_len = beat_resolution*4*4

path_model = "../../../scratch/wtl272/projects/Music_style_transfer/runs/run_53/model/model.hdf5"
model = load_model(path_model)

In [4]:
path = "input_midi/0716/"
file_list = dataset_import(path, ".mid")

for i in file_list:
    if "harmonized" in i :
        file_list.remove(i)

Load from input_midi/0716/
22 files loaded!


In [5]:
file_list

['input_midi/0716/bicycle.mid',
 'input_midi/0716/donkey.mid',
 'input_midi/0716/happy_birthday.mid',
 'input_midi/0716/home.mid',
 'input_midi/0716/indians.mid',
 'input_midi/0716/londonbridge.mid',
 'input_midi/0716/macdonald.mid',
 'input_midi/0716/stars.mid',
 'input_midi/0716/together.mid',
 'input_midi/0716/train.mid',
 'input_midi/0716/two_tigers.mid']

In [6]:
for i in file_list:

    path = i[:-4] + "_harmonized.mid"
    path_melody = i
    print(path, path_melody)
    
    score = midi2score(path, beat_resolution, pitch_range=np.arange(21, 105))
    
    if path_melody is None:
        score_melody = np.zeros(score.shape)
    else:
        score_melody = midi2score(path_melody, beat_resolution, pitch_range=np.arange(21, 105))

    # New version
    score = np.array(score[:])
    score_melody = np.array(score_melody[:])
    score, score_meta = get_representation(score, pitch_range, beat_resolution, len_context)

    #generation
    result = style_transfer_new(score, score_meta, score_melody, model, len_context, pitch_range, iter_num=15)

    #save result
    score2midi("output_midi/Results_1112/" + os.path.basename(i).split(".")[0] + "_bach_model53.mid", 
               result[:len(score_melody)], beat_resolution, 120, 
               pitch_range=np.arange(21, 105),melody_constraint=True, melody=score_melody)


  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/bicycle_harmonized.mid input_midi/0716/bicycle.mid
alpha =  0.6


  7%|▋         | 1/15 [00:25<05:59, 25.67s/it]

alpha =  0.5338235294117647


 13%|█▎        | 2/15 [00:42<04:57, 22.91s/it]

alpha =  0.46715686274509804


 20%|██        | 3/15 [00:58<04:11, 20.99s/it]

alpha =  0.4004901960784314


 27%|██▋       | 4/15 [01:15<03:36, 19.64s/it]

alpha =  0.3338235294117647


 33%|███▎      | 5/15 [01:31<03:07, 18.70s/it]

alpha =  0.26715686274509803


 40%|████      | 6/15 [01:48<02:42, 18.04s/it]

alpha =  0.20049019607843133


 47%|████▋     | 7/15 [02:04<02:20, 17.58s/it]

alpha =  0.13382352941176467


 53%|█████▎    | 8/15 [02:21<02:00, 17.25s/it]

alpha =  0.06715686274509802


 60%|██████    | 9/15 [02:37<01:42, 17.03s/it]

alpha =  0.0004901960784313708


 67%|██████▋   | 10/15 [07:03<07:37, 91.53s/it]

alpha =  -1


 73%|███████▎  | 11/15 [11:27<09:34, 143.53s/it]

alpha =  -1


 80%|████████  | 12/15 [15:57<09:03, 181.29s/it]

alpha =  -1


 87%|████████▋ | 13/15 [20:27<06:56, 208.12s/it]

alpha =  -1


 93%|█████████▎| 14/15 [24:54<03:45, 225.68s/it]

alpha =  -1


100%|██████████| 15/15 [29:19<00:00, 237.32s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/donkey_harmonized.mid input_midi/0716/donkey.mid
alpha =  0.6


  7%|▋         | 1/15 [00:31<07:22, 31.64s/it]

alpha =  0.53359375


 13%|█▎        | 2/15 [01:03<06:51, 31.63s/it]

alpha =  0.4669270833333333


 20%|██        | 3/15 [01:34<06:19, 31.62s/it]

alpha =  0.4002604166666667


 27%|██▋       | 4/15 [02:06<05:47, 31.61s/it]

alpha =  0.33359375


 33%|███▎      | 5/15 [02:38<05:16, 31.63s/it]

alpha =  0.2669270833333333


 40%|████      | 6/15 [03:09<04:44, 31.62s/it]

alpha =  0.20026041666666666


 47%|████▋     | 7/15 [03:41<04:12, 31.60s/it]

alpha =  0.13359375


 53%|█████▎    | 8/15 [04:12<03:41, 31.59s/it]

alpha =  0.0669270833333333


 60%|██████    | 9/15 [04:44<03:09, 31.59s/it]

alpha =  0.00026041666666665186


 67%|██████▋   | 10/15 [13:00<14:14, 170.96s/it]

alpha =  -1


 73%|███████▎  | 11/15 [21:18<17:56, 269.09s/it]

alpha =  -1


 80%|████████  | 12/15 [29:38<16:55, 338.34s/it]

alpha =  -1


 87%|████████▋ | 13/15 [37:55<12:52, 386.05s/it]

alpha =  -1


 93%|█████████▎| 14/15 [46:13<06:59, 419.48s/it]

alpha =  -1


100%|██████████| 15/15 [54:32<00:00, 443.30s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/happy_birthday_harmonized.mid input_midi/0716/happy_birthday.mid
alpha =  0.6


  7%|▋         | 1/15 [00:29<06:55, 29.65s/it]

alpha =  0.533611111111111


 13%|█▎        | 2/15 [00:59<06:25, 29.64s/it]

alpha =  0.46694444444444444


 20%|██        | 3/15 [01:28<05:55, 29.63s/it]

alpha =  0.40027777777777773


 27%|██▋       | 4/15 [01:58<05:25, 29.63s/it]

alpha =  0.3336111111111111


 33%|███▎      | 5/15 [02:28<04:56, 29.64s/it]

alpha =  0.26694444444444443


 40%|████      | 6/15 [02:57<04:26, 29.63s/it]

alpha =  0.20027777777777778


 47%|████▋     | 7/15 [03:27<03:57, 29.64s/it]

alpha =  0.13361111111111112


 53%|█████▎    | 8/15 [03:57<03:27, 29.62s/it]

alpha =  0.06694444444444447


 60%|██████    | 9/15 [04:26<02:57, 29.63s/it]

alpha =  0.0002777777777778212


 67%|██████▋   | 10/15 [12:15<13:27, 161.50s/it]

alpha =  -1


 73%|███████▎  | 11/15 [20:08<16:59, 254.94s/it]

alpha =  -1


 80%|████████  | 12/15 [27:57<15:56, 318.95s/it]

alpha =  -1


 87%|████████▋ | 13/15 [35:44<12:07, 363.51s/it]

alpha =  -1


 93%|█████████▎| 14/15 [43:36<06:35, 395.90s/it]

alpha =  -1


100%|██████████| 15/15 [51:24<00:00, 417.58s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/home_harmonized.mid input_midi/0716/home.mid
alpha =  0.6


  7%|▋         | 1/15 [02:07<29:51, 128.00s/it]

alpha =  0.5335858585858586


 13%|█▎        | 2/15 [04:11<27:27, 126.71s/it]

alpha =  0.4669191919191919


 20%|██        | 3/15 [06:15<25:09, 125.80s/it]

alpha =  0.40025252525252525


 27%|██▋       | 4/15 [08:18<22:56, 125.14s/it]

alpha =  0.33358585858585854


 33%|███▎      | 5/15 [10:22<20:46, 124.67s/it]

alpha =  0.26691919191919194


 40%|████      | 6/15 [12:26<18:39, 124.34s/it]

alpha =  0.20025252525252524


 47%|████▋     | 7/15 [14:29<16:33, 124.14s/it]

alpha =  0.13358585858585859


 53%|█████▎    | 8/15 [16:33<14:27, 123.99s/it]

alpha =  0.06691919191919193


 60%|██████    | 9/15 [18:37<12:23, 123.87s/it]

alpha =  0.0002525252525252819


 67%|██████▋   | 10/15 [27:16<20:12, 242.59s/it]

alpha =  -1


 73%|███████▎  | 11/15 [35:52<21:38, 324.55s/it]

alpha =  -1


 80%|████████  | 12/15 [44:34<19:11, 383.70s/it]

alpha =  -1


 87%|████████▋ | 13/15 [53:15<14:10, 425.05s/it]

alpha =  -1


 93%|█████████▎| 14/15 [1:01:52<07:32, 452.72s/it]

alpha =  -1


100%|██████████| 15/15 [1:10:28<00:00, 471.43s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/indians_harmonized.mid input_midi/0716/indians.mid
alpha =  0.6


  7%|▋         | 1/15 [00:29<06:54, 29.61s/it]

alpha =  0.533611111111111


 13%|█▎        | 2/15 [00:59<06:24, 29.60s/it]

alpha =  0.46694444444444444


 20%|██        | 3/15 [01:28<05:55, 29.60s/it]

alpha =  0.40027777777777773


 27%|██▋       | 4/15 [01:58<05:25, 29.59s/it]

alpha =  0.3336111111111111


 33%|███▎      | 5/15 [02:27<04:55, 29.58s/it]

alpha =  0.26694444444444443


 40%|████      | 6/15 [02:57<04:26, 29.58s/it]

alpha =  0.20027777777777778


 47%|████▋     | 7/15 [03:27<03:56, 29.59s/it]

alpha =  0.13361111111111112


 53%|█████▎    | 8/15 [03:56<03:27, 29.59s/it]

alpha =  0.06694444444444447


 60%|██████    | 9/15 [04:26<02:57, 29.60s/it]

alpha =  0.0002777777777778212


 67%|██████▋   | 10/15 [12:14<13:25, 161.03s/it]

alpha =  -1


 73%|███████▎  | 11/15 [20:01<16:52, 253.08s/it]

alpha =  -1


 80%|████████  | 12/15 [27:49<15:52, 317.54s/it]

alpha =  -1


 87%|████████▋ | 13/15 [35:38<12:05, 362.77s/it]

alpha =  -1


 93%|█████████▎| 14/15 [43:26<06:34, 394.37s/it]

alpha =  -1


100%|██████████| 15/15 [51:16<00:00, 417.08s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/londonbridge_harmonized.mid input_midi/0716/londonbridge.mid
alpha =  0.6


  7%|▋         | 1/15 [00:31<07:22, 31.64s/it]

alpha =  0.53359375


 13%|█▎        | 2/15 [01:03<06:51, 31.64s/it]

alpha =  0.4669270833333333


 20%|██        | 3/15 [01:34<06:19, 31.63s/it]

alpha =  0.4002604166666667


 27%|██▋       | 4/15 [02:06<05:47, 31.63s/it]

alpha =  0.33359375


 33%|███▎      | 5/15 [02:38<05:16, 31.63s/it]

alpha =  0.2669270833333333


 40%|████      | 6/15 [03:09<04:44, 31.63s/it]

alpha =  0.20026041666666666


 47%|████▋     | 7/15 [03:41<04:13, 31.63s/it]

alpha =  0.13359375


 53%|█████▎    | 8/15 [04:13<03:41, 31.64s/it]

alpha =  0.0669270833333333


 60%|██████    | 9/15 [04:44<03:09, 31.63s/it]

alpha =  0.00026041666666665186


 67%|██████▋   | 10/15 [13:05<14:22, 172.52s/it]

alpha =  -1


 73%|███████▎  | 11/15 [21:24<18:01, 270.33s/it]

alpha =  -1


 80%|████████  | 12/15 [29:43<16:56, 338.99s/it]

alpha =  -1


 87%|████████▋ | 13/15 [38:01<12:53, 386.78s/it]

alpha =  -1


 93%|█████████▎| 14/15 [46:18<06:59, 419.66s/it]

alpha =  -1


100%|██████████| 15/15 [54:36<00:00, 443.27s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/macdonald_harmonized.mid input_midi/0716/macdonald.mid
alpha =  0.6


  7%|▋         | 1/15 [01:04<15:07, 64.81s/it]

alpha =  0.5335897435897435


 13%|█▎        | 2/15 [02:07<13:55, 64.24s/it]

alpha =  0.4669230769230769


 20%|██        | 3/15 [03:10<12:46, 63.84s/it]

alpha =  0.40025641025641023


 27%|██▋       | 4/15 [04:13<11:39, 63.55s/it]

alpha =  0.3335897435897436


 33%|███▎      | 5/15 [05:16<10:33, 63.37s/it]

alpha =  0.26692307692307693


 40%|████      | 6/15 [06:19<09:29, 63.23s/it]

alpha =  0.20025641025641022


 47%|████▋     | 7/15 [07:22<08:25, 63.13s/it]

alpha =  0.13358974358974363


 53%|█████▎    | 8/15 [08:25<07:21, 63.07s/it]

alpha =  0.06692307692307697


 60%|██████    | 9/15 [09:28<06:18, 63.05s/it]

alpha =  0.0002564102564103221


 67%|██████▋   | 10/15 [17:57<16:24, 196.88s/it]

alpha =  -1


 73%|███████▎  | 11/15 [26:26<19:22, 290.51s/it]

alpha =  -1


 80%|████████  | 12/15 [34:59<17:51, 357.22s/it]

alpha =  -1


 87%|████████▋ | 13/15 [43:26<13:24, 402.29s/it]

alpha =  -1


 93%|█████████▎| 14/15 [51:54<07:14, 434.05s/it]

alpha =  -1


100%|██████████| 15/15 [1:00:23<00:00, 456.37s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/stars_harmonized.mid input_midi/0716/stars.mid
alpha =  0.6


  7%|▋         | 1/15 [00:46<10:54, 46.75s/it]

alpha =  0.5335069444444445


 13%|█▎        | 2/15 [01:33<10:07, 46.75s/it]

alpha =  0.46684027777777776


 20%|██        | 3/15 [02:20<09:21, 46.75s/it]

alpha =  0.40017361111111105


 27%|██▋       | 4/15 [03:06<08:34, 46.73s/it]

alpha =  0.3335069444444444


 33%|███▎      | 5/15 [03:53<07:46, 46.66s/it]

alpha =  0.2668402777777778


 40%|████      | 6/15 [04:39<06:59, 46.59s/it]

alpha =  0.2001736111111111


 47%|████▋     | 7/15 [05:26<06:12, 46.55s/it]

alpha =  0.1335069444444444


 53%|█████▎    | 8/15 [06:12<05:25, 46.53s/it]

alpha =  0.06684027777777779


 60%|██████    | 9/15 [06:59<04:39, 46.52s/it]

alpha =  0.00017361111111113825


 67%|██████▋   | 10/15 [15:24<15:20, 184.03s/it]

alpha =  -1


 73%|███████▎  | 11/15 [23:49<18:41, 280.32s/it]

alpha =  -1


 80%|████████  | 12/15 [32:14<17:23, 347.79s/it]

alpha =  -1


 87%|████████▋ | 13/15 [40:39<13:10, 395.04s/it]

alpha =  -1


 93%|█████████▎| 14/15 [49:08<07:09, 429.32s/it]

alpha =  -1


100%|██████████| 15/15 [57:37<00:00, 453.06s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

input_midi/0716/together_harmonized.mid input_midi/0716/together.mid
alpha =  0.6


  7%|▋         | 1/15 [00:49<11:33, 49.53s/it]

alpha =  0.5334967320261438


 13%|█▎        | 2/15 [01:38<10:43, 49.50s/it]

alpha =  0.46683006535947713


 20%|██        | 3/15 [02:28<09:53, 49.48s/it]

alpha =  0.4001633986928105


 27%|██▋       | 4/15 [03:17<09:04, 49.47s/it]

alpha =  0.33349673202614377


 33%|███▎      | 5/15 [04:07<08:14, 49.47s/it]

alpha =  0.2668300653594771


 40%|████      | 6/15 [04:56<07:25, 49.50s/it]

alpha =  0.2001633986928104


 47%|████▋     | 7/15 [05:46<06:35, 49.49s/it]

alpha =  0.13349673202614376


 53%|█████▎    | 8/15 [06:35<05:46, 49.50s/it]

alpha =  0.06683006535947711


 60%|██████    | 9/15 [07:25<04:57, 49.51s/it]

alpha =  0.00016339869281045694


 67%|██████▋   | 10/15 [15:50<15:31, 186.34s/it]

alpha =  -1


 73%|███████▎  | 11/15 [24:16<18:48, 282.21s/it]

alpha =  -1


 80%|████████  | 12/15 [32:43<17:28, 349.39s/it]

alpha =  -1


 87%|████████▋ | 13/15 [41:09<13:12, 396.37s/it]

alpha =  -1


KeyboardInterrupt: 

In [None]:
for i in file_list:

    path = i[:-4] + "_harmonized.mid"
    path_melody = i
    print(path, path_melody)
    
    score = midi2score(path, beat_resolution, pitch_range=np.arange(21, 105))
    
    if path_melody is None:
        score_melody = np.zeros(score.shape)
    else:
        score_melody = midi2score(path_melody, beat_resolution, pitch_range=np.arange(21, 105))

    # New version
    score = np.array(score[:])
    score_melody = np.array(score_melody[:])
    score, score_meta = get_representation(score, pitch_range, beat_resolution, len_context)

    #generation
    result = style_transfer_new(score, score_meta, score_melody, model, len_context, pitch_range, iter_num=15)

    #save result
    score2midi("output_midi/Results_1022/" + os.path.basename(i).split(".")[0] + "_bach_model49.mid", 
               result[:len(score_melody)], beat_resolution, 120, 
               pitch_range=np.arange(21, 105),melody_constraint=True, melody=score_melody)


## def len_generator(path_dataset, len_context, phase="valid", percentage_train=0.8,):
    d_scores = dataset_import(path_dataset, ".npy")
    d_metas = []
    for i in d_scores:
        if "meta" in i:
            d_metas.append(i)
            d_scores.remove(i)
         
    d_metas.sort(key=lambda x: (int(os.path.basename(x)[:-9])))
    
    # Train valid test split:  percentage_train,  0.5*(1-percentage_train), 0.5*(1-percentage_train)
    if phase == 'train':
        score_indices = np.arange(int(len(d_scores) * percentage_train))
    if phase == "valid":
        score_indices = np.arange(int(len(d_scores) * percentage_train), int((1 + percentage_train)/2 * len(d_scores)))
    elif phase == 'test':
        score_indices = np.arange(int((1 + percentage_train)/2 * len(d_scores)), len(d_scores))
    elif phase == 'all':
        score_indices = np.arange(len(d_scores))

    num_samples = 0
    for score_index in score_indices:
        score = np.load(d_scores[score_index])
        num_samples += (score.shape[0] - 2*len_context)*score.shape[1]    
    
    return num_samples

In [21]:
p = "../../../scratch/wtl272/dataset/bach_8_on_offset/"
len_generator(p, 256)/84

Load from ../../../scratch/wtl272/dataset/bach_8_on_offset/
742 files loaded!


16832.0