In [None]:
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os.path
import glob
import pickle
import multiprocessing

### Evaluate Retrieval Accuracy

In [None]:
def evaluateAll(sampleLens, dbSizes, dbSamplings, hyp_dir, savefile = None):
    mrrInfo = np.zeros((len(sampleLens), len(dbSizes), dbSamplings))
    for i, sampleLen in enumerate(sampleLens):
        print('Evaluating sample len = {}'.format(sampleLen))
        for j, dbSize in enumerate(dbSizes):
            print('  dbSize = {}'.format(dbSize))
            for k in range(dbSamplings):
                mrrInfo[i,j,k] = evaluateSingle(sampleLen, dbSize, hyp_dir, k)
        
    if savefile:
        with open(savefile, 'wb') as f:
            pickle.dump({'mrrInfo': mrrInfo, 'sampleLens': sampleLens, 'dbSizes': dbSizes, 'dbSamplings': dbSamplings}, f)

    return mrrInfo

In [None]:
def evaluateSingle(sampleLen, dbSize, hyp_dir, random_seed = 0):
    np.random.seed(random_seed)
    numQueries = len(glob.glob('{}/sample{}/*.hyp'.format(hyp_dir, sampleLen)))
    gtRanks = []
    for pieceNum in range(1, numQueries+1):
        if pieceNum % 10 == 1 or pieceNum % 10 == 5: # skip training queries
            continue 
        hyp_file = '{}/sample{}/p{}.hyp'.format(hyp_dir, sampleLen, pieceNum)
        d = loadPickle(hyp_file)
        fullDBSize = d['dbSize']
        rankings = d['sorted_both']
        dbPieces = getDBSampling(fullDBSize, pieceNum, dbSize)
        for i in range(rankings.shape[0]):
            gtRanks.append(determineGroundTruthRank(rankings[i,:], dbPieces, pieceNum))
    mrr = np.mean(1/np.array(gtRanks))
    return mrr

In [None]:
def evaluateMultithreaded(sampleLen, dbSize, hyp_dir, dbSamplings, savefile):
    print('Processing sampleLen {}, dbSize {}'.format(sampleLen, dbSize))
    result = []
    for k in range(dbSamplings):
        mrr = evaluateSingle(sampleLen, dbSize, hyp_dir, k)
        result.append(mrr)
    result = np.array(result)
    
    with open(savefile, 'wb') as f:
        pickle.dump(result, f)

In [None]:
def loadPickle(pkl_file):
    with open(pkl_file, 'rb') as f:
        d = pickle.load(f)
    return d

In [None]:
def getDBSampling(fullSize, refId, sampleSize):
    withoutRef = np.delete(np.arange(1, fullSize+1), refId-1)
    sampling = np.random.choice(withoutRef, size=sampleSize-1, replace=False)
    sampling = np.append(sampling, refId)
    return sampling

In [None]:
def determineGroundTruthRank(predRanking, dbPieces, refId):
    rank = 0
    for pieceNum in predRanking:
        if pieceNum in dbPieces:
            rank += 1
        if pieceNum == refId:
            break
    return rank

In [None]:
# # use single CPU
# sampleLens = [10, 20, 50, 100, 200, 500, 1000, 100000]
# dbSizes = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000]
# numDBSamplings = 10
# hyp_dir = 'hyps'
# savefile = 'mrr.pkl'
# evaluateAll(sampleLens, dbSizes, numDBSamplings, hyp_dir, savefile)

In [None]:
# use multiple CPUs
sampleLens = [10, 20, 50, 100, 200, 500, 1000, 100000]
dbSizes = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000]
numDBSamplings = 10
hyp_dir = 'hyps'
eval_dir = 'results'
savefile = 'mrr.pkl'

# prep output directory
if not os.path.isdir(eval_dir):
    os.makedirs(eval_dir)

# number of cores to use
n_cores = 25 #multiprocessing.cpu_count()

# prep inputs for parallelization
inputs = []
for i, sampleLen in enumerate(sampleLens):
    for j, dbSize in enumerate(dbSizes):
        outfile = '{}/sampleLen{}_dbSize{}.pkl'.format(eval_dir, sampleLen, dbSize)
        inputs.append((sampleLen, dbSize, hyp_dir, numDBSamplings, outfile))
            
