# Correcting Note Lengths

In [None]:
from mido import MidiFile
from tqdm import tqdm
import matplotlib.pyplot as plt
import miditoolkit
import decimal
from math import floor
from decimal import Decimal as D
from collections import namedtuple

def continued_fraction(x, k):
    cf = []
    q = floor(x)
    cf.append(q)
    x = x - q
    i = 0

    while x != 0 and i < k:
        q = floor(1 / x)
        if q > k:
            break
        cf.append(q)
        x = 1 / x - q
        i += 1

    return cf

def rationalApproximations(clist, app):
    hn0, kn0 = 0, 1
    hn1, kn1 = 1, 0
    ran, rad = 0, 0
    conlist, finallist = [], []
    fraction = namedtuple("fraction", "ratio, numer, denom")
    for n in clist:
        for i in range(1, n + 1):
            ran = hn0 + (i * hn1)
            rad = kn0 + (i * kn1)
            try:
                if D.copy_abs(app-D(ran/rad)) < D.copy_abs(app-D(hn1/kn1)):
                    conlist.append(fraction(f'{ran}/{rad}', ran, rad))
            except:
                pass
        hn2 = (n * hn1) + hn0
        kn2 = (n * kn1) + kn0
        conlist.append(fraction(f'{hn2}/{kn2}', hn2, kn2))
        hn0, kn0 = hn1, kn1
        hn1, kn1 = hn2, kn2
    #Change x.ratio to x.denom or x.numer for numerators or denominators   
    finallist = [ x.ratio for x in sorted(conlist, key=lambda i: i.denom) ]
    return list(dict.fromkeys(finallist))

def bestApproximations(length,tps):
    value = D(length/tps)
    prec = len(str(value))*5
    decimal.getcontext().prec = prec
    vc = continued_fraction(value, prec)
    vc = rationalApproximations(vc, value)
    return vc

def correctLength(length,tps):

    corrections=range(-3,3)
    lengths=[]
    for correction in corrections:
        lengths.append(length+correction)

    store=[]

    for length in lengths:
        approximations=bestApproximations(length,tps)
        store.append((length,approximations))
    
    # return minimum value of store
    return min(store, key=lambda x: len(x[1]))[0]  

def noteLengths(file):
    mid_in= miditoolkit.midi.parser.MidiFile(file)
    for channel in mid_in.instruments:
        notes = channel.notes
        notes = sorted(notes, key=lambda x: x.start)
    note_lengths=[]
    tps=mid_in.ticks_per_beat
    for note in notes:
        note_length=(note.end-note.start+1)
        note_lengths.append(note_length)
    return(note_lengths,tps)

def correctNoteLengths(note_lengths,tps):
    corrected_dict={}
    corrected_lengths=[]
    check=False
    for note_length in note_lengths:
        if note_length not in corrected_dict:
            corrected_dict[note_length]=correctLength(note_length,tps)
        corrected_length=corrected_dict[note_length]
        if corrected_length>0:
            corrected_lengths.append(corrected_length/tps)
        else:
            if not check:
                # print(file)
                check=True
    return corrected_lengths

## I/O Utilities

In [None]:
from statistics import mean
import miditoolkit
import numpy as np

def getNotes(file):
    mid_in= miditoolkit.midi.parser.MidiFile(file)

    notes=[]
    for channel in mid_in.instruments:
        notes.extend(channel.notes)
    notes.sort(key=lambda x: x.start)
    
    return(notes)

def getNoteLengths(file):
    mid_in= miditoolkit.midi.parser.MidiFile(file)

    notes=[]
    for channel in mid_in.instruments:
        notes.extend(channel.notes)
    notes.sort(key=lambda x: x.start)
    note_lengths=[note.end-note.start+1 for note in notes]
    corrected_lengths=correctNoteLengths(note_lengths,mid_in.ticks_per_beat)
    # corrected_lengths=np.array(note_lengths)*(1/mid_in.ticks_per_beat)
    return(corrected_lengths)

def getNoteStarts(file):
    mid_in= miditoolkit.midi.parser.MidiFile(file)

    notes=[]
    for channel in mid_in.instruments:
        notes.extend(channel.notes)
    notes.sort(key=lambda x: x.start)
    note_starts=[note.start for note in notes]
    return(note_starts)

def scaleLengths(note_lengths,min_beat):
    scaled_note_lengths=[int(note_length*min_beat) for note_length in note_lengths]
    scaled_note_lengths=[note_length if note_length>0 else -1 for note_length in scaled_note_lengths]
    return(scaled_note_lengths)

def scalingError(note_lengths,min_beat,norm=0):
    scaled_note_lengths=np.array([note_length*min_beat for note_length in note_lengths])
    int_scaled_note_lengths=scaled_note_lengths.astype(int)

    if norm>0:
        errors=np.abs(scaled_note_lengths-int_scaled_note_lengths)
        errors=errors**norm
    else:
        # number of non zero errors
        errors=np.array([1 if note_length>0 else 0 for note_length in scaled_note_lengths-int_scaled_note_lengths])
        
    mean_error=np.mean(errors)
    return(mean_error)


def rankMinBeats(note_lengths,min_beats,norm=0):
    # compare errors for each min_beat and lowest min_beats
    errors=[]
    for min_beat in min_beats:
        errors.append((min_beat,scalingError(note_lengths,min_beat,norm)))
    errors=sorted(errors, key=lambda x: x[1])
    return(errors[:5])



## Tokenization

In [None]:
norm=1
min_beats=[8,12,16,24,32,48,64]
min_beat=48

file="Datasets/asap-dataset/Bach/Fugue/bwv_846/midi_score.mid"


notes=getNotes(file)

note_lengths=getNoteLengths(file)
scaled_note_lengths=scaleLengths(note_lengths,min_beat)

note_starts=getNoteStarts(file)
scaled_note_starts=scaleLengths(note_starts,min_beat)
print(scaled_note_starts[-1])


In [None]:
from miditok import REMI, get_midi_programs
from miditoolkit import MidiFile

# Our parameters
pitch_range = range(21, 109)
beat_res = {(0, 4): 8, (4, 12): 4}
nb_velocities = 1
additional_tokens = {'Chord': False, 'Rest': False, 'Tempo': False, 'Program': False, 'TimeSignature': False}

# Creates the tokenizer and loads a MIDI
tokenizer = REMI(pitch_range, beat_res, nb_velocities, additional_tokens, mask=True)
midi = MidiFile(file)

# Converts MIDI to tokens, and back to a MIDI
tokens = tokenizer.midi_to_tokens(midi)
print(len(tokens))

In [3]:
import pretty_midi as pm
import miditoolkit as mt

midi_path="Store/Score2Melody/Bach/Fugue/bwv_846/Shi05M_score.mid"

midi_data = pm.PrettyMIDI(midi_path)
piano_roll = midi_data.get_piano_roll(fs=100) # shape=(pitch, timestep)
print(piano_roll.shape)

midi_obj=mt.midi.parser.MidiFile(midi_path)

notes=[note for note in [instrument.notes for instrument in midi_obj.instruments]]
sorted_notes=sorted(notes[0], key=lambda x: x.start)
length=sorted_notes[-1].end-sorted_notes[0].start
print(length)

(128, 5399)
51599


In [1]:
import torch

In [2]:
x=torch.arange(25).reshape(5,5)
print(x)

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])


In [5]:
mask=torch.arange(5).view(-1,1)<=torch.arange(5).view(1,-1) # (5,1)<(1,5)
print(mask)

tensor([[ True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True]])


In [7]:
x[~mask]

tensor([ 5, 10, 11, 15, 16, 17, 20, 21, 22, 23])