In [1]:
import os
import numpy as np
import pretty_midi
from note import MIDI_note

def find_files_by_extensions(root, exts=[]):
    def _has_ext(name):
        if not exts:
            return True
        name = name.lower()
        for ext in exts:
            if name.endswith(ext):
                return True
        return False
    for path, _, files in os.walk(root):
        for name in files:
            if _has_ext(name):
                yield os.path.join(path, name)
    return files

def preprocess_midi_files_under(midi_folder, preprocess_folder):
    midi_paths = list(find_files_by_extensions(midi_folder, ['.mid', '.midi']))
    os.makedirs(midi_folder, exist_ok=True)
    os.makedirs(preprocess_folder, exist_ok=True)

    for path in midi_paths:
        file_name = os.path.split(path)[1]
        new_path = os.path.join(preprocess_folder, file_name)
        
        print(' ', end='[{}]'.format(path), flush=True)
        if os.path.exists(new_path + '.npz'):
            continue

        token_seq, channel_seq, tempo_seq = encode(path)
        np.savez(new_path + ".npz", token_seq=token_seq, channel_seq=channel_seq, tempo_seq=tempo_seq)

def extract_midi(path):
    mid = pretty_midi.PrettyMIDI(midi_file=path)
    tempo_times, tempo_bpm = mid.get_tempo_changes()
    end_time = mid.get_end_time()
    tempo_times = np.append(tempo_times, end_time)

    midi_notes = []
    for inst in mid.instruments:
        channel = inst.program
        for n in inst.notes:
            idx = next((i for i, t in enumerate(tempo_bpm) if tempo_times[i] <= n.start < tempo_times[i + 1]))
        
            midi_notes.append(MIDI_note(pitch=n.pitch, time_start=n.start, time_end=n.end, dynamic=n.velocity, channel=channel, tempo=round(tempo_bpm[idx])))

    midi_notes = list(set(midi_notes))
    midi_notes = sorted(midi_notes, key=lambda note: note.time_start)

    return midi_notes, tempo_times, tempo_bpm

def get_beats_and_tempo(tempo_times, tempo_bpm):
    res_per_beat = 12
    resolutions = []
    total_beats = [0]

    for idx, val in enumerate(tempo_times[:-1]):
        beat_length = 60 / tempo_bpm[idx]
        resolution = beat_length / res_per_beat
        resolutions.append(resolution)
        
        num_beats = round((tempo_times[idx + 1] - val) / resolution)
        total_beats.append(num_beats + total_beats[idx])

    # Convert lists to NumPy arrays
    resolutions = np.array(resolutions)
    total_beats = np.array(total_beats)
    return resolutions, total_beats

def adjust_note_time(midi_notes, tempo_times, resolutions, total_beats, attribute):
    for n in midi_notes:
        idx = 0  # Reset idx for each note
        while idx < len(tempo_times) - 1:
            if tempo_times[idx] <= getattr(n, attribute) < tempo_times[idx + 1]:
                ticks = (getattr(n, attribute) - tempo_times[idx]) / resolutions[idx]
                setattr(n, attribute, np.round(total_beats[idx] + ticks))
                break
            elif idx < len(tempo_times) - 2:
                idx += 1
            else:
                break

PITCH_RES = 128
DYN_RES = 128
LENGTH_RES = 400
TIME_RES = 400

START_IDX = {
    'PITCH_RES' : 0,
    'DYN_RES' : PITCH_RES,
    'LENGTH_RES' : PITCH_RES + DYN_RES,
    'TIME_RES' : PITCH_RES + DYN_RES + LENGTH_RES
}

def encode(path):
    midi_notes, tempo_times, tempo_bpm = extract_midi(path)
    resolutions, total_beats = get_beats_and_tempo(tempo_times, tempo_bpm)
    adjust_note_time(midi_notes, tempo_times, resolutions, total_beats, 'time_start')
    adjust_note_time(midi_notes, tempo_times, resolutions, total_beats, 'time_end')

    token_seq, channel_seq, tempo_seq = [], [], []
    for idx, m in enumerate(midi_notes[:-1]):
        token_seq.extend([START_IDX['DYN_RES'] + m.dynamic, START_IDX['PITCH_RES'] + m.pitch, START_IDX['LENGTH_RES'] + m.time_end - m.time_start, START_IDX['TIME_RES'] + midi_notes[idx + 1].time_start - m.time_start])
        channel_seq.append(m.channel)
        tempo_seq.append(m.tempo)

    return token_seq, channel_seq, tempo_seq

In [2]:
preprocess_midi_files_under('F:\\GitHub\\dataset\\midi_dataset', 'F:\\GitHub\\dataset\\np_dataset')

 [F:\GitHub\dataset\midi_dataset\Beethoven - Symphony no. 5.mid] [F:\GitHub\dataset\midi_dataset\Beethoven - Symphony no. 6 - 1st movement.mid] [F:\GitHub\dataset\midi_dataset\Beethoven - Symphony no. 7 - 2nd movement.mid] [F:\GitHub\dataset\midi_dataset\Beethoven - Symphony no. 9 - 2nd movement.mid] [F:\GitHub\dataset\midi_dataset\Beethoven - Symphony no. 9 - 4th movement.mid] [F:\GitHub\dataset\midi_dataset\Dukas - Sorcerer's Apprentice.mid] [F:\GitHub\dataset\midi_dataset\Hans_Zimmer_-_Pirates_Of_The_Caribbean_-_He's_A_Pirate.mid]



 [F:\GitHub\dataset\midi_dataset\Mozart - Eine Kleine Nachtmusik.mid] [F:\GitHub\dataset\midi_dataset\Mozart - Lacrimoza.mid] [F:\GitHub\dataset\midi_dataset\Mozart - Marige of Figaro.mid] [F:\GitHub\dataset\midi_dataset\Mozart - Queen of the Night.mid] [F:\GitHub\dataset\midi_dataset\Mozart - Symphony no. 40.mid] [F:\GitHub\dataset\midi_dataset\Mozart - Symphony no. 41 - 3rd movement.mid] [F:\GitHub\dataset\midi_dataset\Paul Dukas - Sorcerer's Apprentice.mid]

In [3]:
loaded = np.load("F:\\GitHub\\dataset\\np_dataset\\Mozart - Lacrimoza.mid.npz")
token_seq = loaded["token_seq"]
channel_seq = loaded["channel_seq"]
tempo_seq = loaded["tempo_seq"]

In [4]:
token_seq

array([ 48., 193., 268., ..., 194., 328., 656.])