# Correcting Note Lengths

In [70]:
from mido import MidiFile
from tqdm import tqdm
import matplotlib.pyplot as plt
import miditoolkit
import decimal
from math import floor
from decimal import Decimal as D
from collections import namedtuple

def continued_fraction(x, k):
    cf = []
    q = floor(x)
    cf.append(q)
    x = x - q
    i = 0

    while x != 0 and i < k:
        q = floor(1 / x)
        if q > k:
            break
        cf.append(q)
        x = 1 / x - q
        i += 1

    return cf

def rationalApproximations(clist, app):
    hn0, kn0 = 0, 1
    hn1, kn1 = 1, 0
    ran, rad = 0, 0
    conlist, finallist = [], []
    fraction = namedtuple("fraction", "ratio, numer, denom")
    for n in clist:
        for i in range(1, n + 1):
            ran = hn0 + (i * hn1)
            rad = kn0 + (i * kn1)
            try:
                if D.copy_abs(app-D(ran/rad)) < D.copy_abs(app-D(hn1/kn1)):
                    conlist.append(fraction(f'{ran}/{rad}', ran, rad))
            except:
                pass
        hn2 = (n * hn1) + hn0
        kn2 = (n * kn1) + kn0
        conlist.append(fraction(f'{hn2}/{kn2}', hn2, kn2))
        hn0, kn0 = hn1, kn1
        hn1, kn1 = hn2, kn2
    #Change x.ratio to x.denom or x.numer for numerators or denominators   
    finallist = [ x.ratio for x in sorted(conlist, key=lambda i: i.denom) ]
    return list(dict.fromkeys(finallist))

def bestApproximations(length,tps):
    value = D(length/tps)
    prec = len(str(value))*5
    decimal.getcontext().prec = prec
    vc = continued_fraction(value, prec)
    vc = rationalApproximations(vc, value)
    return vc

def correctLength(length,tps):

    corrections=range(-3,3)
    lengths=[]
    for correction in corrections:
        lengths.append(length+correction)

    store=[]

    for length in lengths:
        approximations=bestApproximations(length,tps)
        store.append((length,approximations))
    
    # return minimum value of store
    return min(store, key=lambda x: len(x[1]))[0]  

def noteLengths(file):
    mid_in= miditoolkit.midi.parser.MidiFile(file)
    for channel in mid_in.instruments:
        notes = channel.notes
        notes = sorted(notes, key=lambda x: x.start)
    note_lengths=[]
    tps=mid_in.ticks_per_beat
    for note in notes:
        note_length=(note.end-note.start+1)
        note_lengths.append(note_length)
    return(note_lengths,tps)

def correctNoteLengths(note_lengths,tps):
    corrected_dict={}
    corrected_lengths=[]
    check=False
    for note_length in note_lengths:
        if note_length not in corrected_dict:
            corrected_dict[note_length]=correctLength(note_length,tps)
        corrected_length=corrected_dict[note_length]
        if corrected_length>0:
            corrected_lengths.append(corrected_length/tps)
        else:
            if not check:
                # print(file)
                check=True
    return corrected_lengths

## I/O Utilities

In [97]:
from statistics import mean
import miditoolkit
import numpy as np

def getNotes(file):
    mid_in= miditoolkit.midi.parser.MidiFile(file)

    notes=[]
    for channel in mid_in.instruments:
        notes.extend(channel.notes)
    notes.sort(key=lambda x: x.start)
    
    return(notes)

def getNoteLengths(file):
    mid_in= miditoolkit.midi.parser.MidiFile(file)

    notes=[]
    for channel in mid_in.instruments:
        notes.extend(channel.notes)
    notes.sort(key=lambda x: x.start)
    note_lengths=[note.end-note.start+1 for note in notes]
    corrected_lengths=correctNoteLengths(note_lengths,mid_in.ticks_per_beat)
    # corrected_lengths=np.array(note_lengths)*(1/mid_in.ticks_per_beat)
    return(corrected_lengths)

def getNoteStarts(file):
    mid_in= miditoolkit.midi.parser.MidiFile(file)

    notes=[]
    for channel in mid_in.instruments:
        notes.extend(channel.notes)
    notes.sort(key=lambda x: x.start)
    note_starts=[note.start for note in notes]
    return(note_starts)

