In [1]:
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os.path
import pretty_midi

### Generate ground truth labels

In [3]:
def importScoreInfo(scoreDir):
    d = {}
    for csvfile in glob.glob("{}/p*.scoreinfo.csv".format(scoreDir)):
        pieceStr = os.path.basename(csvfile).split('.')[0]  # e.g. 'p7'
        d[pieceStr] = {}
        with open(csvfile, 'r') as f:
            next(f) # skip header
            for line in f:
                parts = line.rstrip().split(',')
                linenum = int(parts[0])
                startmeasure = int(parts[1])
                endmeasure = int(parts[2])
                d[pieceStr][linenum] = (startmeasure, endmeasure)
    return d

In [4]:
scoreDir = 'data/score_info'
scoreInfo = importScoreInfo(scoreDir)

In [5]:
def importMidiInfo(midiInfoDir, midiDir):
    d = {}
    for csvfile in glob.glob("{}/p*_midinfo.csv".format(midiInfoDir)):
        pieceStr = os.path.basename(csvfile).split('_')[0]  # e.g. 'p7'
        d[pieceStr] = {}
        with open(csvfile, 'r') as f:
            for line in f:
                parts = line.rstrip().split(',')
                measure = int(parts[0])
                time = float(parts[1])
                d[pieceStr][measure] = time
        
        # add an additional entry to indicate the total duration
        midfile = "{}/{}.mid".format(midiDir, pieceStr)
        mid = pretty_midi.PrettyMIDI(midfile)
        totalDur = mid.get_piano_roll().shape[1] * .01 # default fs = 100
        d[pieceStr][measure+1] = totalDur
                
    return d

In [6]:
midiInfoDir = 'data/midi_info'
midiDir = 'data/midi'
midiInfo = importMidiInfo(midiInfoDir, midiDir)



In [7]:
def getQueryGroundTruth(infile, multMatchFile, scoreInfo, midiInfo):
    # infers ground truth timestamps for each query
    d = {}
    with open(infile, 'r') as fin: 
        next(fin) # skip header
        for line in fin:

            # get start, end lines
            parts = line.rstrip().split(',')  # e.g. 'p1_q1,0,3'
            queryStr = parts[0]
            startLine = int(parts[1])
            endLine = int(parts[2])

            # infer start, end measure
            pieceStr = queryStr.split('_')[0]
            #print("%s,%s,%s" % (queryStr, startLine,endLine))            
            startMeasure = scoreInfo[pieceStr][startLine][0]
            endMeasure = scoreInfo[pieceStr][endLine][1]

            # infer start, end time
            #print("%s,%s,%s" % (queryStr, startMeasure, endMeasure))
            startTime = midiInfo[pieceStr][startMeasure]
            endTime = midiInfo[pieceStr][endMeasure+1] # ends on downbeat of next measure

            d[queryStr] = [(startTime, endTime, startMeasure, endMeasure, startLine, endLine)]

    addMultipleMatches(d, multMatchFile, scoreInfo, midiInfo)
            
    return d                

In [8]:
def addMultipleMatches(d, multMatchFile, scoreInfo, midiInfo):
    # some queries match more than 1 segment of the score, these are indicated in multMatchFile
    with open(multMatchFile, 'r') as f:
        for line in f:
            
            # parse line 
            parts = line.rstrip().split(',')  # e.g. 'p31_q8,L3m6,L5m1'
            queryStr = parts[0]
            pieceStr = queryStr.split('_')[0]
            startStr = parts[1]
            endStr = parts[2]
            
            # infer start, end measure
            startLine = int(startStr.split('m')[0][1:])
            endLine = int(endStr.split('m')[0][1:])
            startOffset = int(startStr.split('m')[1])
            endOffset = int(endStr.split('m')[1])
            startMeasure = scoreInfo[pieceStr][startLine][0] + startOffset - 1
            endMeasure = scoreInfo[pieceStr][endLine][0] + endOffset - 1
            
            # infer start, end time
            startTime = midiInfo[pieceStr][startMeasure]
            endTime = midiInfo[pieceStr][endMeasure+1] # ends on downbeat of next measure
            
            tup = (startTime, endTime, startMeasure, endMeasure, startStr, endStr) # startStr more informative than startLine
            d[queryStr].append(tup)
            
    return d

In [9]:
def saveQueryInfoToFile(d, outfile):
    with open(outfile, 'w') as f:
        for query in sorted(d):
            for (tstart, tend, mstart, mend, lstart, lend) in d[query]:
                f.write('{},{:.2f},{:.2f},{},{},{},{}\n'.format(query, tstart, tend, mstart, mend, lstart, lend))

