In [None]:
# start+end Timestamp with chord labels

import os
import pretty_midi
import pandas as pd
from collections import defaultdict


# normalize chord, removing octave transpositions 
def normalize_chord(chord_tuple):
    normalized_chord = {note % 12 for note in chord_tuple}  # keep only unique notes modulo 12
    return tuple(sorted(normalized_chord))

# index mapping for chord vocab based on set of chords in the data
def create_chord_vocab(chords):
    unique_chords = sorted(set(chords)) # ensure consistency
    chord_to_index = {chord: idx for idx, chord in enumerate(unique_chords)}
    return chord_to_index

# extract chord sequence
def midi_to_chord_sequence(midi_file, merge_threshold=0.3):
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    
    events = []
    # for each note, add two events: on/off
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        for note in instrument.notes:
            events.append((note.start, 'on', note.pitch))
            events.append((note.end, 'off', note.pitch))
    
    events.sort()
    
    active_notes = set()    # track notes that are in use
    chords = []             # final list
    previous_chord = None   # last note
    chord_start_time = None # when curr note start
    last_event_time = 0

    # if note is starting, add to active set
    # if note ending, remove it from active set
    for time, action, pitch in events:
        if action == 'on':
            active_notes.add(pitch)
        elif action == 'off':
            active_notes.discard(pitch)

        current_chord = normalize_chord(active_notes) if active_notes else None
        
        # if chord changed
        if current_chord != previous_chord:
            if previous_chord is not None and chord_start_time is not None:
                # only save prev chord if it lasted long enough
                if time - chord_start_time >= merge_threshold:
                    chords.append((round(chord_start_time, 3), round(time, 3), previous_chord))
            # start tracking new chord from this time
            chord_start_time = time
            previous_chord = current_chord

        last_event_time = time

    # capture final chord if any
    if previous_chord is not None and chord_start_time is not None:
        chords.append((round(chord_start_time, 3), round(midi_data.get_end_time(), 3), previous_chord))

    return chords

# process all midi files in the folder, save to csv
def process_midi_folder(midi_folder, output_csv):
    data = []
    all_chords = []  # collect all chords for vocab creation
    
    for midi_file in os.listdir(midi_folder):
        if midi_file.endswith(".mid") or midi_file.endswith(".midi"):
            file_path = os.path.join(midi_folder, midi_file)
            try:
                chords = midi_to_chord_sequence(file_path)
                all_chords.extend([chord for _, _, chord in chords])  # add chords to the list for vocab creation
                
                for timestamp_start, timestamp_end, chord in chords:
                    data.append([midi_file, timestamp_start, timestamp_end, chord])

            except Exception as e:
                print(f"Error processing {midi_file}: {e}")
    
    chord_to_index = create_chord_vocab(all_chords)

    # save to csv
    df = pd.DataFrame(data, columns=["filename", "start_time", "end_time", "chord"])
    df.to_csv(output_csv, index=False)
    print(f"Dataset saved to {output_csv}")
    
    return chord_to_index  # return the generated chord vocabulary


midi_folder = "midi_folder"  
output_csv = "chord_dataset.csv"
chord_to_index = process_midi_folder(midi_folder, output_csv)


Dataset saved to chord_dataset.csv


In [5]:
# # Single Timestamp with chord labels

# # normalize chord, removing octave transpositions 
# def normalize_chord(chord_tuple):
#     normalized_chord = {note % 12 for note in chord_tuple}  # keep only unique notes modulo 12
#     return tuple(sorted(normalized_chord))

# # index mapping for chord vocab based on set of chords in the data
# def create_chord_vocab(chords):
#     unique_chords = sorted(set(chords)) # ensure consistency
#     chord_to_index = {chord: idx for idx, chord in enumerate(unique_chords)}
#     return chord_to_index

# def midi_to_chord_sequence(midi_file):
#     midi_data = pretty_midi.PrettyMIDI(midi_file)
    
#     # dictionary to store active notes at each time
#     active_notes = defaultdict(set)
    
#     for instrument in midi_data.instruments:
#         if instrument.is_drum:
#             continue  # skip drum 
        
#         for note in instrument.notes:
#             active_notes[note.start].add(note.pitch)
#             active_notes[note.end].discard(note.pitch)
    
#     timestamps = sorted(active_notes.keys())
#     chords = []
#     previous_chord = None

#     active_chord = set()
#     for t in timestamps:
#         active_chord.update(active_notes[t])
#         chord_label = tuple(sorted(active_chord)) 
        
#         # only store if the chord actually changes
#         if chord_label and chord_label != previous_chord:
#             normalized_chord = normalize_chord(chord_label)
#             rounded_timestamp = round(t, 3)
#             chords.append((rounded_timestamp, normalized_chord))  
#             previous_chord = normalized_chord
    
#     return chords

# def process_midi_folder(midi_folder, output_csv):
#     data = []
#     all_chords = [] 
    
#     for midi_file in os.listdir(midi_folder):
#         if midi_file.endswith(".mid") or midi_file.endswith(".midi"):
#             file_path = os.path.join(midi_folder, midi_file)
#             try:
#                 chords = midi_to_chord_sequence(file_path)
#                 all_chords.extend([chord for _, chord in chords]) 
                
#                 for timestamp, chord in chords:
#                     data.append([midi_file, timestamp, chord])
#             except Exception as e:
#                 print(f"Error processing {midi_file}: {e}")
    
#     chord_to_index = create_chord_vocab(all_chords)

#     df = pd.DataFrame(data, columns=["filename", "timestamp", "chord"])
#     df.to_csv(output_csv, index=False)
#     print(f"Dataset saved to {output_csv}")
    
#     return chord_to_index  


# midi_folder = "midi_folder"  
# output_csv = "chord_dataset.csv"
# chord_to_index = process_midi_folder(midi_folder, output_csv)


Dataset saved to chord_dataset.csv


In [5]:
# one-hot encoding 
# chord dictionary is based on current chords
def one_hot_encode(chord, chord_to_index):
    chord_index = chord_to_index.get(chord)
    if chord_index is not None:
        one_hot = [0] * len(chord_to_index)
        one_hot[chord_index] = 1
        return one_hot
    return None
# C maj7
sample_chord = (2,5,9)  
one_hot_sample = one_hot_encode(sample_chord, chord_to_index)
print("One-hot encoded chord:", one_hot_sample)

One-hot encoded chord: [0, 0, 0, 1, 0]


In [None]:
# decode output from trained model