def scaleLengths(note_lengths,min_beat):
    scaled_note_lengths=[int(note_length*min_beat) for note_length in note_lengths]
    scaled_note_lengths=[note_length if note_length>0 else -1 for note_length in scaled_note_lengths]
    return(scaled_note_lengths)

def scalingError(note_lengths,min_beat,norm=0):
    scaled_note_lengths=np.array([note_length*min_beat for note_length in note_lengths])
    int_scaled_note_lengths=scaled_note_lengths.astype(int)

    if norm>0:
        errors=np.abs(scaled_note_lengths-int_scaled_note_lengths)
        errors=errors**norm
    else:
        # number of non zero errors
        errors=np.array([1 if note_length>0 else 0 for note_length in scaled_note_lengths-int_scaled_note_lengths])
        
    mean_error=np.mean(errors)
    return(mean_error)


def rankMinBeats(note_lengths,min_beats,norm=0):
    # compare errors for each min_beat and lowest min_beats
    errors=[]
    for min_beat in min_beats:
        errors.append((min_beat,scalingError(note_lengths,min_beat,norm)))
    errors=sorted(errors, key=lambda x: x[1])
    return(errors[:5])



## Tokenization

In [102]:
norm=1
min_beats=[8,12,16,24,32,48,64]
min_beat=48

file="Datasets/asap-dataset/Bach/Fugue/bwv_846/midi_score.mid"


notes=getNotes(file)

note_lengths=getNoteLengths(file)
scaled_note_lengths=scaleLengths(note_lengths,min_beat)

note_starts=getNoteStarts(file)
scaled_note_starts=scaleLengths(note_starts,min_beat)
print(scaled_note_starts[-1])


2442240


In [104]:
from miditok import REMI, get_midi_programs
from miditoolkit import MidiFile

# Our parameters
pitch_range = range(21, 109)
beat_res = {(0, 4): 8, (4, 12): 4}
nb_velocities = 1
additional_tokens = {'Chord': False, 'Rest': False, 'Tempo': False, 'Program': False, 'TimeSignature': False}

# Creates the tokenizer and loads a MIDI
tokenizer = REMI(pitch_range, beat_res, nb_velocities, additional_tokens, mask=True)
midi = MidiFile(file)

# Converts MIDI to tokens, and back to a MIDI
tokens = tokenizer.midi_to_tokens(midi)
print(len(tokens))

[[1, 160, 42, 91, 95, 164, 44, 91, 95, 168, 46, 91, 95, 172, 47, 91, 97, 178, 49, 91, 92, 179, 47, 91, 92, 180, 46, 91, 95, 184, 51, 91, 95, 1, 156, 44, 91, 95, 160, 49, 91, 97, 166, 51, 91, 93, 168, 49, 91, 93, 170, 47, 91, 93, 172, 46, 91, 93, 174, 47, 91, 93, 176, 46, 91, 93, 49, 91, 95, 178, 44, 91, 93, 180, 42, 91, 93, 51, 91, 95, 182, 44, 91, 93, 184, 42, 91, 93, 53, 91, 95, 186, 41, 91, 93, 1, 156, 39, 91, 95, 54, 91, 97, 160, 48, 91, 95, 162, 56, 91, 92, 163, 54, 91, 92, 164, 49, 91, 103, 53, 91, 95, 168, 58, 91, 95, 172, 51, 91, 95, 176, 48, 91, 93, 56, 91, 97, 178, 46, 91, 93, 180, 48, 91, 95, 182, 58, 91, 93, 184, 44, 91, 95, 56, 91, 93, 186, 54, 91, 93, 1, 156, 49, 91, 93, 53, 91, 93, 158, 49, 91, 93, 160, 47, 91, 95, 51, 91, 93, 162, 53, 91, 93, 164, 46, 91, 95, 54, 91, 93, 166, 53, 91, 93, 168, 44, 91, 95, 54, 91, 93, 170, 56, 91, 93, 172, 58, 91, 93, 174, 56, 91, 93, 176, 58, 91, 93, 178, 60, 91, 93, 180, 61, 91, 95, 184, 49, 91, 99, 53, 91, 95, 1, 156, 54, 91, 95, 160, 