In [1]:
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os.path
import pretty_midi
import pickle

### Generate ground truth labels

In [3]:
def importScoreInfo(scoreDir):
    d = {}
    for csvfile in glob.glob("{}/p*.scoreinfo.csv".format(scoreDir)):
        pieceStr = os.path.basename(csvfile).split('.')[0]  # e.g. 'p7'
        d[pieceStr] = {}
        with open(csvfile, 'r') as f:
            next(f) # skip header
            for line in f:
                parts = line.rstrip().split(',')
                linenum = int(parts[0])
                startmeasure = int(parts[1])
                endmeasure = int(parts[2])
                d[pieceStr][linenum] = (startmeasure, endmeasure)
    return d

In [4]:
scoreDir = 'data/score_info'
scoreInfo = importScoreInfo(scoreDir)

In [5]:
def importMidiInfo(midiInfoDir, midiDir):
    d = {}
    for csvfile in glob.glob("{}/p*_midinfo.csv".format(midiInfoDir)):
        pieceStr = os.path.basename(csvfile).split('_')[0]  # e.g. 'p7'
        d[pieceStr] = {}
        with open(csvfile, 'r') as f:
            for line in f:
                parts = line.rstrip().split(',')
                measure = int(parts[0])
                time = float(parts[1])
                d[pieceStr][measure] = time
        
        # add an additional entry to indicate the total duration
        midfile = "{}/{}.mid".format(midiDir, pieceStr)
        mid = pretty_midi.PrettyMIDI(midfile)
        totalDur = mid.get_piano_roll().shape[1] * .01 # default fs = 100
        d[pieceStr][measure+1] = totalDur
                
    return d

In [6]:
midiInfoDir = 'data/midi_info'
midiDir = 'data/midi'
midiInfo = importMidiInfo(midiInfoDir, midiDir)



In [7]:
def getQueryGroundTruth(infile, multMatchFile, scoreInfo, midiInfo):
    # infers ground truth timestamps for each query
    d = {}
    with open(infile, 'r') as fin: 
        next(fin) # skip header
        for line in fin:

            # get start, end lines
            parts = line.rstrip().split(',')  # e.g. 'p1_q1,0,3'
            queryStr = parts[0]
            startLine = int(parts[1])
            endLine = int(parts[2])

            # infer start, end measure
            pieceStr = queryStr.split('_')[0]
            #print("%s,%s,%s" % (queryStr, startLine,endLine))            
            startMeasure = scoreInfo[pieceStr][startLine][0]
            endMeasure = scoreInfo[pieceStr][endLine][1]

            # infer start, end time
            #print("%s,%s,%s" % (queryStr, startMeasure, endMeasure))
            startTime = midiInfo[pieceStr][startMeasure]
            endTime = midiInfo[pieceStr][endMeasure+1] # ends on downbeat of next measure

            d[queryStr] = [(startTime, endTime, startMeasure, endMeasure, startLine, endLine)]

    addMultipleMatches(d, multMatchFile, scoreInfo, midiInfo)
            
    return d                

In [8]:
def addMultipleMatches(d, multMatchFile, scoreInfo, midiInfo):
    # some queries match more than 1 segment of the score, these are indicated in multMatchFile
    with open(multMatchFile, 'r') as f:
        for line in f:
            
            # parse line 
            parts = line.rstrip().split(',')  # e.g. 'p31_q8,L3m6,L5m1'
            queryStr = parts[0]
            pieceStr = queryStr.split('_')[0]
            startStr = parts[1]
            endStr = parts[2]
            
            # infer start, end measure
            startLine = int(startStr.split('m')[0][1:])
            endLine = int(endStr.split('m')[0][1:])
            startOffset = int(startStr.split('m')[1])
            endOffset = int(endStr.split('m')[1])
            startMeasure = scoreInfo[pieceStr][startLine][0] + startOffset - 1
            endMeasure = scoreInfo[pieceStr][endLine][0] + endOffset - 1
            
            # infer start, end time
            startTime = midiInfo[pieceStr][startMeasure]
            endTime = midiInfo[pieceStr][endMeasure+1] # ends on downbeat of next measure
            
            tup = (startTime, endTime, startMeasure, endMeasure, startStr, endStr) # startStr more informative than startLine
            d[queryStr].append(tup)
            
    return d

In [9]:
def saveQueryInfoToFile(d, outfile):
    with open(outfile, 'w') as f:
        for query in sorted(d):
            for (tstart, tend, mstart, mend, lstart, lend) in d[query]:
                f.write('{},{:.2f},{:.2f},{},{},{},{}\n'.format(query, tstart, tend, mstart, mend, lstart, lend))

In [10]:
queryInfoFile = 'data/query_info/query_info.csv' # to read
multMatchesFile = 'data/query_info/query.multmatches' # to read
queryGTFile = 'data/query_info/query.gt' # to write
queryInfo = getQueryGroundTruth(queryInfoFile, multMatchesFile, scoreInfo, midiInfo)
saveQueryInfoToFile(queryInfo, queryGTFile)