In [10]:
queryInfoFile = 'data/query_info/query_info.csv' # to read
multMatchesFile = 'data/query_info/query.multmatches' # to read
queryGTFile = 'data/query_info/query.gt' # to write
queryInfo = getQueryGroundTruth(queryInfoFile, multMatchesFile, scoreInfo, midiInfo)
saveQueryInfoToFile(queryInfo, queryGTFile)

### Evaluate system performance

In [11]:
def readGroundTruthLabels(gtfile):
    d = {}
    with open(gtfile, 'r') as f:
        for line in f:
            parts = line.rstrip().split(',') # e.g. 'p1_q1,1.55,32.59'
            queryStr = parts[0]
            tstart = float(parts[1])
            tend = float(parts[2])
            if queryStr in d:
                d[queryStr].append((tstart, tend))
            else:
                d[queryStr] = [(tstart, tend)]
    return d

In [12]:
def readHypothesisFiles(hypdir):
    l = []
    for hypfile in sorted(glob.glob("{}/*.hyp".format(hypdir))):
        with open(hypfile, 'r') as f:
            line = next(f)
            parts = line.rstrip().split(',')
            query = parts[0]  # e.g. p1_q1
            tstart = float(parts[1])
            tend = float(parts[2])
            l.append((query, tstart, tend))
    return l

In [13]:
def calcPrecisionRecall(hypdir, gtfile):
    d = readGroundTruthLabels(gtfile)
    hyps = readHypothesisFiles(hypdir)
    hypinfo = [] 
    overlapTotal, hypTotal, refTotal = (0,0,0)
    for (queryid, hypstart, hypend) in hyps:
        refSegments = d[queryid]
        idxMax = 0
        overlapMax = 0
        for i, refSegment in enumerate(refSegments): # find ref segment with most overlap
            overlap = calcOverlap((hypstart, hypend), refSegment)
            if overlap > overlapMax:
                idxMax = i
                overlapMax = overlap
        hyplen = hypend - hypstart
        reflen = refSegments[idxMax][1] - refSegments[idxMax][0]        
        overlapTotal += overlapMax
        hypTotal += hyplen
        refTotal += reflen
        hypinfo.append((queryid, overlapMax, refSegments[idxMax][0], refSegments[idxMax][1], idxMax)) # keep for error analysis
    P = overlapTotal / hypTotal
    R = overlapTotal / refTotal
    F = 2 * P * R / (P + R)
    return F, P, R, hypinfo

In [14]:
def calcOverlap(seg1, seg2):
    overlap_lb = max(seg1[0], seg2[0])
    overlap_ub = min(seg1[1], seg2[1])
    overlap = np.clip(overlap_ub - overlap_lb, 0, None)
    return overlap    

In [18]:
hypdir = 'experiments/exp2/hyp'
F, P, R, hypinfo = calcPrecisionRecall(hypdir, queryGTFile)

In [20]:
F, P, R, len(hypinfo)

(0.8581609057904117, 0.8894500066157977, 0.8289983744189549, 900)

### Investigate Errors

In [21]:
def printDebuggingInfo(hypdir, gtfile, scoreInfo, midiInfo, queryInfo, hypInfo):
    d = readGroundTruthLabels(gtfile)
    hyps = readHypothesisFiles(hypdir)
    for i, (query, hyp_tstart, hyp_tend) in enumerate(hyps):
        
        # hyp and ref info (sec)
        piece = query.split('_')[0]
        _, overlap, ref_tstart, ref_tend, bestIdx = hypInfo[i]
        
        # hyp and ref info (measures)
        interp_m = list(midiInfo[piece].keys())
        interp_t = [midiInfo[piece][m] for m in interp_m]
        hyp_mstart, hyp_mend, ref_mstart, ref_mend = np.interp([hyp_tstart, hyp_tend, ref_tstart, ref_tend], interp_t, interp_m)
        moverlap = calcOverlap((hyp_mstart, hyp_mend),(ref_mstart, ref_mend))
        
        # hyp and ref info (line # + measure offset)
        hyp_lstart, hyp_lstartoff = getLineNumberMeasureOffset(hyp_mstart, scoreInfo[piece])
        hyp_lend, hyp_lendoff = getLineNumberMeasureOffset(hyp_mend, scoreInfo[piece])
        ref_lstart = queryInfo[query][bestIdx][4]
        ref_lend = queryInfo[query][bestIdx][5]
        
        # compare in sec
        print("{}: hyp ({:.1f} s,{:.1f} s), ref ({:.1f} s,{:.1f} s), overlap {:.1f} of {:.1f} s".format(query, hyp_tstart, hyp_tend, ref_tstart, ref_tend, overlap, ref_tend - ref_tstart))
        
        # compare in measure numbers
        #print("\thyp ({:.1f} m, {:.1f} m), ref ({:.1f} m, {:.1f} m), overlap {:.1f} m".format(hyp_mstart, hyp_mend, ref_mstart, ref_mend, moverlap))
        
        # compare in line + measure offset
        print("\thyp (ln {} m{:.1f}, ln {} m{:.1f}), ref (ln {}, ln {})".format(hyp_lstart, hyp_lstartoff, hyp_lend, hyp_lendoff, ref_lstart, ref_lend))
    return