# process queries in parallel
pool = multiprocessing.Pool(processes=n_cores)
outputs = list(pool.starmap(evaluateMultithreaded, inputs))

In [None]:
def aggregateResults(eval_dir, sampleLens, dbSizes, numDBSamplings, savefile):
    mrr = np.zeros((len(sampleLens), len(dbSizes), numDBSamplings))
    for i, sampleLen in enumerate(sampleLens):
        for j, dbSize in enumerate(dbSizes):
            pkl_file = '{}/sampleLen{}_dbSize{}.pkl'.format(eval_dir, sampleLen, dbSize)
            with open(pkl_file, 'rb') as f:
                d = pickle.load(f)
            mrr[i,j,:] = d
    
    if savefile:
        with open(savefile, 'wb') as f:
            pickle.dump({'mrrInfo': mrr, 'sampleLens': sampleLens, 'dbSizes': dbSizes, 'dbSamplings': numDBSamplings}, f)            
    return

In [None]:
aggregateResults(eval_dir, sampleLens, dbSizes, numDBSamplings, savefile)

### Visualize Results

In [None]:
d = loadPickle(savefile)
sampleLens = d['sampleLens']
dbSizes = d['dbSizes']
mrr = d['mrrInfo']
means = np.mean(mrr, axis=2)
stdevs = np.std(mrr, axis=2)

In [None]:
ind = np.arange(len(dbSizes))  # the x locations for the groups
width = 0.1  # the width of the bars

fig, ax = plt.subplots(figsize=(12,5))
for i, sampleLen in enumerate(sampleLens):
    if i == len(sampleLens) - 1:
        labelStr = 'Full MIDI File'
    else:
        labelStr = '{}'.format(sampleLen)
    ax.bar(ind + width*(i-3.5), means[i,:], width, label=labelStr)
    #ax.bar(ind + width*(i-3.5), means[i,:], width, yerr=stdevs[i,:], label='SampleLen = {}'.format(sampleLen))

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('MRR')
ax.set_xticks(ind)
ax.set_xticklabels([str(dbSize) for dbSize in dbSizes])
ax.set_xlabel('Database Size')
ax.set_ylim(top=1)
ax.legend(title = 'Query Length', loc = 'upper center', framealpha = 1.0, ncol=len(sampleLens), bbox_to_anchor=(0.5, 1.15))
ax.yaxis.grid(True, linestyle='--')
#fig.tight_layout()
#plt.savefig('accuracy.png')
plt.show()

### Evaluate Runtime

In [None]:
def getRuntimeInfo(sampleLens, hyp_dir, savefile = None):
    '''
    Compute average search time per query.
    '''
    t_avgs = []
    for i, sampleLen in enumerate(sampleLens):
        hypdir = '{}/sample{}'.format(hyp_dir, sampleLen)
        totalDur = 0
        fileCount = 0
        for hypfile in glob.glob('{}/*.hyp'.format(hypdir)):
            d = loadPickle(hypfile)
            totalDur += d['profileDur']
            fileCount += 1
        t_avgs.append(totalDur / fileCount)
        
    if savefile:
        with open(savefile, 'wb') as f:
            pickle.dump({'t_avgs': t_avgs}, f)

    return t_avgs

In [None]:
t_avgs = getRuntimeInfo(sampleLens, hyp_dir, 'runtime.pkl')

In [None]:
t_avgs # average search times on full 5k database by sample query length

In [None]:
def calcMidiFeatureRuntime(midi_feat_dir):
    totalDur = 0
    numFiles = 0
    for pkl_file in glob.glob('{}/*.pkl'.format(midi_feat_dir)):
        with open(pkl_file, 'rb') as f:
            d = loadPickle(pkl_file)
        totalDur += d['dur']
        numFiles += 1
    avgTime = totalDur / numFiles
    print('Average time to compute MIDI bootleg score: {:.2f} sec'.format(avgTime))

In [None]:
calcMidiFeatureRuntime('midi_feat')