# MIDI-Trained Chord Recognition Model

## Data Preprocessing

### 1. Load and Extract from midi_folder

In [32]:
import os
import json
import pretty_midi
import pandas as pd
import numpy as np
from collections import defaultdict
import mido
import io

# define chord type templates: intervals relative to root
CHORD_TEMPLATES = {
    "Major":         {0, 4, 7},
    "Minor":         {0, 3, 7},
    "Dominant 7th":  {0, 4, 7, 10},
    "Diminished":    {0, 3, 6},
    "Augmented":     {0, 4, 8},
}

PITCH_CLASS_NAMES = ['C', 'C#', 'D', 'D#', 'E', 'F',
                     'F#', 'G', 'G#', 'A', 'A#', 'B']

# normalize chord, removing octave transpositions 
def normalize_chord(chord_tuple):
    normalized_chord = {note % 12 for note in chord_tuple}  # keep only unique notes modulo 12
    return tuple(sorted(normalized_chord))

# identify and name chords 
def identify_named_chord(chord_tuple):
    if not chord_tuple:
        return "Unknown"

    pitch_classes = sorted({p % 12 for p in chord_tuple})
    for root in pitch_classes:
        transposed = sorted({(p - root) % 12 for p in pitch_classes})
        for label, template in CHORD_TEMPLATES.items():
            if set(transposed) == template:
                root_name = PITCH_CLASS_NAMES[root]
                return f"{root_name} {label}"
    return "Unknown"

# fixed mapping for chord vocab: all 12 roots * templates
def create_fixed_chord_vocab():
    ALL_CHORDS = [
        f"{pitch} {chord_type}"
        for pitch in PITCH_CLASS_NAMES
        for chord_type in CHORD_TEMPLATES.keys()
    ]
    chord_to_index = {chord: idx for idx, chord in enumerate(ALL_CHORDS)}
    return chord_to_index

# extract chord sequence
def midi_to_chord_sequence(midi_file, merge_threshold=0.3):
    #midi_data = pretty_midi.PrettyMIDI(midi_file)
    
    raw = mido.MidiFile(midi_file, clip=True)
    merged = mido.MidiFile() 
    merged.ticks_per_beat = raw.ticks_per_beat
    merged_track = mido.merge_tracks(raw.tracks)
    merged.tracks.append(merged_track)
    
    # dump to memory buffer
    buf = io.BytesIO()
    merged.save(file=buf)
    buf.seek(0)

    midi_data = pretty_midi.PrettyMIDI(buf)

    events = []
    # for each note, add two events: on/off
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        for note in instrument.notes:
            events.append((note.start, 'on', note.pitch))
            events.append((note.end, 'off', note.pitch))
    

    events.sort(key=lambda x: x[0])

    active_notes = set()  # track notes that are in use
    chords = []  # final list
    previous_chord = None
    chord_start_time = None
    last_event_time = 0

    # if note is starting, add to active set
    # if note ending, remove it from active set
    for time, action, pitch in events:
        if action == 'on':
            active_notes.add(pitch)
        elif action == 'off':
            active_notes.discard(pitch)

        current_chord = normalize_chord(active_notes) if active_notes else None
        chord_label = identify_named_chord(current_chord) if current_chord else None

        # if chord changed
        if chord_label != previous_chord:
            if previous_chord is not None and chord_start_time is not None:
                if time - chord_start_time >= merge_threshold:
                    chords.append((round(chord_start_time, 3), round(time, 3), previous_chord))
            chord_start_time = time
            previous_chord = chord_label

        last_event_time = time

    # capture final chord if any
    if previous_chord is not None and chord_start_time is not None:
        chords.append((round(chord_start_time, 3), round(midi_data.get_end_time(), 3), previous_chord))

    return chords, midi_data

# timeframe-level feature extraction and align with chord labels
def extract_frame_level_data(chords, midi_data, chord_to_index, frame_hop=1):
    end_time = midi_data.get_end_time()
    frame_times = np.arange(0, end_time, frame_hop)

    chroma = midi_data.get_chroma(fs=int(1 / frame_hop))
    chroma = chroma.T  # transpose to shape (frames, 12)

    data = []

    for i, t in enumerate(frame_times):
        frame_feature = chroma[i] if i < len(chroma) else np.zeros(12)
        label = None
        for start, end, chord in chords:
            if start <= t < end:
                if chord in chord_to_index:
                    label = chord_to_index[chord]
                break
        if label is not None:
            data.append((t, frame_feature, label))
    return data


