In [None]:
from music21 import converter, instrument, note, chord, stream, duration, midi

import os
import pickle as pkl
from collections import defaultdict
import numpy as np
import copy
from itertools import combinations

In [None]:
quarters = [x/2 for x in range(8)]

def snap_to_quarter(offset, dur):
    l = np.floor(offset)
    options = [l, l+0.5, l+1]
    offset_idx = np.argmin([np.abs(x - offset) for x in options])
    snap_offset = options[offset_idx]
    snap_dur = dur + (snap_offset - offset)
    return snap_offset, snap_dur

def pitch_id(el):
    if isinstance(el, note.Note):
        return str(el.pitch.midi)
    elif isinstance(el, chord.Chord):
        midis = [e.midi for e in el.pitches]
        midis.sort()
        return '.'.join([str(x) for x in midis])

# converts different measurements into a flat array with correct offset
def read_discrete_measures(measures, totlength):
    measures.sort(key = lambda x: x.offset)

    sequence = [[] for _ in range(int(np.ceil(totlength*2)))]

    for m in measures:
        base_offset = m.offset
        qlength = m.duration.quarterLength
        notes, chords = [], []
        for el in m.recurse():
            if isinstance(el, note.Note):
                notes.append(el)
            elif isinstance(el, chord.Chord):
                chords.append(el)
        
        for el in chords + notes:
            snap_offset, snap_dur = snap_to_quarter(base_offset + el.offset, el.duration.quarterLength)
            if snap_offset < totlength and snap_dur > 0:
                sequence[int(snap_offset*2)].append((pitch_id(el), snap_dur))
        
    return sequence

def read_flat_stream(score, totlength):
    sequence = [[] for _ in range(int(np.ceil(totlength*2)))]

    for el in score.recurse():
        if isinstance(el, (note.Note, chord.Chord)):
            snap_offset, snap_dur = snap_to_quarter(el.offset, el.duration.quarterLength)
            if snap_offset < totlength and snap_dur > 0:
                sequence[int(snap_offset*2)].append((pitch_id(el), snap_dur))
    
    return sequence

# function for combining notes that span different measures, assuming flat sequence
def resolve_extensions(sequence, mlength=8):
    def find_in_timepoint(point, el):
        if len(point) == 0:
            return -1
        for i,(alt_el, alt_dur) in enumerate(point):
            if alt_el == el:
                return i
        return -1
    
    windowsize = int(mlength*2)
    bkpts = list(range(0, len(sequence), mlength))

    for i in range(len(bkpts) - 1):
        current_seq = sequence[bkpts[i]:bkpts[i+1]]
        next_bar = sequence[bkpts[i+1]]
        for q in range(7, -1, -1):
            current_tp = current_seq[q]
            new_tp = []
            while current_tp:
                el,dur = current_tp.pop(-1)
                if (q/2) + dur >= 4:
                    idx = find_in_timepoint(next_bar, el)
                    if idx == -1:
                        new_tp.insert(0, (el, dur))
                    else:
                        _, alt_dur = next_bar.pop(idx)
                        dur = 4 - (q/2) + alt_dur
                        new_tp.insert(0, (el, dur))
                else:
                    new_tp.insert(0, (el, dur))
            sequence[bkpts[i] + q] = new_tp

# assuming flat sequence
def resolve_overlap(sequence, thresh=0.1):
    def within_thresh(a,b):
        return np.abs(a - b) <= thresh

    def check_neighbor(el, dur, i, v):
        if v == -1:
            new_dur = dur + 0.5
        elif v == 1:
            new_dur = dur - 0.5
        elif v == 2:
            new_dur = dur - 1.0
        if new_dur <= 0:
            return False
        if len(sequence[i+v]) == 0:
            sequence[i+v].append((el, new_dur))
            return True
        #elif len(sequence[i+v]) == 1:
        #    ael, adur = sequence[i+v][0]
        #    if within_thresh(new_dur, adur):
        #        ael += f'.{el}'
        #        sequence[i+v][0] = (ael, adur)
        #    return True
        else:
            return False

    for i in range(len(sequence)):
        current_bar = sequence[i]
        if len(current_bar) > 1:
            fel, fdur = current_bar[0]
            remaining = current_bar[1:]
            for el,dur in remaining:
                #if within_thresh(dur, fdur):
                #    fel += f'.{el}'
                #    current_bar[0] = (fel, fdur)
                #else:
                alts = [a for a in [-1, 1, 2] if i+a > 0 and i+a < len(sequence)]
                for a in alts:
                    if check_neighbor(el, dur, i, a):
                        break
        sequence[i] = sequence[i][:1]

