## Setup

In [1]:
import pathlib
import pretty_midi
import os
import csv

## Retrieve YM2413-MDB (v1.0.2) Dataset

In [2]:
# Directory containing MIDI files
data_dir = pathlib.Path('../music_dataset/YM2413-MDB-v1.0.2/midi/adjust_tempo_remove_delayed_inst')

# Get list of MIDI files
filenames = list(data_dir.glob('*.mid*'))
print('Number of files:', len(filenames))

Number of files: 669


## Process MIDI Files

In [3]:
pretty_midi.pretty_midi.MAX_TICK = 1e16

# Initialize list to store inspection results
inspection_results = []

for filepath in filenames:
    # Convert WindowsPath object to string
    filepath_str = str(filepath)
    
    # Process MIDI file
    pm = pretty_midi.PrettyMIDI(filepath_str)
    num_instruments = len(pm.instruments)
    instrument_names = [pretty_midi.program_to_instrument_name(inst.program) for inst in pm.instruments]

    # Store inspection results
    inspection_results.append({
        "Filename": os.path.basename(filepath_str),
        "Number of Instruments": num_instruments,
        "Instrument Names": instrument_names
    })

# Write inspection results to CSV file
csv_filename = '../ym2413_jupyter_proj/inspection_results.csv'
csv_columns = ["Filename", "Number of Instruments", "Instrument Names"]