# process all midi files in the folder, save to CSV
def process_midi_folder(input_folder, chord_csv, frame_csv, frame_hop=1):
    chord_rows = []
    frame_rows = []
    chord_to_index = create_fixed_chord_vocab()

    for root, _, files in os.walk(input_folder):
        for fname in files:
            if not fname.lower().endswith(('.mid','.midi')): continue
            path = os.path.join(root, fname)
            rel = os.path.relpath(path, input_folder)
            try:
                chords, midi = midi_to_chord_sequence(path)
                # chord-level
                for st, ed, ch in chords:
                    chord_rows.append([rel, st, ed, ch])
                # frame-level
                frames = extract_frame_level_data(chords, midi, chord_to_index, frame_hop)
                for t, feat, lbl in frames:
                    frame_rows.append([rel, t, *feat, lbl])

            except Exception as e:
                print(f"[ERROR] {rel}: {e}")

    # save to csv
    chord_df = pd.DataFrame(chord_rows, columns=["filename","start_time","end_time","chord"])
    chord_df.to_csv(chord_csv, index=False)
    cols = [f"chroma_{i}" for i in range(12)]
    frame_df = pd.DataFrame(frame_rows, columns=["filename","time", *cols, "label"])
    frame_df.to_csv(frame_csv, index=False)

    print(f"✔ Saved chords to: {chord_csv}")
    print(f"✔ Saved frames to: {frame_csv}")
    return chord_to_index


# def process_midi_folder(midi_folder, chord_output_csv, frame_output_csv, frame_hop=1):
#     chord_data = []
#     frame_data = []

#     chord_to_index = create_fixed_chord_vocab()

#     for midi_file in os.listdir(midi_folder):
#         if midi_file.endswith(".mid") or midi_file.endswith(".midi"):
#             file_path = os.path.join(midi_folder, midi_file)
#             try:
#                 chords, midi_data = midi_to_chord_sequence(file_path)
#                 for timestamp_start, timestamp_end, chord in chords:
#                     chord_data.append([midi_file, timestamp_start, timestamp_end, chord])
#             except Exception as e:
#                 print(f"Error processing {midi_file}: {e}")

#     # second pass to align frame-wise data using finalized vocab
#     for midi_file in os.listdir(midi_folder):
#         if midi_file.endswith(".mid") or midi_file.endswith(".midi"):
#             file_path = os.path.join(midi_folder, midi_file)
#             try:
#                 chords, midi_data = midi_to_chord_sequence(file_path)
#                 frame_entries = extract_frame_level_data(chords, midi_data, chord_to_index, frame_hop)
#                 for t, feat, label in frame_entries:
#                     frame_data.append([midi_file, round(t, 3)] + list(feat) + [label])
#             except Exception as e:
#                 print(f"Error processing {midi_file} for frame-level: {e}")

#     # save chord segment CSV
#     chord_df = pd.DataFrame(chord_data, columns=["filename", "start_time", "end_time", "chord"])
#     chord_df.to_csv(chord_output_csv, index=False)

#     # save frame-level CSV
#     feat_cols = [f"chroma_{i}" for i in range(12)]
#     frame_df = pd.DataFrame(frame_data, columns=["filename", "time"] + feat_cols + ["label"])
#     frame_df.to_csv(frame_output_csv, index=False)

#     print(f"Chord segments saved to {chord_output_csv}")
#     print(f"Frame-level data saved to {frame_output_csv}")
    
#     return chord_to_index

### 2. Extract and Combine to csv file

In [33]:
# paths
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)

folder_to_process = 'midi_folder' # test use; change to 'lakh-midi-clean' for actual experiments

base = os.path.basename(folder_to_process.rstrip(os.sep))
chord_csv = os.path.join(output_dir, f"chord_dataset.csv")
frame_csv = os.path.join(output_dir, f"timeframe_dataset.csv")
vocab_json = os.path.join(output_dir, f"chord_vocab.json")

chord_to_index = process_midi_folder(folder_to_process, chord_csv, frame_csv)

with open(vocab_json, 'w') as f:
    json.dump(chord_to_index, f, indent=2)
    

✔ Saved chords to: output/chord_dataset.csv
✔ Saved frames to: output/timeframe_dataset.csv


### 3. One-hot Encoding

In [34]:
# one-hot encoding 
import pandas as pd
import numpy as np
import os
import json

output_dir = "output"
os.makedirs(output_dir, exist_ok=True)

frame_csv_path = "output/timeframe_dataset.csv"
chord_vocab_path = "output/chord_vocab.json"
output_onehot_csv_path = os.path.join(output_dir, "timeframe_onehot.csv")