def extract(s, verbose=False):
    totlength = s.duration.quarterLength

    # first isolate part
    def is_piano(part):
        for instr in part.getInstruments(recurse=True):
            if isinstance(instr, instrument.Piano):
                return True
        return False
    
    def total_num_notes(part):
        t = 0
        for el in part.flat:
            if isinstance(el, (note.Note, chord.Chord)):
                t += 1
        return t
    
    parts = instrument.partitionByInstrument(s)
    if not parts:
        score = s.flat
    elif len(parts) == 1:
        score = parts[0]
    else:
        piano_parts = []
        for part in parts:
            if is_piano(part):
                piano_parts.append(part)
        if len(piano_parts) == 0:
            if verbose:
                print('File Error: Cannot find suitable part')
            return None
        elif len(piano_parts) == 1:
            score = piano_parts[0]
        else:
            piano_parts.sort(key=lambda x: total_num_notes(x), reverse=True)
            score = piano_parts[0]
    
    # check if of the form discretized measures
    measures = []
    for x in score.getElementsByClass(stream.Measure):
        measures.append(x)
    if len(measures) > 0:
        sequence = read_discrete_measures(measures, totlength)
        resolve_extensions(sequence)
    
    # otherwise process flat stream
    else:   
        sequence = read_flat_stream(score, totlength)
    
    return sequence

In [None]:
# Fixed merged chords and adds rests
def clean_sequence(sequence):
    for i in range(len(sequence)):
        if len(sequence[i]) == 0:
            sequence[i] = ('REST', 0.5)
        else:
            midi_code, dur = sequence[i][0]
            ind_notes = [int(x) for x in midi_code.split('.')][:4]
            ind_notes = list(set(ind_notes))
            ind_notes.sort()
            corrected_code = '.'.join([str(x) for x in ind_notes])
            sequence[i] = (corrected_code, dur)

def pad_sequence(sequence):
    assert len(sequence) < 64
    r,n = 0,0
    for midi_code,duration in sequence:
        if midi_code == 'REST':
            r += 1
        else:
            n += 1

    if r/(n+r) > 0.6:
        splits = True
        for midi_code, dur in sequence:
            if midi_code != 'REST' and dur > 0.5:
                splits=False
        if splits:
            for i in range(len(sequence)):
                midi_code, dur = sequence[i]
                if midi_code != 'REST':
                    sequence[i] = (midi_code, dur+0.5)

def reduce_rests(sequence):
    # removing leading rests
    i = 0
    while sequence[i][0] == 'REST':
        i += 1
        if i >= len(sequence):
            return None
    
    j = len(sequence) - 1
    while sequence[j][0] == 'REST':
        j -= 1
    
    #sequence = sequence[i:j]
    #new_seq = []

    #prev_note = None
    #rest_duration = 0
    #for i, (midi_code,dur) in enumerate(sequence):
    #    if midi_code == 'REST':
    #        if prev_note == 'REST':
    #            rest_duration += dur
    #        else:
    #            prev_note = 'REST'
    #            rest_duration = 0.5
    #    else:
    #        if prev_note == 'REST':
    #            new_note = ('REST', min(rest_duration, 2))
    #            new_seq.append(new_note)
    #        new_seq.append((midi_code, dur))
    #        prev_note = 'NOTE'

    #if len(new_seq) < 8:
    #    return None
    #return new_seq
    return sequence[i:j]

# Determines if sequences are acceptable if they contain a min number of notes and a min ratio of music/rest
def filter_sequence(sequence, min_note=8, min_ratio=0.4):
    if sequence == None:
        return False
    tot_dur = 0
    i = 0
    while i < len(sequence):
        midi_code,duration = sequence[i]
        step = int(np.ceil(2*duration))
        if (2*duration) % 1 < 0.5 and (2*duration) % 1 != 0:
            step -= 1
        if step > 0:
            i += step
            if midi_code != 'REST':
                tot_dur += step
        else:
            i += 1

    if tot_dur / len(sequence) < min_ratio:
        return False


    r,n = 0,0
    for midi_code,duration in sequence:
        if midi_code == 'REST':
            r += 1
        else:
            n += 1
    
    if n < min_note:
        return False
    
    return True

In [None]:
def converge_octaves(sequences, lbound=40, ubound=80):
    for seq in sequences:
        for i in range(len(seq)):
            midi_code, dur = seq[i]
            if midi_code != 'REST':
                els = [int(x) for x in midi_code.split('.')]
                adjusted = []
                for el in els:
                    while el < lbound:
                        el += 12
                    while el >= ubound:
                        el -= 12
                    adjusted.append(str(el))
                new_code = '.'.join(adjusted)
                seq[i] = (new_code, dur)

def create_variants(sequence):
    if len(sequence) < 64:
        n = len(sequence)
        while len(sequence) < 64:
            sequence.append(('REST', 0.5))
            sequence += sequence[:n]
        sequence = sequence[:64]

    n = len(sequence)
    subseqs = []
    labels = []
    for i in range(32, n):
        subseqs.append(sequence[i-32:i])
        labels.append(sequence[i])
    return subseqs, labels