with open(csv_filename, "w", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
    writer.writeheader()
    for result in inspection_results:
        writer.writerow(result)

print("Inspection results saved to", csv_filename)

Inspection results saved to ../ym2413_jupyter_proj/inspection_results.csv


## Process Single MIDI File

In [4]:
# sample_file = filenames[0]
# print(sample_file)

# # Convert WindowsPath object to string
# sample_file_str = str(sample_file)

# pm = pretty_midi.PrettyMIDI(sample_file_str)

# print('Number of instruments:', len(pm.instruments))

# # Extract instrument names without considering brackets
# instrument_names = [pretty_midi.program_to_instrument_name(inst.program) for inst in pm.instruments]

# # Print all instrument names
# print('Instrument names:', instrument_names)

## Convert MIDI to Nintendo Entertainment System (NES) Format

In [5]:
import itertools
import os
import random

In [6]:
# Quarter to Emotion Mapping
quarter_to_emotion = {
    'Q1': 'happy',
    'Q2': 'angry',
    'Q3': 'sad',
    'Q4': 'relaxed'
}

In [7]:
nes_ins_name_to_min_pitch = {
    'p1': 33,
    'p2': 33,
    'tr': 21
}
nes_ins_name_to_max_pitch = {
    'p1': 108,
    'p2': 108,
    'tr': 108
}

In [8]:
def instrument_is_monophonic(ins):
    # Ensure sorted
    notes = ins.notes
    last_note_start = -1
    for n in notes:
        assert n.start >= last_note_start
        last_note_start = n.start

    monophonic = True
    for i in range(len(notes) - 1):
        n0 = notes[i]
        n1 = notes[i + 1]
        if n0.end > n1.start:
            monophonic = False
            break
    return monophonic

In [9]:
def preprocess_midi(midi_fp):
    filename = os.path.split(midi_fp)[1].split('.')[0]
    quarter_label = filename.split('_')[0]  # Extract the quarter label from the filename
    emotion_label = quarter_to_emotion.get(quarter_label, 'Unknown')  # Map the quarter label to an emotion
    return emotion_label

In [10]:
def emit_nesmdb_midi_examples(
    midi_fp,
    output_dir,
    min_num_instruments=1,
    filter_mid_len_below_seconds=5.,
    filter_mid_len_above_seconds=600.,
    filter_mid_bad_times=True,
    filter_ins_max_below=21,
    filter_ins_min_above=108,
    filter_ins_duplicate=True,
    output_include_drums=True,
    output_max_num=16,
    output_max_num_seconds=180.):
    midi_name = os.path.split(midi_fp)[1].split('.')[0]

    if min_num_instruments <= 0:
        raise ValueError()

    # Ignore unusually large MIDI files (only ~25 of these in the dataset)
    if os.path.getsize(midi_fp) > (512 * 1024): #512K
        return

    try:
        midi = pretty_midi.PrettyMIDI(midi_fp)
    except:
        return

    # Filter MIDIs with extreme length
    midi_len = midi.get_end_time()
    if midi_len < filter_mid_len_below_seconds or midi_len > filter_mid_len_above_seconds:
        return

    # Filter out negative times and quantize to audio samples
    for ins in midi.instruments:
        for n in ins.notes:
            if filter_mid_bad_times:
                if n.start < 0 or n.end < 0 or n.end < n.start:
                    return
            n.start = round(n.start * 44100.) / 44100.
            n.end = round(n.end * 44100.) / 44100.

    instruments = midi.instruments

    # Filter out drum instruments
    drums = [i for i in instruments if i.is_drum]
    instruments = [i for i in instruments if not i.is_drum]

    # Filter out instruments with bizarre ranges
    instruments_normal_range = []
    for ins in instruments:
        pitches = [n.pitch for n in ins.notes]
        min_pitch = min(pitches)
        max_pitch = max(pitches)
        if max_pitch >= filter_ins_max_below and min_pitch <= filter_ins_min_above:
            instruments_normal_range.append(ins)
    instruments = instruments_normal_range
    if len(instruments) < min_num_instruments:
        return

    # Sort notes for polyphonic filtering and proper saving
    for ins in instruments:
        ins.notes = sorted(ins.notes, key=lambda x: x.start)
    if output_include_drums:
        for ins in drums:
            ins.notes = sorted(ins.notes, key=lambda x: x.start)

    # Filter out polyphonic instruments
    instruments = [i for i in instruments if instrument_is_monophonic(i)]
    if len(instruments) < min_num_instruments:
        return

    # Filter out duplicate instruments
    if filter_ins_duplicate:
        uniques = set()
        instruments_unique = []
        for ins in instruments:
            pitches = ','.join(['{}:{:.1f}'.format(str(n.pitch), n.start) for n in ins.notes])
            if pitches not in uniques:
                instruments_unique.append(ins)
                uniques.add(pitches)
        instruments = instruments_unique
        if len(instruments) < min_num_instruments:
            return

    # Create assignments of MIDI instruments to NES instruments
    num_instruments = len(instruments)
    if num_instruments == 1:
        instrument_perms = [(0, -1, -1), (-1, 0, -1), (-1, -1, 0)]
    elif num_instruments == 2:
        instrument_perms = [(-1, 0, 1), (-1, 1, 0), (0, -1, 1), (0, 1, -1), (1, -1, 0), (1, 0, -1)]
    elif num_instruments > 32:
        instrument_perms = list(itertools.permutations(random.sample(range(num_instruments), 32), 3))
    else:
        instrument_perms = list(itertools.permutations(range(num_instruments), 3))

    if len(instrument_perms) > output_max_num:
        instrument_perms = random.sample(instrument_perms, output_max_num)

    num_drums = len(drums) if output_include_drums else 0
    instrument_perms_plus_drums = []
    for perm in instrument_perms:
        selection = -1 if num_drums == 0 else random.choice(range(num_drums))
        instrument_perms_plus_drums.append(perm + (selection,))
    instrument_perms = instrument_perms_plus_drums

    # Emit midi files
    for i, perm in enumerate(instrument_perms):
        # Create MIDI instruments
        p1_prog = pretty_midi.instrument_name_to_program('Lead 1 (square)')
        p2_prog = pretty_midi.instrument_name_to_program('Lead 2 (sawtooth)')
        tr_prog = pretty_midi.instrument_name_to_program('Synth Bass 1')
        no_prog = pretty_midi.instrument_name_to_program('Breath Noise')
        p1 = pretty_midi.Instrument(program=p1_prog, name='p1', is_drum=False)
        p2 = pretty_midi.Instrument(program=p2_prog, name='p2', is_drum=False)
        tr = pretty_midi.Instrument(program=tr_prog, name='tr', is_drum=False)
        no = pretty_midi.Instrument(program=no_prog, name='no', is_drum=True)

        # Filter out invalid notes
        perm_mid_ins_notes = []
        for mid_ins_id, nes_ins_name in zip(perm, ['p1', 'p2', 'tr', 'no']):
            if mid_ins_id < 0:
                perm_mid_ins_notes.append(None)
            else:
                if nes_ins_name == 'no':
                    mid_ins = drums[mid_ins_id]
                    mid_ins_notes_valid = mid_ins.notes
                else:
                    mid_ins = instruments[mid_ins_id]
                    mid_ins_notes_valid = [n for n in mid_ins.notes if n.pitch >= nes_ins_name_to_min_pitch[nes_ins_name] and n.pitch <= nes_ins_name_to_max_pitch[nes_ins_name]]
                perm_mid_ins_notes.append(mid_ins_notes_valid)
        assert len(perm_mid_ins_notes) == 4

        # Calculate length of this ensemble
        start = None
        end = None
        for notes in perm_mid_ins_notes:
            if notes is None or len(notes) == 0:
                continue
            ins_start = min([n.start for n in notes])
            ins_end = max([n.end for n in notes])
            if start is None or ins_start < start:
                start = ins_start
            if end is None or ins_end > end:
                end = ins_end
        if start is None or end is None:
            continue

        # Clip if needed
        if (end - start) > output_max_num_seconds:
            end = start + output_max_num_seconds

        # Create notes
        for mid_ins_notes, nes_ins_name, nes_ins in zip(perm_mid_ins_notes, ['p1', 'p2', 'tr', 'no'], [p1, p2, tr, no]):
            if mid_ins_notes is None:
                continue

            if nes_ins_name == 'no':
                random_noise_mapping = [random.randint(1, 16) for _ in range(128)]

            last_nend = -1
            for ni, n in enumerate(mid_ins_notes):
                nvelocity = n.velocity
                npitch = n.pitch
                nstart = n.start
                nend = n.end

                # Drums are not necessarily monophonic so we need to filter
                if nes_ins_name == 'no' and nstart < last_nend:
                    continue
                last_nend = nend

                assert nstart >= start
                if nend > end:
                    continue
                assert nend <= end

                nvelocity = 1 if nes_ins_name == 'tr' else int(round(1. + (14. * nvelocity / 127.)))
                assert nvelocity > 0
                if nes_ins_name == 'no':
                    npitch = random_noise_mapping[npitch]
                nstart = nstart - start
                nend = nend - start

                nes_ins.notes.append(pretty_midi.Note(nvelocity, npitch, nstart, nend))

        # Add instruments to MIDI file
        midi = pretty_midi.PrettyMIDI(initial_tempo=120, resolution=22050)
        midi.instruments.extend([p1, p2, tr, no])

        # Create indicator for end of song
        eos = pretty_midi.TimeSignature(1, 1, end - start)
        midi.time_signature_changes.append(eos)

        # Save MIDI file
        out_fp = '{}_{}.mid'.format(midi_name, str(i).zfill(3))
        out_fp = os.path.join(output_dir, out_fp)
        midi.write(out_fp)

In [11]:
import shutil
import os

if __name__ == '__main__':
    pretty_midi.pretty_midi.MAX_TICK = 1e16

    data_dir = "../music_dataset/YM2413-MDB-v1.0.2/midi/adjust_tempo_remove_delayed_inst"
    out_dir = './output'

    # Create the output directory if it doesn't exist
    if os.path.isdir(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir)

    # Iterate through the directory and its subdirectories
    for root, _, files in os.walk(data_dir):
        for file in files:
            # Check if the file is a MIDI file
            if file.endswith(".mid"):
                # Construct the full path of the MIDI file
                midi_fp = os.path.join(root, file)
                
                # Generate MIDI examples and save them in the output directory
                emit_nesmdb_midi_examples(midi_fp, out_dir)

In [12]:
data_dir = pathlib.Path('./output')
filenames = list(data_dir.glob('*.mid*'))
print('Number of files (after permutations):', len(filenames))

Number of files (after permutations): 7612


## Extract Notes

In [13]:
# def midi_to_notes(midi_file: str) -> pd.DataFrame:
#     pm = pretty_midi.PrettyMIDI(midi_file)
#     notes_data = {'instrument': [], 'pitch': [], 'start': [], 'end': [], 'duration': []}

#     for instrument in pm.instruments:
#         instrument_name = pretty_midi.program_to_instrument_name(instrument.program)
#         for note in instrument.notes:
#             notes_data['instrument'].append(instrument_name)
#             notes_data['pitch'].append(note.pitch)
#             notes_data['start'].append(note.start)
#             notes_data['end'].append(note.end)
#             notes_data['duration'].append(note.end - note.start)

#     return pd.DataFrame(notes_data)

In [14]:
# midi_file_path = "ym2413_jupyter_proj/out/01 - Game de check! Koutsuu Anzen (FM) - Instructions_000.mid"
# notes_df = midi_to_notes(midi_file_path)
# print(notes_df.head(10))