In [1]:
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os.path
import pretty_midi

# to remove later
import re

### Generate ground truth labels

In [3]:
def importScoreInfo(scoreDir):
    d = {}
    for csvfile in glob.glob("{}/p*1.scoreinfo.csv".format(scoreDir)): ### TO DO: REMOVE 1
        pieceStr = os.path.basename(csvfile).split('.')[0]  # e.g. 'p7'
        d[pieceStr] = {}
        with open(csvfile, 'r') as f:
            next(f) # skip header
            for line in f:
                parts = line.rstrip().split(',')
                linenum = int(parts[0])
                startmeasure = int(parts[1])
                endmeasure = int(parts[2])
                d[pieceStr][linenum] = (startmeasure, endmeasure)
    return d

In [4]:
scoreDir = 'score_info'
scoreInfo = importScoreInfo(scoreDir)

In [5]:
def importMidiInfo(midiInfoDir, midiDir):
    d = {}
    for csvfile in glob.glob("{}/p*1_midinfo.csv".format(midiInfoDir)): ### TO DO: REMOVE 1
        pieceStr = os.path.basename(csvfile).split('_')[0]  # e.g. 'p7'
        d[pieceStr] = {}
        with open(csvfile, 'r') as f:
            for line in f:
                parts = line.rstrip().split(',')
                measure = int(parts[0])
                time = float(parts[1])
                d[pieceStr][measure] = time
        
        # add an additional entry to indicate the total duration
        midfile = "{}/{}.mid".format(midiDir, pieceStr)
        mid = pretty_midi.PrettyMIDI(midfile)
        totalDur = mid.get_piano_roll().shape[1] * .01 # default fs = 100
        d[pieceStr][measure+1] = totalDur
                
    return d

In [6]:
midiInfoDir = 'midi_info'
midiDir = 'midi'
midiInfo = importMidiInfo(midiInfoDir, midiDir)



In [7]:
def getQueryGroundTruth(infile, scoreInfo, midiInfo):
    # infers ground truth timestamps for each query
    d = {}
    with open(infile, 'r') as fin: 
        next(fin) # skip header
        for line in fin:

            # get start, end lines
            parts = line.rstrip().split(',')  # e.g. 'p1_q1,0,3'
            queryStr = parts[0]
            startLine = int(parts[1])
            endLine = int(parts[2])

            # infer start, end measure
            pieceStr = queryStr.split('_')[0]

            m = re.search(r'^p\d*1$', pieceStr) ### TO DO: REMOVE
            if not m: # doesn't match
                continue

            #print("%s,%s,%s" % (queryStr, startLine,endLine))
            startMeasure = scoreInfo[pieceStr][startLine][0]
            endMeasure = scoreInfo[pieceStr][endLine][1]

            # infer start, end time
            #print("%s,%s,%s" % (queryStr, startMeasure, endMeasure))
            startTime = midiInfo[pieceStr][startMeasure]
            endTime = midiInfo[pieceStr][endMeasure+1] # ends on downbeat of next measure

            d[queryStr] = (startTime, endTime, startMeasure, endMeasure, startLine, endLine)
                
    return d                

In [8]:
def saveQueryInfoToFile(d, outfile):
    with open(outfile, 'w') as f:
        for query in d:
            (tstart, tend, mstart, mend, lstart, lend) = d[query]
            f.write('{},{:.2f},{:.2f},{},{},{},{}\n'.format(query, tstart, tend, mstart, mend, lstart, lend))

In [9]:
queryInfoFile = 'query_info/query_info.csv' # to read
queryGTFile = 'query_info/query.gt' # to write
queryInfo = getQueryGroundTruth(queryInfoFile, scoreInfo, midiInfo)
saveQueryInfoToFile(queryInfo, queryGTFile)

### Evaluate system performance

In [19]:
def readGroundTruthLabels(gtfile):
    d = {}
    with open(gtfile, 'r') as f:
        for line in f:
            parts = line.rstrip().split(',') # e.g. 'p1_q1,1.55,32.59'
            queryStr = parts[0]
            tstart = float(parts[1])
            tend = float(parts[2])
            d[queryStr] = (tstart, tend)
    return d

In [20]:
def readHypothesisFiles(hypdir):
    l = []
    for hypfile in sorted(glob.glob("{}/*.hyp".format(hypdir))):
        with open(hypfile, 'r') as f:
            line = next(f)
            parts = line.rstrip().split(',')
            query = parts[0]  # e.g. p1_q1
            tstart = float(parts[1])
            tend = float(parts[2])
            l.append((query, tstart, tend))
    return l

In [21]:
def calcPrecisionRecall(hypdir, gtfile):
    d = readGroundTruthLabels(gtfile)
    hyps = readHypothesisFiles(hypdir)
    hypinfo = [] 
    overlapTotal, hypTotal, refTotal = (0,0,0)
    for (queryid, hypstart, hypend) in hyps:
        refSegment = d[queryid]
        overlap = calcOverlap((hypstart, hypend), refSegment)
        hyplen = hypend - hypstart
        reflen = refSegment[1] - refSegment[0]        
        overlapTotal += overlap
        hypTotal += hyplen
        refTotal += reflen
        hypinfo.append((queryid, overlap, hyplen, reflen)) # keep for error analysis
    P = overlapTotal / hypTotal
    R = overlapTotal / refTotal
    F = 2 * P * R / (P + R)
    return F, P, R, hypinfo

