In [18]:
import os
import pretty_midi
import pandas as pd
import numpy as np
from collections import defaultdict

# define chord type templates: intervals relative to root
CHORD_TEMPLATES = {
    "Major":         {0, 4, 7},
    "Minor":         {0, 3, 7},
    "Dominant 7th":  {0, 4, 7, 10},
    "Diminished":    {0, 3, 6},
    "Augmented":     {0, 4, 8},
}

PITCH_CLASS_NAMES = ['C', 'C#', 'D', 'D#', 'E', 'F',
                     'F#', 'G', 'G#', 'A', 'A#', 'B']

# normalize chord, removing octave transpositions 
def normalize_chord(chord_tuple):
    normalized_chord = {note % 12 for note in chord_tuple}  # keep only unique notes modulo 12
    return tuple(sorted(normalized_chord))

# identify and name chords 
def identify_named_chord(chord_tuple):
    if not chord_tuple:
        return "Unknown"

    pitch_classes = sorted({p % 12 for p in chord_tuple})
    for root in pitch_classes:
        transposed = sorted({(p - root) % 12 for p in pitch_classes})
        for label, template in CHORD_TEMPLATES.items():
            if set(transposed) == template:
                root_name = PITCH_CLASS_NAMES[root]
                return f"{root_name} {label}"
    return "Unknown"

# fixed mapping for chord vocab: all 12 roots * templates
def create_fixed_chord_vocab():
    ALL_CHORDS = [
        f"{pitch} {chord_type}"
        for pitch in PITCH_CLASS_NAMES
        for chord_type in CHORD_TEMPLATES.keys()
    ]
    chord_to_index = {chord: idx for idx, chord in enumerate(ALL_CHORDS)}
    return chord_to_index

# extract chord sequence
def midi_to_chord_sequence(midi_file, merge_threshold=0.3):
    midi_data = pretty_midi.PrettyMIDI(midi_file)

    events = []
    # for each note, add two events: on/off
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        for note in instrument.notes:
            events.append((note.start, 'on', note.pitch))
            events.append((note.end, 'off', note.pitch))

    events.sort()

    active_notes = set()  # track notes that are in use
    chords = []  # final list
    previous_chord = None
    chord_start_time = None
    last_event_time = 0

    # if note is starting, add to active set
    # if note ending, remove it from active set
    for time, action, pitch in events:
        if action == 'on':
            active_notes.add(pitch)
        elif action == 'off':
            active_notes.discard(pitch)

        current_chord = normalize_chord(active_notes) if active_notes else None
        chord_label = identify_named_chord(current_chord) if current_chord else None

        # if chord changed
        if chord_label != previous_chord:
            if previous_chord is not None and chord_start_time is not None:
                if time - chord_start_time >= merge_threshold:
                    chords.append((round(chord_start_time, 3), round(time, 3), previous_chord))
            chord_start_time = time
            previous_chord = chord_label

        last_event_time = time

    # capture final chord if any
    if previous_chord is not None and chord_start_time is not None:
        chords.append((round(chord_start_time, 3), round(midi_data.get_end_time(), 3), previous_chord))

    return chords, midi_data

# timeframe-level feature extraction and align with chord labels
def extract_frame_level_data(chords, midi_data, chord_to_index, frame_hop=1):
    end_time = midi_data.get_end_time()
    frame_times = np.arange(0, end_time, frame_hop)

    chroma = midi_data.get_chroma(fs=int(1 / frame_hop))
    chroma = chroma.T  # transpose to shape (frames, 12)

    data = []

    for i, t in enumerate(frame_times):
        frame_feature = chroma[i] if i < len(chroma) else np.zeros(12)
        label = None
        for start, end, chord in chords:
            if start <= t < end:
                if chord in chord_to_index:
                    label = chord_to_index[chord]
                break
        if label is not None:
            data.append((t, frame_feature, label))
    return data

# process all midi files in the folder, save to CSV
def process_midi_folder(midi_folder, chord_output_csv, frame_output_csv, frame_hop=1):
    chord_data = []
    frame_data = []

    chord_to_index = create_fixed_chord_vocab()

    for midi_file in os.listdir(midi_folder):
        if midi_file.endswith(".mid") or midi_file.endswith(".midi"):
            file_path = os.path.join(midi_folder, midi_file)
            try:
                chords, midi_data = midi_to_chord_sequence(file_path)
                for timestamp_start, timestamp_end, chord in chords:
                    chord_data.append([midi_file, timestamp_start, timestamp_end, chord])
            except Exception as e:
                print(f"Error processing {midi_file}: {e}")

    # second pass to align frame-wise data using finalized vocab
    for midi_file in os.listdir(midi_folder):
        if midi_file.endswith(".mid") or midi_file.endswith(".midi"):
            file_path = os.path.join(midi_folder, midi_file)
            try:
                chords, midi_data = midi_to_chord_sequence(file_path)
                frame_entries = extract_frame_level_data(chords, midi_data, chord_to_index, frame_hop)
                for t, feat, label in frame_entries:
                    frame_data.append([midi_file, round(t, 3)] + list(feat) + [label])
            except Exception as e:
                print(f"Error processing {midi_file} for frame-level: {e}")

    # save chord segment CSV
    chord_df = pd.DataFrame(chord_data, columns=["filename", "start_time", "end_time", "chord"])
    chord_df.to_csv(chord_output_csv, index=False)

    # save frame-level CSV
    feat_cols = [f"chroma_{i}" for i in range(12)]
    frame_df = pd.DataFrame(frame_data, columns=["filename", "time"] + feat_cols + ["label"])
    frame_df.to_csv(frame_output_csv, index=False)

    print(f"Chord segments saved to {chord_output_csv}")
    print(f"Frame-level data saved to {frame_output_csv}")
    
    return chord_to_index

# paths
midi_folder = "midi_folder"  
chord_output_csv = "chord_dataset.csv"
frame_output_csv = "timeframe_dataset.csv"

chord_to_index = process_midi_folder(midi_folder, chord_output_csv, frame_output_csv)

with open("chord_vocab.json", "w") as f:
    json.dump(chord_to_index, f)



Chord segments saved to chord_dataset.csv
Frame-level data saved to timeframe_dataset.csv


In [19]:
# one-hot encoding 
import pandas as pd
import numpy as np
import os
import json

frame_csv_path = "timeframe_dataset.csv"
chord_vocab_path = "chord_vocab.json"        
output_onehot_csv_path = "timeframe_onehot.csv"

# load from JSON file
with open(chord_vocab_path, "r") as f:
    chord_to_index = json.load(f)

# reverse
chord_to_index = {str(k): v for k, v in chord_to_index.items()}


def one_hot_encode_labels(label_indices, num_classes):
    return np.eye(num_classes)[label_indices]

# load original timeframe-level dataset
df = pd.read_csv(frame_csv_path)

# get label col
label_indices = df["label"].astype(int).values

# one-hot encoding 
num_classes = len(chord_to_index)
one_hot = one_hot_encode_labels(label_indices, num_classes)

# create DataFrame 
one_hot_columns = [f"class_{i}" for i in range(num_classes)]
one_hot_df = pd.DataFrame(one_hot, columns=one_hot_columns)

# combine with filename + time 
minimal_df = df[["filename", "time"]].reset_index(drop=True)
result_df = pd.concat([minimal_df, one_hot_df], axis=1)

result_df.to_csv(output_onehot_csv_path, index=False)

print(f"One-hot encoded data saved to {output_onehot_csv_path}")


One-hot encoded data saved to timeframe_onehot.csv
