Construct Data Pipeline:
Raw Data in Directory --> Extract Semantics Data --> Get Encoding --> Collect into batches --> Yield labels from expert system --> Model Training

In [1]:
import os
import sys
import constraint
import logging

from itertools import chain
from string import digits

import pickle as pkl

from pprint import pprint
from tqdm.notebook import tqdm

sys.path.insert(1, os.path.join(sys.path[0], '..'))
import finger_oracle
# Technically no need to import semantics_generator for this experiment,
# Can simply use semantics file as ground truth
# sys.path.insert(1, os.path.join(sys.path[0], '..'))
# import semantics_generator

In [2]:
DATA_DIR = ['../data/package_aa', '../data/package_ab']
DONE_STATUS_OUTPUT = '../tmp/finger_notation_generator_output.pkl'
IGNORE_LIST_PATH = '../tmp/finger_notation_generator_ignore.txt'

Convert Notes into black/non-black key encoding and semitone distances

In [3]:
# static variables for encoding inputs into black/non-black keys and semitone distances
BLACK_KEYS = set(['C#','Bb','D#','Eb','F#','Gb','G#','Ab','A#','Bb'])
KEY_MAP = {'Cb': -1, 'C': 0, 'C#': 1, 'Db': 1, 'D': 2, 'D#': 3, 'Eb': 3,
           'E': 4, 'Fb': 4, 'E#': 5, 'F': 5, 'F#': 6, 'Gb': 6, 'G': 7, 'G#': 8,
           'Ab': 8, 'A': 9, 'A#': 10, 'Bb': 10, 'B': 11, 'B#': 12}
REMOVE_DIGITS = str.maketrans('', '', digits)
OCTAVE_SEMITONES = 12

In [4]:
# Enumerate black keys. Do not filter by checking for '#' or 'b' in the string, because
# composers sometimes write things like E# or Cb, which are not black keys.
# The string translate() method is faster than iterating manually.
def is_black_key(key):
    key = key.translate(REMOVE_DIGITS)
    return key in BLACK_KEYS

In [5]:
# Get semitone distance
# if no octave digit indicators in input, assume all to be the same
# octave 4 was arbitrarily picked as C4 is middle C.
def get_semitone_distances(notes):
    # fill in missing octave digits
    notes = [x for x in map(lambda n: n if any(i.isdigit() for i in n) else n+'4', notes)]
    
    # compute semitone differences
    # initialize first note semitone distance to 0
    notes = [notes[0]] + notes
    diffs = []
    for i in range(len(notes) - 1):
        before = notes[i]
        after = notes[i+1]
        
        before = int(before[-1]) * OCTAVE_SEMITONES + KEY_MAP[before[:-1]]
        after = int(after[-1]) * OCTAVE_SEMITONES + KEY_MAP[after[:-1]]
        diffs.append(after - before)
    return diffs

In [6]:
# encodes tokens as inputs to model
# for the piano model, we don't really care which octave we're at
# in fact, we can simplify the notes into semitone distance from previous note,
# and whether the current note is a white or black key (affects ease of playing)
# note that the primus dataset does not have double-sharp/flat as inputs
def get_encoding(tokens):
    # combine multiple lines, if any
    tokens = '\t'.join(tokens)
    tokens = tokens.split('\t')
    # only use notes
    tokens = [y for y in filter(lambda x: x.startswith('note-') or x.startswith('gracenote-'), tokens)]
    # ignore note lengths
    tokens = [y for y in map(lambda x: x.split('_')[0][5:] if x.startswith('note-') else x.split('_')[0][10:]
                             , tokens)]
    # get black_keys
    black_keys = [x for x in map(is_black_key, tokens)]
    semitone_distances = get_semitone_distances(tokens)
    
    return [(x,y) for x, y in zip(black_keys, semitone_distances)]
    

Heuristic-based labelling