In [22]:
def getLineNumberMeasureOffset(measureNum, d):
    line = -1
    moffset = -1
    for key in d:
        lb, ub = d[key] # line start, end measure 
        if measureNum >= lb and measureNum < ub + 1:
            line = key
            moffset = measureNum - lb + 1
            break
    return line, moffset

In [23]:
printDebuggingInfo(hypdir, queryGTFile, scoreInfo, midiInfo, queryInfo, hypinfo)

p100_q1: hyp (192.0 s,226.4 s), ref (193.3 s,227.4 s), overlap 33.2 of 34.1 s
	hyp (ln 18 m3.7, ln 21 m3.7), ref (ln 19, ln 21)
p100_q10: hyp (170.5 s,204.3 s), ref (170.5 s,204.6 s), overlap 33.8 of 34.1 s
	hyp (ln 17 m1.0, ln 19 m3.9), ref (ln 17, ln 19)
p100_q2: hyp (226.9 s,261.0 s), ref (227.4 s,261.5 s), overlap 33.6 of 34.1 s
	hyp (ln 21 m3.9, ln 24 m3.9), ref (ln 22, ln 24)
p100_q3: hyp (261.4 s,283.3 s), ref (261.5 s,291.6 s), overlap 21.8 of 30.1 s
	hyp (ln 24 m4.0, ln 27 m1.7), ref (ln 25, ln 27)
p100_q4: hyp (10.4 s,45.5 s), ref (11.4 s,45.5 s), overlap 34.1 of 34.1 s
	hyp (ln 1 m3.7, ln 4 m4.0), ref (ln 2, ln 4)
p100_q5: hyp (79.3 s,102.0 s), ref (79.6 s,102.3 s), overlap 22.4 of 22.7 s
	hyp (ln 7 m3.9, ln 9 m3.9), ref (ln 8, ln 9)
p100_q6: hyp (0.0 s,33.6 s), ref (0.0 s,34.1 s), overlap 33.6 of 34.1 s
	hyp (ln 1 m1.0, ln 3 m3.9), ref (ln 1, ln 3)
p100_q7: hyp (55.9 s,90.6 s), ref (56.8 s,91.0 s), overlap 33.8 of 34.1 s
	hyp (ln 5 m3.7, ln 8 m3.9), ref (ln 6, ln 8)
p100_q8

p75_q1: hyp (63.9 s,93.1 s), ref (51.3 s,93.5 s), overlap 29.3 of 42.3 s
	hyp (ln 5 m3.9, ln 7 m3.9), ref (ln 5, ln 7)
p75_q10: hyp (105.7 s,134.2 s), ref (109.2 s,134.6 s), overlap 25.1 of 25.4 s
	hyp (ln 8 m3.4, ln 10 m3.9), ref (ln 9, ln 10)
p75_q2: hyp (128.6 s,158.4 s), ref (121.8 s,165.9 s), overlap 29.8 of 44.0 s
	hyp (ln 10 m2.6, ln 12 m3.0), ref (ln 10, ln 12)
p75_q3: hyp (134.2 s,157.5 s), ref (134.6 s,165.9 s), overlap 22.9 of 31.3 s
	hyp (ln 10 m3.9, ln 12 m2.8), ref (ln 11, ln 12)
p75_q4: hyp (0.0 s,25.4 s), ref (0.0 s,25.9 s), overlap 25.4 of 25.9 s
	hyp (ln 1 m1.0, ln 2 m3.9), ref (ln 1, ln 2)
p75_q5: hyp (70.3 s,110.2 s), ref (79.1 s,121.8 s), overlap 31.1 of 42.7 s
	hyp (ln 6 m2.2, ln 9 m1.3), ref (ln 7, ln 9)
p75_q6: hyp (0.0 s,38.5 s), ref (0.0 s,38.5 s), overlap 38.5 of 38.5 s
	hyp (ln 1 m1.0, ln 3 m4.0), ref (ln 1, ln 3)
p75_q7: hyp (12.7 s,50.9 s), ref (12.7 s,51.3 s), overlap 38.2 of 38.5 s
	hyp (ln 2 m1.0, ln 4 m3.9), ref (ln 2, ln 4)
p75_q8: hyp (50.9 s,78.7 s)