In [22]:
def calcOverlap(seg1, seg2):
    overlap_lb = max(seg1[0], seg2[0])
    overlap_ub = min(seg1[1], seg2[1])
    overlap = np.clip(overlap_ub - overlap_lb, 0, None)
    return overlap    

In [33]:
hypdir = 'experiments/exp1/hyp/cur'
F, P, R, hypinfo = calcPrecisionRecall(hypdir, queryGTFile)

In [34]:
F, P, R, len(hypinfo)

(0.5845634183414965, 0.521830900976963, 0.6644398372110881, 10)

### Investigate Errors

In [35]:
def printDebuggingInfo(hypdir, gtfile, scoreInfo, midiInfo, queryInfo):
    d = readGroundTruthLabels(gtfile)
    hyps = readHypothesisFiles(hypdir)
    for i, (query, hyp_tstart, hyp_tend) in enumerate(hyps):
        
        # hyp and ref info (sec)
        piece = query.split('_')[0]
        ref_tstart, ref_tend = d[query]
        overlap = calcOverlap((hyp_tstart, hyp_tend), (ref_tstart, ref_tend))
        
        # hyp and ref info (measures)
        interp_m = list(midiInfo[piece].keys())
        interp_t = [midiInfo[piece][m] for m in interp_m]
        hyp_mstart, hyp_mend, ref_mstart, ref_mend = np.interp([hyp_tstart, hyp_tend, ref_tstart, ref_tend], interp_t, interp_m)
        moverlap = calcOverlap((hyp_mstart, hyp_mend),(ref_mstart, ref_mend))
        
        # hyp and ref info (line # + measure offset)
        hyp_lstart, hyp_lstartoff = getLineNumberMeasureOffset(hyp_mstart, scoreInfo[piece])
        hyp_lend, hyp_lendoff = getLineNumberMeasureOffset(hyp_mend, scoreInfo[piece])
        ref_lstart = queryInfo[query][4]
        ref_lend = queryInfo[query][5]
        
        # compare in sec
        print("{}: hyp ({:.1f} s,{:.1f} s), ref ({:.1f} s,{:.1f} s), overlap {:.1f} of {:.1f} s".format(query, hyp_tstart, hyp_tend, ref_tstart, ref_tend, overlap, ref_tend - ref_tstart))
        
        # compare in measure numbers
        #print("\thyp ({:.1f} m, {:.1f} m), ref ({:.1f} m, {:.1f} m), overlap {:.1f} m".format(hyp_mstart, hyp_mend, ref_mstart, ref_mend, moverlap))
        
        # compare in line + measure offset
        print("\thyp (ln {} m{:.1f}, ln {} m{:.1f}), ref (ln {}, ln {})".format(hyp_lstart, hyp_lstartoff, hyp_lend, hyp_lendoff, ref_lstart, ref_lend))
    return

In [36]:
def getLineNumberMeasureOffset(measureNum, d):
    line = -1
    moffset = -1
    for key in d:
        lb, ub = d[key] # line start, end measure 
        if measureNum >= lb and measureNum < ub + 1:
            line = key
            moffset = measureNum - lb + 1
            break
    return line, moffset

In [37]:
printDebuggingInfo(hypdir, queryGTFile, scoreInfo, midiInfo, queryInfo)

p51_q1: hyp (0.0 s,11.5 s), ref (0.0 s,8.8 s), overlap 8.8 of 8.8 s
	hyp (ln 1 m1.0, ln 3 m4.7), ref (ln 1, ln 2)
p51_q10: hyp (0.2 s,10.7 s), ref (43.2 s,52.3 s), overlap 0.0 of 9.1 s
	hyp (ln 1 m1.2, ln 3 m3.7), ref (ln 10, ln 11)
p51_q2: hyp (0.0 s,26.0 s), ref (0.0 s,13.2 s), overlap 13.2 of 13.2 s
	hyp (ln 1 m1.0, ln 7 m1.5), ref (ln 1, ln 3)
p51_q3: hyp (0.0 s,20.6 s), ref (4.4 s,17.6 s), overlap 13.2 of 13.2 s
	hyp (ln 1 m1.0, ln 5 m5.2), ref (ln 2, ln 4)
p51_q4: hyp (6.4 s,21.9 s), ref (4.4 s,21.2 s), overlap 14.8 of 16.8 s
	hyp (ln 2 m3.7, ln 6 m2.0), ref (ln 2, ln 5)
p51_q5: hyp (15.2 s,24.9 s), ref (17.6 s,25.6 s), overlap 7.3 of 8.1 s
	hyp (ln 4 m3.7, ln 6 m6.0), ref (ln 5, ln 6)
p51_q6: hyp (23.4 s,37.9 s), ref (25.6 s,37.3 s), overlap 11.7 of 11.7 s
	hyp (ln 6 m4.0, ln 9 m1.8), ref (ln 7, ln 8)
p51_q7: hyp (23.4 s,43.4 s), ref (25.6 s,43.2 s), overlap 17.6 of 17.6 s
	hyp (ln 6 m4.0, ln 10 m1.2), ref (ln 7, ln 9)
p51_q8: hyp (11.9 s,29.1 s), ref (31.5 s,48.3 s), overlap 0.