In [None]:
%matplotlib inline

In [None]:
import sys
import os
import glob
import csv
import librosa
import librosa.display
import pretty_midi
import numpy as np
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import pandas as pd
from multiprocessing import Pool
from tqdm import tqdm, trange
import pickle

## 1. Audio Synchronization Baseline
### 1.1 Get chroma features from midi

In [None]:
synth_midi_path = 'synth_midi'
midi_path = 'midi'
piece = 'debussy_childrencorner6'
midi_file1 = synth_midi_path + '/sharpeye/' + piece + '_v1.mid'
midi_file2 = midi_path + '/' + piece + '.mid'

In [None]:
sr = 22050
hop_size = 0.025
window_len = 0.025

In [None]:
mid1 = pretty_midi.PrettyMIDI(midi_file1)
mid2 = pretty_midi.PrettyMIDI(midi_file2)
audio1 = mid1.synthesize()
audio2 = mid2.synthesize()

In [None]:
chroma1 = librosa.feature.mfcc(audio1, sr, hop_length=int(hop_size*sr), n_fft=int(window_len*sr))
chroma2 = librosa.feature.mfcc(audio2, sr, hop_length=int(hop_size*sr), n_fft=int(window_len*sr))

### 1.2 DTW on chroma feature

In [None]:
def alignAudio(M1, M2):
    # Get cost metric
    C = cdist(M1, M2, 'seuclidean', V=None)
    
    # DTW
    steps = np.array([1,1,1,2,2,1]).reshape((3,2))
    weights = np.array([2,3,3])
    D, wp = librosa.core.dtw(C=C, step_sizes_sigma=steps, weights_mul=weights)
    return wp[::-1,:].transpose()

In [None]:
wp = alignAudio(np.transpose(chroma1), np.transpose(chroma2))

### 1.3 Calculate Error

In [None]:
def getMidiRefLocs(annot_file):
    
    timeStamps = []
    with open(annot_file, newline='') as csvfile:
        spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
        for row in spamreader:
            if row[0] != '-':
                timeStamps.append(float(row[0]))
            else:
                timeStamps.append(float('inf'))
    timeStamps = np.array(timeStamps)
    
    return timeStamps