In [None]:
# function to find the closest high-freq chord 
# Order: (1-4) notes in different octave - if multiple notes, prioritize changes of similar magnitude?
#        (5-6) notes swapped out
#        For 4-note: 1st is replaced by a 3-chord matching
#        For 3-note, 5th is replaced by a 2-chord matching
def find_closest_four(ch, high_freq):
    comp = ch.split('.')
    sub_chs = ['.'.join(c) for c in combinations(comp, 3)]
    hits = [c for c in sub_chs if c in high_freq]
    if len(hits) > 0:
        mapping = np.random.choice(hits)
        return mapping
    else:
        hits = gen_variants_1diff_octave(comp, high_freq)
        if len(hits) > 0:
            mapping = np.random.choice(hits)
            return mapping
        
        hits = gen_variants_1diff(comp, high_freq)
        if len(hits) > 0:
            mapping = np.random.choice(hits)
            return mapping
        
        hits = gen_variants_2diff_octave(comp, high_freq)
        if len(hits) > 0:
            mapping = np.random.choice(hits)
            return mapping
        
        hits = gen_variants_2diff(comp, high_freq)
        if len(hits) > 0:
            mapping = np.random.choice(hits)
            return mapping

        else:
            return None

def find_closest_three(ch, high_freq):
    comp = ch.split('.')
    
    hits = gen_variants_1diff_octave(comp, high_freq)
    if len(hits) > 0:
        mapping = np.random.choice(hits)
        return mapping
    
    hits = gen_variants_1diff(comp, high_freq)
    if len(hits) > 0:
        mapping = np.random.choice(hits)
        return mapping
    
    sub_chs = ['.'.join(c) for c in combinations(comp, 2)]
    hits = [c for c in sub_chs if c in high_freq]
    if len(hits) > 0:
        mapping = np.random.choice(hits)
        return mapping

    else:
        return None

def find_closest_two(ch, high_freq):
    comp = ch.split('.')

    hits = gen_variants_1diff_octave(comp, high_freq)
    if len(hits) > 0:
        mapping = np.random.choice(hits)
        return mapping
    
    hits = gen_variants_2diff_octave(comp, high_freq)
    if len(hits) > 0:
        mapping = np.random.choice(hits)
        return mapping

    else:
        return None

def note_ranges(c, lbound=40, ubound=80):
    vals = []
    lc = c - 12
    while lc >= lbound:
        vals.append(lc)
        lc -= 12
    uc = c + 12
    while uc < ubound:
        vals.append(uc)
        uc += 12
    return vals

def gen_variants_1diff_octave(comp, high_freq):
    icomp = [int(c) for c in comp]
    altvals = [note_ranges(c) for c in icomp]
    
    idxs = list(range(len(icomp)))
    variants = []
    for i in idxs:
        fixed = [j for j in idxs if j != i]
        for alt in altvals[i]:
            cand = [alt] + [icomp[j] for j in fixed]
            cand.sort()
            cand = '.'.join([str(x) for x in cand])
            if cand in high_freq:
                variants.append(cand)
    return list(set(variants))

def gen_variants_1diff(comp, high_freq):
    icomp = [int(c) for c in comp]
    altvals = [a for a in range(40, 80) if a not in icomp]

    idxs = list(range(len(icomp)))
    variants = []
    for i in idxs:
        fixed = [j for j in idxs if j != i]
        for alt in altvals:
            cand = [alt] + [icomp[j] for j in fixed]
            cand.sort()
            cand = '.'.join([str(x) for x in cand])
            if cand in high_freq:
                variants.append(cand)
    return list(set(variants))

def gen_variants_2diff_octave(comp, high_freq):
    icomp = [int(c) for c in comp]
    altvals = [note_ranges(c) for c in icomp]
    
    idxs = list(range(len(icomp)))
    variants = []
    for i1 in idxs:
        for i2 in idxs[i1:]:
            fixed = [j for j in idxs if j != i1 and j != i2]
            for alt1 in altvals[i1]:
                for alt2 in altvals[i2]:
                    cand = [alt1, alt2] + [icomp[j] for j in fixed]
                    cand.sort()
                    cand = '.'.join([str(x) for x in cand])
                    if cand in high_freq:
                        variants.append(cand)
    return list(set(variants))

def gen_variants_2diff(comp, high_freq):
    icomp = [int(c) for c in comp]
    altvals = [a for a in range(40, 80) if a not in icomp]
    high_freq_four = [hf for hf in high_freq if hf.count('.') == 3]
    
    idxs = list(range(len(icomp)))
    variants = []
    for i1 in idxs:
        for i2 in idxs[i1:]:
            fixed = [j for j in idxs if j != i1 and j != i2]
            for hf in high_freq_four:
                hcomp = [int(x) for x in hf.split('.')]
                #print(hcomp, icomp[fixed[0]], icomp[fixed[1]])
                if icomp[fixed[0]] in hcomp and icomp[fixed[1]] in hcomp:
                    variants.append(hf)
    return list(set(variants))