### Evaluate system performance

In [11]:
def readGroundTruthLabels(gtfile):
    d = {}
    with open(gtfile, 'r') as f:
        for line in f:
            parts = line.rstrip().split(',') # e.g. 'p1_q1,1.55,32.59'
            queryStr = parts[0]
            tstart = float(parts[1])
            tend = float(parts[2])
            if queryStr in d:
                d[queryStr].append((tstart, tend))
            else:
                d[queryStr] = [(tstart, tend)]
    return d

In [29]:
def loadHypothesisFile(hypfile):
    with open(hypfile, 'rb') as f:
        d = pickle.load(f)
    return d

In [38]:
def calcPrecisionRecall(hypdir, gtfile):
    d = readGroundTruthLabels(gtfile)
    hypfiles = sorted(glob.glob("{}/*.hyp".format(hypdir)))
    hypinfo = [] 
    overlapTotal, hypTotal, refTotal = (0,0,0)
    for hypfile in hypfiles:
        
        # read hypothesis data
        hypdata = loadHypothesisFile(hypfile)
        queryid = hypdata['query'] # e.g. p1_q7
        if hypdata['results'] == (0,0): # no noteheads detected
            hyp_pieceid, hyp_score, hyp_tstart, hyp_tend = None, 0, 0, 0 
        else:
            hyp_pieceid, hyp_score, hyp_tstart, hyp_tend = hypdata['results'][0] # top result
            
        # ground truth
        ref_pieceid = queryid.split('_')[0]
        refSegments = d[queryid]        
        
        # calculate overlap
        idxMax = 0
        overlapMax = 0
        if hyp_pieceid == ref_pieceid: # picks correct piece
            for i, refSegment in enumerate(refSegments): # find ref segment with most overlap
                overlap = calcOverlap((hyp_tstart, hyp_tend), refSegment)
                if overlap > overlapMax:
                    idxMax = i
                    overlapMax = overlap
                    
        # accumulate stats
        hyplen = hyp_tend - hyp_tstart
        reflen = refSegments[idxMax][1] - refSegments[idxMax][0]        
        overlapTotal += overlapMax
        hypTotal += hyplen
        refTotal += reflen
        hypinfo.append((queryid, overlapMax, refSegments[idxMax][0], refSegments[idxMax][1], idxMax)) # keep for error analysis                    
        
    P = overlapTotal / hypTotal
    R = overlapTotal / refTotal
    F = 2 * P * R / (P + R)
    return F, P, R, hypinfo, hypfiles

In [39]:
def calcOverlap(seg1, seg2):
    overlap_lb = max(seg1[0], seg2[0])
    overlap_ub = min(seg1[1], seg2[1])
    overlap = np.clip(overlap_ub - overlap_lb, 0, None)
    return overlap    

In [47]:
def calcRankStats(hypdir, N = 100):
    hypfiles = sorted(glob.glob("{}/*.hyp".format(hypdir)))
    ranks = [] # rank of ground truth piece
    for hypfile in hypfiles:
        
        # rank of ground truth
        hypdata = loadHypothesisFile(hypfile)
        queryid = hypdata['query'] # e.g. p1_q7
        ref_pieceid = queryid.split('_')[0]
        if hypdata['results'] == (0,0): # no noteheads detected
            ranks.append(np.inf)
        else:
            for i, tup in enumerate(hypdata['results']):
                hyp_pieceid = tup[0]
                if hyp_pieceid == ref_pieceid:
                    ranks.append(i+1)
                    break
        
    # compute MRR
    mrr = np.mean([1.0/r for r in ranks])
    
    # compute topN
    topn = np.zeros(N)
    for n in range(N):
        topn[n] = np.mean(np.array(ranks) <= n+1)
            
    return mrr, topn, ranks

In [60]:
hypdir = 'experiments/search1/hyp'
F, P, R, hypinfo, hypfiles = calcPrecisionRecall(hypdir, queryGTFile)
mrr, topn, ranks = calcRankStats(hypdir)

In [61]:
F, P, R, len(hypinfo)

(0.7839038672753129, 0.8387211076466727, 0.7358125382173216, 900)

In [62]:
mrr, topn[0:5]

(0.8370295552326883,
 array([0.81222222, 0.83666667, 0.85111111, 0.85555556, 0.86222222]))

### Investigate Errors