# load from JSON file
with open(chord_vocab_path, "r") as f:
    chord_to_index = json.load(f)

# reverse
chord_to_index = {str(k): v for k, v in chord_to_index.items()}


def one_hot_encode_labels(label_indices, num_classes):
    return np.eye(num_classes)[label_indices]

# load original timeframe-level dataset
df = pd.read_csv(frame_csv_path)

# get label col
label_indices = df["label"].astype(int).values

# one-hot encoding 
num_classes = len(chord_to_index)
one_hot = one_hot_encode_labels(label_indices, num_classes)

# create DataFrame 
one_hot_columns = [f"class_{i}" for i in range(num_classes)]
one_hot_df = pd.DataFrame(one_hot, columns=one_hot_columns)

# combine with filename + time 
minimal_df = df[["filename", "time"]].reset_index(drop=True)
result_df = pd.concat([minimal_df, one_hot_df], axis=1)

result_df.to_csv(output_onehot_csv_path, index=False)

print(f"One-hot encoded data saved to {output_onehot_csv_path}")

One-hot encoded data saved to output/timeframe_onehot.csv


## Baseline Model: SVM

In [35]:
import pandas as pd
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler

frame_csv_path = "output/timeframe_dataset.csv"
df = pd.read_csv(frame_csv_path)

# split to train and test dataset
feature_cols = [f"chroma_{i}" for i in range(12)]

X = df[feature_cols].values
y = df["label"].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# RBF kernel 
svm_model = SVC()
svm_model.fit(X_train_scaled, y_train)

y_pred = svm_model.predict(X_test_scaled)

# print confusion metrics with zero_division fix
print("Classification Report:")
print(classification_report(y_test, y_pred, zero_division=0))

print("Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))


Classification Report:
              precision    recall  f1-score   support

           0       0.47      0.71      0.57       185
           1       0.50      0.38      0.43        24
           2       0.47      0.37      0.41        19
           3       0.00      0.00      0.00         1
           4       0.00      0.00      0.00         2
           5       0.55      0.68      0.61        31
           6       0.75      0.46      0.57        13
           7       0.00      0.00      0.00         4
           8       0.00      0.00      0.00         1
           9       0.00      0.00      0.00         2
          10       0.59      0.74      0.66       133
          11       0.53      0.47      0.49        45
          12       0.57      0.29      0.38        14
          13       0.00      0.00      0.00         1
          14       0.00      0.00      0.00         1
          15       0.69      0.74      0.71        65
          16       0.00      0.00      0.00         5
    

## Deep Learning Models

### Reorganize timing data

In [36]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

# auto-select device:
if torch.cuda.is_available():
    device = torch.device("cuda")
    backend = "CUDA"
elif getattr(torch.backends, "mps", None) is not None \
     and torch.backends.mps.is_available():
    device = torch.device("mps")
    backend = "MPS (Apple Silicon)"
else:
    device = torch.device("cpu")
    backend = "CPU"

print(f"Using device: {device}  |  backend: {backend}")


# convert one-hot
y_tr_idx = np.argmax(y_tr, axis=-1).astype(np.int64)   # shape (N, seq_len)
y_te_idx = np.argmax(y_te, axis=-1).astype(np.int64)

# convert to PyTorch tensors
batch_size = 16

train_ds = TensorDataset(
    torch.tensor(X_tr, dtype=torch.float32),
    torch.tensor(y_tr_idx, dtype=torch.long)
)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

val_ds = TensorDataset(
    torch.tensor(X_te, dtype=torch.float32),
    torch.tensor(y_te_idx, dtype=torch.long)
)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

seq_len, num_feat = X_tr.shape[1], X_tr.shape[2]
num_classes       = y_tr.shape[-1]        # 24


Using device: mps  |  backend: MPS (Apple Silicon)


### 1. CNN Model

### 2. RNN Model

In [37]:
class SimpleRNNModel(nn.Module):
    def __init__(self, n_feat, n_classes, hidden=64):
        super().__init__()
        self.rnn = nn.RNN(
            input_size=n_feat,
            hidden_size=hidden,
            batch_first=True,
            nonlinearity="tanh"
        )
        self.fc  = nn.Linear(hidden, n_classes)

    def forward(self, x):                    # x: (B, T, F)
        out, _ = self.rnn(x)                 # (B, T, H)
        out    = self.fc(out)                # (B, T, C) logits
        return out

rnn_model = SimpleRNNModel(num_feat, num_classes).to(device)
print(rnn_model)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn_model.parameters(), lr=1e-3)