def compress_dictionary(sequences, ffreqs, thresh=25):
    high_freq = [i for i,v in ffreqs if v >= thresh and i != 'REST']
    low_freq = [i for i,v in ffreqs if v < thresh]
    mappings = {}
    for ch in low_freq:
        if ch.count('.') == 3:
            m = find_closest_four(ch, high_freq)
        elif ch.count('.') == 2:
            m = find_closest_three(ch, high_freq)
        elif ch.count('.') == 1:
            m = find_closest_two(ch, high_freq)
        if m is not None:
            mappings[ch] = m
    return mappings

In [None]:
def gather_all(data_dir):
    midi_files = []
    for dirpath, dirnames, filenames in os.walk(data_dir):
        for filename in filenames:
            if filename[-4:].lower() == '.mid':
                midi_files.append(os.path.join(dirpath, filename))
    
    print(f'Total # Files: {len(midi_files)}')

    sequences = []
    for i,mf in enumerate(midi_files):
        if i % 50 == 0:
            print(f'Num Processed: {i}')
        s = converter.parse(mf)
        sequence = extract(s)
        if sequence:
            sequences.append(sequence)
            #resolve_overlap(sequence, thresh=0.1)
            #clean_sequence(sequence)
            #if len(sequence) < 64:
            #    pad_sequence(sequence)
            #sequences.append(sequence)
    
    return sequences

In [None]:
from_scratch = False

if from_scratch:
    raw_sequences = gather_all('datasets')
    with open('raw_seqs.pkl', 'wb') as f:
        pkl.dump(raw_sequences, f)

else:
    with open('raw_seqs.pkl', 'rb') as f:
        raw_sequences = pkl.load(f)

In [None]:
print(len(raw_sequences))
wsequences = copy.deepcopy(raw_sequences)

In [None]:
filtered_sequences = []
for seq in wsequences:
    resolve_overlap(seq, thresh=0.1)
    clean_sequence(seq)
    seq = reduce_rests(seq)
    if seq is not None:
        if len(seq) < 64:
            pad_sequence(seq)
        if filter_sequence(seq, min_ratio=0.6):
            filtered_sequences.append(seq)
print(len(filtered_sequences))
print(np.mean([len(seq) for seq in filtered_sequences]))

In [None]:
converge_octaves(filtered_sequences)
freqs = defaultdict(int)
for seq in filtered_sequences:
    for n,dur in seq:
        freqs[n] += 1
        
print(len(freqs))

flat_freqs = list(freqs.items())
flat_freqs.sort(key=lambda x: x[1], reverse=True)

In [None]:
mappings = compress_dictionary(filtered_sequences, flat_freqs, thresh=25)
for seq in filtered_sequences:
    for i in range(len(seq)):
        midi_code, dur = seq[i]
        if midi_code in mappings:
            seq[i] = (mappings[midi_code], dur)

In [None]:
freqs = defaultdict(int)
for seq in filtered_sequences:
    for n,dur in seq:
        freqs[n] += 1
        
print(len(freqs))

flat_freqs = list(freqs.items())
flat_freqs.sort(key=lambda x: x[1], reverse=True)

In [None]:
final_sequences = []
labels = []

for seq in filtered_sequences:
    cur_subseqs, cur_labels = create_variants(seq)
    final_sequences.extend(cur_subseqs)
    labels.extend(cur_labels)

print(len(final_sequences))

In [None]:
def normalize_duration(d, d_min, d_max):
    return (d - d_min) / (d_max - d_min)

def denormalize_duration(d_normalized, d_min, d_max):
    return d_normalized * (d_max - d_min) + d_min

# Preprocessing
vocab = {v[0]: i for i,v in enumerate(flat_freqs)}
encoded_sequences = []
for seq in final_sequences:
    encoded_sequences.append([(vocab[midi_code], dur) for midi_code,dur in seq])

all_durations = [dur for seq in encoded_sequences for midi_code,dur in seq] + [dur for midi_code,dur in labels]
d_min = min(all_durations)
d_max = max(all_durations)

normalized_sequences = [
    [(x, normalize_duration(d, d_min, d_max)) for x, d in seq]
    for seq in encoded_sequences
]

note_labels = [vocab[label[0]] for label in labels]
duration_labels = [(label[1] - d_min) / (d_max - d_min) for label in labels]
encoded_labels = [(note_labels[i], duration_labels[i]) for i in range(len(labels))]

In [None]:
with open('processed_data2.pkl', 'wb') as f:
    pkl.dump((normalized_sequences, encoded_labels, vocab, d_min, d_max), f)