In [65]:
def printDebuggingInfo(hypfiles, gtfile, scoreInfo, midiInfo, queryInfo, hypInfo, rankInfo):
    d = readGroundTruthLabels(gtfile)
    for i, hypfile in enumerate(hypfiles): # (query, hyp_tstart, hyp_tend)
        
        # hyp and ref info (sec)
        hypdata = loadHypothesisFile(hypfile)
        query = hypdata['query']
        if hypdata['results'] == (0,0):
            hyp_pieceid, hyp_tstart, hyp_tend = None, 0, 0
        else:
            hyp_pieceid, _, hyp_tstart, hyp_tend = hypdata['results'][0]
        ref_pieceid = query.split('_')[0]
        _, overlap, ref_tstart, ref_tend, bestIdx = hypInfo[i]
        
        # hyp and ref info (measures)
        interp_m = list(midiInfo[hyp_pieceid].keys())
        interp_t = [midiInfo[hyp_pieceid][m] for m in interp_m]
        hyp_mstart, hyp_mend = np.interp([hyp_tstart, hyp_tend], interp_t, interp_m)
        interp_m = list(midiInfo[ref_pieceid].keys())
        interp_t = [midiInfo[ref_pieceid][m] for m in interp_m]
        ref_mstart, ref_mend = np.interp([ref_tstart, ref_tend], interp_t, interp_m)
        if hyp_pieceid == ref_pieceid:
            moverlap = calcOverlap((hyp_mstart, hyp_mend),(ref_mstart, ref_mend))
        else:
            moverlap = 0
        
        # hyp and ref info (line # + measure offset)
        hyp_lstart, hyp_lstartoff = getLineNumberMeasureOffset(hyp_mstart, scoreInfo[hyp_pieceid])
        hyp_lend, hyp_lendoff = getLineNumberMeasureOffset(hyp_mend, scoreInfo[hyp_pieceid])
        ref_lstart = queryInfo[query][bestIdx][4]
        ref_lend = queryInfo[query][bestIdx][5]
        
        # compare in sec
        print("{}: hyp {} ({:.1f} s,{:.1f} s), ref {} ({:.1f} s,{:.1f} s), rank {}, overlap {:.1f} of {:.1f} s".format(query, hyp_pieceid, hyp_tstart, hyp_tend, ref_pieceid, ref_tstart, ref_tend, rankInfo[i], overlap, ref_tend - ref_tstart))
        
        # compare in measure numbers
        #print("\thyp {} ({:.1f} m, {:.1f} m), ref {} ({:.1f} m, {:.1f} m), overlap {:.1f} m".format(hyp_pieceid, hyp_mstart, hyp_mend, ref_pieceid, ref_mstart, ref_mend, moverlap))
        
        # compare in line + measure offset
        print("\thyp {} (ln {} m{:.1f}, ln {} m{:.1f}), ref {} (ln {}, ln {})".format(hyp_pieceid, hyp_lstart, hyp_lstartoff, hyp_lend, hyp_lendoff, ref_pieceid, ref_lstart, ref_lend))
    return

In [66]:
def getLineNumberMeasureOffset(measureNum, d):
    line = -1
    moffset = -1
    for key in d:
        lb, ub = d[key] # line start, end measure 
        if measureNum >= lb and measureNum < ub + 1:
            line = key
            moffset = measureNum - lb + 1
            break
    return line, moffset

In [67]:
printDebuggingInfo(hypfiles, queryGTFile, scoreInfo, midiInfo, queryInfo, hypinfo, ranks)

p100_q1: hyp p100 (192.0 s,227.1 s), ref p100 (193.3 s,227.4 s), rank 1, overlap 33.8 of 34.1 s
	hyp p100 (ln 18 m3.7, ln 21 m3.9), ref p100 (ln 19, ln 21)
p100_q10: hyp p100 (170.5 s,204.3 s), ref p100 (170.5 s,204.6 s), rank 1, overlap 33.8 of 34.1 s
	hyp p100 (ln 17 m1.0, ln 19 m3.9), ref p100 (ln 17, ln 19)
p100_q2: hyp p100 (226.9 s,259.4 s), ref p100 (227.4 s,261.5 s), rank 1, overlap 32.1 of 34.1 s
	hyp p100 (ln 21 m3.9, ln 24 m3.5), ref p100 (ln 22, ln 24)
p100_q3: hyp p100 (261.8 s,283.1 s), ref p100 (261.5 s,291.6 s), rank 1, overlap 21.3 of 30.1 s
	hyp p100 (ln 25 m1.1, ln 27 m1.7), ref p100 (ln 25, ln 27)
p100_q4: hyp p100 (10.4 s,45.5 s), ref p100 (11.4 s,45.5 s), rank 1, overlap 34.1 of 34.1 s
	hyp p100 (ln 1 m3.8, ln 5 m1.0), ref p100 (ln 2, ln 4)
p100_q5: hyp p100 (79.3 s,102.0 s), ref p100 (79.6 s,102.3 s), rank 1, overlap 22.4 of 22.7 s
	hyp p100 (ln 7 m3.9, ln 9 m3.9), ref p100 (ln 8, ln 9)
p100_q6: hyp p100 (0.0 s,33.6 s), ref p100 (0.0 s,34.1 s), rank 1, overlap 33

KeyError: None