epochs, patience = 30, 3
best_loss, patience_ctr = np.inf, 0

for epoch in range(1, epochs+1):
    # training
    rnn_model.train()
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)          # yb (B, T)
        optimizer.zero_grad()
        logits = rnn_model(xb)                         # (B, T, C)
        loss = criterion(
            logits.reshape(-1, num_classes),
            yb.reshape(-1)
        )
        loss.backward()
        optimizer.step()

    # validation
    rnn_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            logits = rnn_model(xb)
            loss   = criterion(
                logits.reshape(-1, num_classes),
                yb.reshape(-1)
            )
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_ds)

    print(f"Epoch {epoch:02d}  val_loss={val_loss:.4f}")

    # early-stopping
    if val_loss < best_loss:
        best_loss = val_loss
        patience_ctr = 0
    else:
        patience_ctr += 1
        if patience_ctr >= patience:
            print("Early stopping.")
            break

print("✓ RNN training done!")


SimpleRNNModel(
  (rnn): RNN(12, 64, batch_first=True)
  (fc): Linear(in_features=64, out_features=59, bias=True)
)
Epoch 01  val_loss=3.4412
Epoch 02  val_loss=2.2568
Epoch 03  val_loss=1.7548
Epoch 04  val_loss=1.5431
Epoch 05  val_loss=1.4525
Epoch 06  val_loss=1.4005
Epoch 07  val_loss=1.3597
Epoch 08  val_loss=1.3257
Epoch 09  val_loss=1.0984
Epoch 10  val_loss=1.0734
Epoch 11  val_loss=1.2563
Epoch 12  val_loss=1.0358
Epoch 13  val_loss=1.0258
Epoch 14  val_loss=1.2236
Epoch 15  val_loss=1.2144
Epoch 16  val_loss=0.9957
Epoch 17  val_loss=0.9838
Epoch 18  val_loss=0.9740
Epoch 19  val_loss=0.9668
Epoch 20  val_loss=0.9607
Epoch 21  val_loss=0.9531
Epoch 22  val_loss=0.9490
Epoch 23  val_loss=0.9449
Epoch 24  val_loss=0.9393
Epoch 25  val_loss=0.9326
Epoch 26  val_loss=0.9302
Epoch 27  val_loss=0.9255
Epoch 28  val_loss=0.9226
Epoch 29  val_loss=0.9204
Epoch 30  val_loss=0.9158
✓ RNN training done!


### 3. LSTM Model

In [38]:
# bidirectional LSTM model
class BiLSTMModel(nn.Module):
    def __init__(self, n_feat, n_classes, hidden=64):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=n_feat,
            hidden_size=hidden,
            batch_first=True,
            bidirectional=True
        )
        self.fc = nn.Linear(hidden*2, n_classes)

    def forward(self, x):
        out, _ = self.lstm(x)            # (B, T, 2H)
        out = self.fc(out)               # (B, T, C)
        return out

lstm_model = BiLSTMModel(num_feat, num_classes).to(device)
print(lstm_model)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=1e-3)

epochs, patience = 30, 3
best_loss, patience_ctr = np.inf, 0

for epoch in range(1, epochs+1):
    # training
    lstm_model.train()
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = lstm_model(xb)
        loss   = criterion(
            logits.reshape(-1, num_classes),
            yb.reshape(-1)
        )
        loss.backward()
        optimizer.step()

    # validation
    lstm_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            logits = lstm_model(xb)
            loss   = criterion(
                logits.reshape(-1, num_classes),
                yb.reshape(-1)
            )
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_ds)

    print(f"Epoch {epoch:02d}  val_loss={val_loss:.4f}")

    if val_loss < best_loss:
        best_loss = val_loss
        patience_ctr = 0
    else:
        patience_ctr += 1
        if patience_ctr >= patience:
            print("Early stopping.")
            break

lstm_model.load_state_dict(torch.load("best_lstm.pt"))
print("✓ LSTM training done!")