In [7]:
# Treats this as a constraint satisfaction problem, with heuristics as strict constraints
# An alternative neighbourhood-search algorithm exists at https://qmro.qmul.ac.uk/xmlui/bitstream/handle/123456789/11801/Herremans%20A%20variable%20neighborhood%20search%20algorithm%202015%20Accepted.pdf
# but it is over-engineered for this purpose
def solve(enc):
    variables = [i for i in range(len(enc))]
    is_black_key = [e[0] for e in enc]
    distance = [e[1] for e in enc]
    problem = constraint.Problem()
    
    oracles = []
    for i in range(len(enc) - 1):
        ibk_first, ibk_second = enc[i][0], enc[i+1][0]
        d = enc[i+1][1]
        oracles.append(finger_oracle.FingerOracle(ibk_first, ibk_second, d))
    for v in variables:
        problem.addVariable(v, range(1,6))
    
    # consecutive notes must not be played with the same finger
    for i in range(0, len(variables) - 1):
        problem.addConstraint(constraint.AllDifferentConstraint(), [i, i + 1])
    
    for i in range(len(enc) - 1):
        problem.addConstraint(constraint.FunctionConstraint(oracles[i].is_valid)
                              , [i, i + 1])
    
    # can use problem.getSolutionIter() here to generate multiple solutions
    
    solution = problem.getSolution()
    if not solution:
        return None
    
    output = []
    for i in range(len(solution)):
        output.append(solution[i])
        
    return output

Pull data from data directory/directories

In [8]:
# Get semantic data from dataset
# assumes data path is data_src[i]/folder_name/folder_name.semantic, in-line with primus dataset
def get_semantics_data(data_src, done, ignore):
    for src in data_src:
        for i, folder in enumerate(next(os.walk(src))[1]):
            path_to_file = os.path.join(src, folder, "{}.semantic".format(folder))
            if path_to_file in done or path_to_file in ignore:
                continue
            with open(path_to_file, "r") as f:
                try:
                    yield path_to_file, folder, f.readlines()
                except FileNotFoundError:
                    logging.warn("Data file not found: {}".format(path_to_file))
    

Execution Main

In [9]:
def main():
    done = set()
    ignore = set()
    
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    # check for existence of files
    if not os.path.exists(DONE_STATUS_OUTPUT):
        with open(DONE_STATUS_OUTPUT, 'wb+') as f:
            pass
    if not os.path.exists(IGNORE_LIST_PATH):
        with open(IGNORE_LIST_PATH, 'w+') as f:
            pass
    
    # load file which are already done and ignore list
    with open(DONE_STATUS_OUTPUT, 'rb') as f:
        try:
            done = pkl.load(f)
        except EOFError:
            done = set()
    with open(IGNORE_LIST_PATH, 'r') as f:
        for line in f:
            line = line.rstrip()
            if line:
                ignore.add(line)
    
    last_file = "No File"
    try:
        with tqdm() as pbar:
            pbar.set_description("Generating finger positions")
            for i, (path_to_file, filename, semantics_data) in enumerate(get_semantics_data(DATA_DIR, done, ignore)):
                last_file = path_to_file
                finger_file_path = "{}.finger".format(path_to_file[:-9])
                encoding = get_encoding(semantics_data)
                pseudo_labels = solve(encoding)
                if pseudo_labels:
                    pseudo_labels = map(str, pseudo_labels)
                    with open(finger_file_path, "w") as f:
                        f.write(' '.join(map(str, pseudo_labels)))
                        done.add(path_to_file)
                    pbar.update(1)
                else:
                    with open(IGNORE_LIST_PATH, 'a') as f:
                        f.write('{}\n'.format(path_to_file))
                    
                # pprint(semantics_data)
                # pprint(pseudo_labels)
                # pprint(encoding)
                with open(DONE_STATUS_OUTPUT, 'wb') as f:
                    pkl.dump(done, f)
    except:
        logging.error("Exception triggered while processing {}".format(last_file))
        pbar.close()
        raise
        

In [13]:
main()

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

ERROR:root:Exception triggered while processing ../data/package_aa\000122858-1_1_1\000122858-1_1_1.semantic





KeyboardInterrupt: 