In [None]:
def getSheetRefLocs(scoreid, changeDPI = False):
    
    hyp_file = 'hyp_align/'+scoreid+'.pkl'
    dhyp = pickle.load(open(hyp_file, 'rb'))
    striplens = dhyp['striplens']
    
    # get annotation file
    annot_dir = 'annot_data'
    piece = scoreid.split('_')
    annot_file_beats = '%s/%s_%s_beats.csv' % (annot_dir, piece[0], piece[1])
    df_all = pd.read_csv(annot_file_beats)
    
    # calculate global pixel position
    scoreid = piece[1]+'_'+piece[2]
    df = df_all.loc[df_all.score == scoreid]
    pixelOffset = np.cumsum([0] + striplens)  # cumulative pixel offset for each strip
    stripsPerPage = [df.loc[df.page == i,'strip'].max() for i in range(df.page.max()+1) ]
    stripOffset = np.cumsum([0] + stripsPerPage)
    stripIdx = stripOffset[df.page] + df.strip - 1  # cumulative strip index
    
    if changeDPI:
        hpixlocs = pixelOffset[stripIdx] + (df.hpixel  * 400 // (72*4))
    else:
        hpixlocs = pixelOffset[stripIdx] + df.hpixel
    
    return hpixlocs.values

In [None]:
synth_timestamps = mid1.get_beats()
perf_timestamps1 = mid2.get_beats()

In [None]:
def calcPredErrors(wp, perf_timestamps, synth_timestamps):
    all_errs = []    
    for i, beat_time in enumerate(perf_timestamps):
        frame_id2 = beat_time // hop_size
        wp_id2 = np.argmin([abs(x-frame_id2) for x in wp[1]])
        frame_id1 = wp[0][wp_id2]
        all_errs.append((synth_timestamps[i] - (hop_size*frame_id1)) * 1000) # in ms
    return all_errs

In [None]:
def calcErrorStats(errs_raw, tols, isSingle = False):
    if isSingle:
        errs = errs_raw
    else:
        errs = np.array([err for sublist in errs_raw for err in sublist])
    errs = errs[~np.isnan(errs)] # when beat is not annotated, value is nan
    errorRates = []
    for tol in tols:
        toAdd = np.sum(np.abs(errs) > tol) * 1.0 / len(errs)
        errorRates.append(toAdd)
    return errorRates

In [None]:
errs1 = calcPredErrors(wp, perf_timestamps1, synth_timestamps)

In [None]:
tols = np.arange(5000)
errorRates1 = calcErrorStats(errs1, tols, True)
plt.plot(tols, 100.0*np.array(errorRates1), 'k-', label='auto-annot')
plt.xlabel('Error Tolerance (milliseconds)')
plt.ylabel('Error Rate (%)')
plt.gca().set_ylim([0,100])
plt.legend()
plt.show()

### 1.4 Run experiment on the whole dataset

In [None]:
synth_midi_path = 'synth_midi/'
midi_path = 'midi/'
annot_dir = 'annot_data/'
pieces = ['brahms_op116no6', 'brahms_op117no2', 
          'chopin_op30no2', 'chopin_op63no3', 'chopin_op68no3', 
          'clementi_op36no1mv3', 'clementi_op36no2mv3', 'clementi_op36no3mv3',
          'debussy_childrencorner1', 'debussy_childrencorner3', 'debussy_childrencorner6',
          'mendelssohn_op19no2', 'mendelssohn_op62no3', 'mendelssohn_op62no5',
          'mozart_kv311mv3', 'mozart_kv333mv3',
          'schubert_op90no1', 'schubert_op90no3', 'schubert_op94no2',
          'tchaikovsky_season01', 'tchaikovsky_season06', 'tchaikovsky_season08']

In [None]:
def calcSingleError(mid_pair, perf_timestamps, synth_timestamps):
    mid1 = mid_pair[0]
    mid2 = mid_pair[1]
    audio1 = mid1.synthesize()
    audio2 = mid2.synthesize()
    chroma1 = librosa.feature.mfcc(audio1, sr, hop_length=int(hop_size*sr), n_fft=int(window_len*sr))
    chroma2 = librosa.feature.mfcc(audio2, sr, hop_length=int(hop_size*sr), n_fft=int(window_len*sr))
    wp = alignAudio(np.transpose(chroma1), np.transpose(chroma2))
    
    if len(synth_timestamps) != len(perf_timestamps):
        minLen = min(len(synth_timestamps), len(perf_timestamps))
        synth_timestamps = synth_timestamps[:minLen]
        perf_timestamps = perf_timestamps[:minLen]
        
    errs = calcPredErrors(wp, perf_timestamps, synth_timestamps)
    return errs, wp

In [None]:
def runExperiment(program, pieces_list):
    allErrs_time = []
    allErrs_pixel = []
    
    for piece in pieces_list:
        all_sheets = sorted(glob.glob('score_data/prepped_pdf/%s*' % piece))
        real_midis = sorted(glob.glob(midi_path+'%s*' % piece))
        perf_timestamps = getMidiRefLocs(annot_dir + 'midi/' + piece + '.csv')
        
        if program == 'sharpeye':
            synth_annot_files = sorted(glob.glob(annot_dir+'synth_midi/'+'%s*_se.csv' % piece.split('_')[1]))
        elif program == 'photoscore':
            synth_annot_files = sorted(glob.glob(annot_dir+'synth_midi/'+'%s*_ps.csv' % piece.split('_')[1]))
        
        for i in range(len(real_midis)):
            mid2 = pretty_midi.PrettyMIDI(real_midis[i])
            
            for j in range(len(all_sheets)):
                scoreid = all_sheets[j].split('/')[-1].split('.')[0]
                sheet_annot = getSheetRefLocs(scoreid)
                synth_file = synth_midi_path+program+'/'+scoreid+'.mid'
                synth_name = synth_file.split('/')[-1].split('.')[0]
                synth_name = synth_name.split('_')[1] + '_' + synth_name.split('_')[2]
                    
                if program == 'sharpeye':
                    synth_annot_file = annot_dir + 'synth_midi/' + synth_name + '_se.csv'
                elif program == 'photoscore':
                    synth_annot_file = annot_dir + 'synth_midi/' + synth_name + '_ps.csv'
                        
                if synth_annot_file in synth_annot_files and (program != 'photoscore' or scoreid != 'chopin_op68no3_v6'):
                    mid1 = pretty_midi.PrettyMIDI(synth_file)
                    print(real_midis[i], synth_file)
                
                    synth_timestamps = getMidiRefLocs(synth_annot_file)
                    err_t, wp = calcSingleError([mid1, mid2], perf_timestamps, synth_timestamps)
                    allErrs_time.append(err_t)
                    
                    hypPixels = np.interp(perf_timestamps, wp[:,1], wp[:,0])
                    minLen_p = min(len(hypPixels), len(sheet_annot))
                    allErrs_pixel.append(hypPixels[:minLen_p] - sheet_annot[:minLen_p])
                    
                else:
                    allErrs_pixel.append([float('inf')]*len(sheet_annot))
                    allErrs_time.append([float('inf')]*len(perf_timestamps))
        
    return allErrs_pixel, allErrs_time

In [None]:
def alignAll(program, pieces_list):
    allwp = {}
    
    for piece in pieces_list:
        synth_midis = [os.path.basename(elem) for elem in sorted(glob.glob(synth_midi_path+program+'/%s*' % piece))]
        real_midis = [os.path.basename(elem) for elem in sorted(glob.glob(midi_path+'/*%s*' % piece))]
        
        for i in trange(len(real_midis)):
            real_midi_name = real_midis[i]
            #real_full_path = midi_path + '/' + piece + '/' + real_midi_name
            real_full_path = midi_path + '/' + real_midi_name
            mid2 = pretty_midi.PrettyMIDI(real_full_path)
            audio2 = mid2.synthesize()
            chroma2 = librosa.feature.mfcc(audio2, sr, hop_length=int(hop_size*sr), n_fft=int(window_len*sr))
            
            for j in range(len(synth_midis)):
                synth_midi_name = synth_midis[j]
                synth_full_path = synth_midi_path+'/'+program+'/'+synth_midi_name
                
                mid1 = pretty_midi.PrettyMIDI(synth_full_path)
                audio1 = mid1.synthesize()
                chroma1 = librosa.feature.mfcc(audio1, sr, hop_length=int(hop_size*sr), n_fft=int(window_len*sr))
                
                wp = alignAudio(np.transpose(chroma1), np.transpose(chroma2))
                
                allwp[(real_full_path, synth_full_path)] = wp
                
    with open('results/audioalign_nonmzk_'+program+'.pkl','wb') as f:
        pickle.dump(allwp, f)

In [None]:
alignAll('photoscore', pieces)

In [None]:
alignAll('sharpeye', pieces)

In [None]:
allErrs_se_as_pix, allErrs_se_as_t = runExperiment('sharpeye', pieces)
allErrs_ps_as_pix, allErrs_ps_as_t = runExperiment('photoscore', pieces)

In [None]:
with open('results/errorData_real_as.pkl','wb') as f:
    pickle.dump([allErrs_ps_as_pix, allErrs_ps_as_t, allErrs_se_as_pix, allErrs_se_as_t],f)

## 2. Midi-Beat-Matching

In [None]:
def midiBeatMatch(program, pieces_list):
    allErrs_pixel = []
    allErrs_time = []
    
    for piece in pieces_list:
        perf_timestamps = getMidiRefLocs(annot_dir + 'midi/' + piece + '.csv')
        all_sheets = sorted(glob.glob('score_data/prepped_pdf/%s*' % piece))
        if program == 'sharpeye':
            synth_annot_files = sorted(glob.glob(annot_dir+'synth_midi/'+'%s*_se.csv' % piece.split('_')[1]))
        elif program == 'photoscore':
            synth_annot_files = sorted(glob.glob(annot_dir+'synth_midi/'+'%s*_ps.csv' % piece.split('_')[1]))
        print(synth_annot_files)
            
        for j in range(len(all_sheets)):
            scoreid = all_sheets[j].split('/')[-1].split('.')[0]
            sheet_annot = getSheetRefLocs(scoreid)
            synth_file = synth_midi_path+program+'/'+scoreid+'.mid'
            
            synth_name = synth_file.split('/')[-1].split('.')[0]
            synth_name = synth_name.split('_')[1] + '_' + synth_name.split('_')[2]
            
            if program == 'sharpeye':
                synth_annot_file = annot_dir + 'synth_midi/' + synth_name + '_se.csv'
            elif program == 'photoscore':
                synth_annot_file = annot_dir + 'synth_midi/' + synth_name + '_ps.csv'
            print(synth_annot_file)
            
            if synth_annot_file in synth_annot_files and (program != 'photoscore' or scoreid != 'chopin_op68no3_v6'):
                mid1 = pretty_midi.PrettyMIDI(synth_file)
                start_time = mid1.estimate_beat_start(candidates=10, tolerance=0.025)
                auto_beat = mid1.get_beats()
                
                synth_timestamps = getMidiRefLocs(synth_annot_file)
                
                print(auto_beat[0:10])
                print(synth_timestamps[0:10])
                
                minLen_t = min(len(synth_timestamps), len(auto_beat))
                allErrs_time.append((np.array(synth_timestamps[:minLen_t]) - np.array(auto_beat[:minLen_t])) * 1000)
                
                minLen_p = min(minLen_t, len(sheet_annot))
                hypPixels = np.interp(auto_beat, synth_timestamps[:minLen_p], sheet_annot[:minLen_p])
                allErrs_pixel.append(hypPixels[:minLen_p] - sheet_annot[:minLen_p])
            
            else:
                allErrs_pixel.append([float('inf')*len(sheet_annot)])
                allErrs_time.append([float('inf')]*len(perf_timestamps))
                
    return allErrs_pixel, allErrs_time

In [None]:
allErrs_se_bm_pix, allErrs_se_bm_t = midiBeatMatch('sharpeye', pieces)

In [None]:
allErrs_ps_bm_pix, allErrs_ps_bm_t = midiBeatMatch('photoscore', pieces)

In [None]:
with open('results/errorData_real_bm.pkl','wb') as f:
    pickle.dump([allErrs_ps_bm_pix, allErrs_ps_bm_t, allErrs_se_bm_pix, allErrs_se_bm_t],f)

## 3. Compare Error to Bootleg System

In [None]:
[allErrs_ps_bm_pix, allErrs_ps_bm_t, allErrs_se_bm_pix, allErrs_se_bm_t] = pickle.load(open('results/errorData_real_bm.pkl', 'rb'))
[pixel_errs_bs, pixel_errs_b1, time_errs_bs, time_errs_b1] = pickle.load(open('results/errorData_real_bootleg.pkl', 'rb'))
[allErrs_ps_as_pix, allErrs_ps_as_t, allErrs_se_as_pix, allErrs_se_as_t] = pickle.load(open('results/errorData_real_as.pkl','rb'))

In [None]:
tols = np.arange(2001)
plt.plot(tols, 100.0*np.array(calcErrorStats(time_errs_b1, tols)), 'k-', label='GL')
plt.plot(tols, 100.0*np.array(calcErrorStats(allErrs_se_bm_t, tols)), 'g-.', label='MBM-se')
plt.plot(tols, 100.0*np.array(calcErrorStats(allErrs_ps_bm_t, tols)), 'r-.', label='MBM-ps')
plt.plot(tols, 100.0*np.array(calcErrorStats(allErrs_se_as_t, tols)), 'g--', label='AS-se')
plt.plot(tols, 100.0*np.array(calcErrorStats(allErrs_ps_as_t, tols)), 'r--', label='AS-ps')
plt.plot(tols, 100.0*np.array(calcErrorStats(time_errs_bs, tols)), 'g-', label='BS')

plt.xlabel('Error Tolerance (milliseconds)')
plt.ylabel('Error Rate (%)')
plt.gca().set_ylim([0,100])
plt.legend()
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.savefig('figs/error_curves(final).png', dpi=300, bbox_inches = 'tight')
plt.show()