BiLSTMModel(
  (lstm): LSTM(12, 64, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=128, out_features=59, bias=True)
)
Epoch 01  val_loss=3.8850
Epoch 02  val_loss=3.3640
Epoch 03  val_loss=1.7729
Epoch 04  val_loss=1.3918
Epoch 05  val_loss=1.2711
Epoch 06  val_loss=1.1898
Epoch 07  val_loss=1.1248
Epoch 08  val_loss=1.0695
Epoch 09  val_loss=1.0271
Epoch 10  val_loss=0.9926
Epoch 11  val_loss=0.9640
Epoch 12  val_loss=0.9403
Epoch 13  val_loss=0.9219
Epoch 14  val_loss=0.9032
Epoch 15  val_loss=0.8888
Epoch 16  val_loss=0.8763
Epoch 17  val_loss=0.8662
Epoch 18  val_loss=0.8563
Epoch 19  val_loss=0.8479
Epoch 20  val_loss=0.8404
Epoch 21  val_loss=0.8346
Epoch 22  val_loss=0.8311
Epoch 23  val_loss=0.8292
Epoch 24  val_loss=0.8224
Epoch 25  val_loss=0.8172
Epoch 26  val_loss=0.8152
Epoch 27  val_loss=0.8105
Epoch 28  val_loss=0.8058
Epoch 29  val_loss=0.8031
Epoch 30  val_loss=0.8001


FileNotFoundError: [Errno 2] No such file or directory: 'best_lstm.pt'

### 4. CNN + LSTM Model

## Evaluation

In [39]:
def predict_np(model, X_np, batch_size=128):
    model.eval()
    out = []
    with torch.no_grad():
        for i in range(0, len(X_np), batch_size):
            xb = torch.tensor(X_np[i:i+batch_size], dtype=torch.float32, device=device)
            logits = model(xb)                # (B, T, C)
            prob   = torch.softmax(logits, -1)
            out.append(prob.cpu().numpy())
    return np.concatenate(out, axis=0)

rnn_probs  = predict_np(rnn_model,  X_te)
lstm_probs = predict_np(lstm_model, X_te)


In [40]:
import numpy as np
import mir_eval

def evaluate_chord_predictions(name, y_pred_prob, y_true_ohe, frame_rate=1.0):
    """
    name : Identifier for printing.
    y_pred_prob : ndarray, shape (N, seq_len, C)
        Predicted class probabilities or logits.
    y_true_ohe : ndarray, shape (N, seq_len, C)
        One-hot ground-truth labels.
    frame_rate : float, default 1.0
        Frames per second (interval length = 1/frame_rate sec).
    """
    # flatten to 1-D vectors of class indices
    y_pred_int = y_pred_prob.argmax(-1).flatten()
    y_true_int = y_true_ohe.argmax(-1).flatten()

    # map int to mir_eval chord strings 
    est_labels = ints_to_chords(y_pred_int)
    ref_labels = ints_to_chords(y_true_int)

    n = len(y_true_int)
    intervals = np.column_stack([np.arange(n)/frame_rate,
                                 (np.arange(n)+1)/frame_rate])

    result = mir_eval.chord.evaluate(
        intervals, ref_labels, intervals, est_labels
    )

    # normalize output (names, scores)
    if isinstance(result, dict):
        score_names = list(result.keys())
        scores      = list(result.values())

    elif isinstance(result, (list, tuple)) and len(result) == 2:
        a, b = result
        score_names, scores = (a, b) if isinstance(a[0], str) else (b, a)

    else:
        scores = list(result)
        score_names = ["root", "majmin", "thirds",
                       "triads", "sevenths", "tetrads", "mirex"][:len(scores)]

    # ---- print ----
    print(f"\n=== {name} ===")
    for nm, sc in zip(score_names, scores):
        print(f"{nm:>10}: {sc:.4f}")

    frame_acc = (y_pred_int == y_true_int).mean()
    print(f"{'frame_acc':>10}: {frame_acc:.4f}")


evaluate_chord_predictions("RNN Model",  rnn_probs,  y_te)
evaluate_chord_predictions("LSTM Model", lstm_probs, y_te)



=== RNN Model ===
    thirds: 0.8118
thirds_inv: 0.8118
    triads: 0.8118
triads_inv: 0.8118
   tetrads: 0.7787
tetrads_inv: 0.7787
      root: 0.8256
     mirex: 0.8118
    majmin: 0.8144
majmin_inv: 0.8144
  sevenths: 0.7813
sevenths_inv: 0.7813
  underseg: 0.9342
   overseg: 0.9197
       seg: 0.9197
 frame_acc: 0.7787

=== LSTM Model ===
    thirds: 0.8205
thirds_inv: 0.8205
    triads: 0.8205
triads_inv: 0.8205
   tetrads: 0.7918
tetrads_inv: 0.7918
      root: 0.8289
     mirex: 0.8205
    majmin: 0.8232
majmin_inv: 0.8232
  sevenths: 0.7944
sevenths_inv: 0.7944
  underseg: 0.8924
   overseg: 0.9353
       seg: 0.8924
 frame_acc: 0.7918
