# MIDI-Trained Chord Recognition Model

## Data Preprocessing 

### 1. Load and Extract from midi_folder

In [50]:
import os
import json
import pretty_midi
import pandas as pd
import numpy as np
from collections import defaultdict
import mido
import io

# define chord type templates: intervals relative to root
CHORD_TEMPLATES = {
    "Major":         {0, 4, 7},
    "Minor":         {0, 3, 7},
    "Dominant 7th":  {0, 4, 7, 10},
    "Diminished":    {0, 3, 6},
    "Augmented":     {0, 4, 8},
}

PITCH_CLASS_NAMES = ['C', 'C#', 'D', 'D#', 'E', 'F',
                     'F#', 'G', 'G#', 'A', 'A#', 'B']

# normalize chord, removing octave transpositions 
def normalize_chord(chord_tuple):
    normalized_chord = {note % 12 for note in chord_tuple}  # keep only unique notes modulo 12
    return tuple(sorted(normalized_chord))

# identify and name chords 
def identify_named_chord(chord_tuple):
    if not chord_tuple:
        return "Unknown"

    pitch_classes = sorted({p % 12 for p in chord_tuple})
    for root in pitch_classes:
        transposed = sorted({(p - root) % 12 for p in pitch_classes})
        for label, template in CHORD_TEMPLATES.items():
            if set(transposed) == template:
                root_name = PITCH_CLASS_NAMES[root]
                return f"{root_name} {label}"
    return "Unknown"

# fixed mapping for chord vocab: all 12 roots * templates
def create_fixed_chord_vocab():
    ALL_CHORDS = [
        f"{pitch} {chord_type}"
        for pitch in PITCH_CLASS_NAMES
        for chord_type in CHORD_TEMPLATES.keys()
    ]
    chord_to_index = {chord: idx for idx, chord in enumerate(ALL_CHORDS)}
    return chord_to_index

# extract chord sequence
def midi_to_chord_sequence(midi_file, merge_threshold=0.3):
    #midi_data = pretty_midi.PrettyMIDI(midi_file)
    
    raw = mido.MidiFile(midi_file, clip=True)
    merged = mido.MidiFile() 
    merged.ticks_per_beat = raw.ticks_per_beat
    merged_track = mido.merge_tracks(raw.tracks)
    merged.tracks.append(merged_track)
    
    # dump to memory buffer
    buf = io.BytesIO()
    merged.save(file=buf)
    buf.seek(0)

    midi_data = pretty_midi.PrettyMIDI(buf)

    events = []
    # for each note, add two events: on/off
    for instrument in midi_data.instruments:
        if instrument.is_drum:
            continue
        for note in instrument.notes:
            events.append((note.start, 'on', note.pitch))
            events.append((note.end, 'off', note.pitch))
    
    events.sort(key=lambda x: x[0])

    active_notes = set()  # track notes that are in use
    chords = []  # final list
    previous_chord = None
    chord_start_time = None
    last_event_time = 0

    # if note is starting, add to active set
    # if note ending, remove it from active set
    for time, action, pitch in events:
        if action == 'on':
            active_notes.add(pitch)
        elif action == 'off':
            active_notes.discard(pitch)

        current_chord = normalize_chord(active_notes) if active_notes else None
        chord_label = identify_named_chord(current_chord) if current_chord else None

        # if chord changed
        if chord_label != previous_chord:
            if previous_chord is not None and chord_start_time is not None:
                if time - chord_start_time >= merge_threshold:
                    chords.append((round(chord_start_time, 3), round(time, 3), previous_chord))
            chord_start_time = time
            previous_chord = chord_label

        last_event_time = time

    # capture final chord if any
    if previous_chord is not None and chord_start_time is not None:
        chords.append((round(chord_start_time, 3), round(midi_data.get_end_time(), 3), previous_chord))

    return chords, midi_data

# timeframe-level feature extraction and align with chord labels
def extract_frame_level_data(chords, midi_data, chord_to_index, frame_hop=1):
    end_time = midi_data.get_end_time()
    frame_times = np.arange(0, end_time, frame_hop)

    chroma = midi_data.get_chroma(fs=int(1 / frame_hop))
    chroma = chroma.T  # transpose to shape (frames, 12)

    data = []

    for i, t in enumerate(frame_times):
        frame_feature = chroma[i] if i < len(chroma) else np.zeros(12)
        label = None
        for start, end, chord in chords:
            if start <= t < end:
                if chord in chord_to_index:
                    label = chord_to_index[chord]
                break
        if label is not None:
            data.append((t, frame_feature, label))
    return data


# process all midi files in the folder, save to CSV
def process_midi_folder(input_folder, chord_csv, frame_csv, frame_hop=1):
    chord_rows = []
    frame_rows = []
    chord_to_index = create_fixed_chord_vocab()

    for root, _, files in os.walk(input_folder):
        for fname in files:
            if not fname.lower().endswith(('.mid','.midi')): continue
            path = os.path.join(root, fname)
            rel = os.path.relpath(path, input_folder)
            try:
                chords, midi = midi_to_chord_sequence(path)
                # chord-level
                for st, ed, ch in chords:
                    chord_rows.append([rel, st, ed, ch])
                # frame-level
                frames = extract_frame_level_data(chords, midi, chord_to_index, frame_hop)
                for t, feat, lbl in frames:
                    frame_rows.append([rel, t, *feat, lbl])

            except Exception as e:
                print(f"[ERROR] {rel}: {e}")

    # save to csv
    chord_df = pd.DataFrame(chord_rows, columns=["filename","start_time","end_time","chord"])
    chord_df.to_csv(chord_csv, index=False)
    cols = [f"chroma_{i}" for i in range(12)]
    frame_df = pd.DataFrame(frame_rows, columns=["filename","time", *cols, "label"])
    frame_df.to_csv(frame_csv, index=False)

    print(f"✔ Saved chords to: {chord_csv}")
    print(f"✔ Saved frames to: {frame_csv}")
    return chord_to_index

### 2. Extract and Combine to csv file

In [51]:
# paths
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)

folder_to_process = 'midi_folder' # test use; change to 'lakh-midi-clean' for actual experiments

base = os.path.basename(folder_to_process.rstrip(os.sep))
chord_csv = os.path.join(output_dir, f"chord_dataset.csv")
frame_csv = os.path.join(output_dir, f"timeframe_dataset.csv")
vocab_json = os.path.join(output_dir, f"chord_vocab.json")

chord_to_index = process_midi_folder(folder_to_process, chord_csv, frame_csv)

with open(vocab_json, 'w') as f:
    json.dump(chord_to_index, f, indent=2)

✔ Saved chords to: output/chord_dataset.csv
✔ Saved frames to: output/timeframe_dataset.csv


### 3. One-hot Encoding

In [52]:
# one-hot encoding 
import pandas as pd
import numpy as np
import os
import json

output_dir = "output"
os.makedirs(output_dir, exist_ok=True)

frame_csv_path = "output/timeframe_dataset.csv"
chord_vocab_path = "output/chord_vocab.json"
output_onehot_csv_path = os.path.join(output_dir, "timeframe_onehot.csv")


# load from JSON file
with open(chord_vocab_path, "r") as f:
    chord_to_index = json.load(f)

# reverse
chord_to_index = {str(k): v for k, v in chord_to_index.items()}


def one_hot_encode_labels(label_indices, num_classes):
    return np.eye(num_classes)[label_indices]

# load original timeframe-level dataset
df = pd.read_csv(frame_csv_path)

# get label col
label_indices = df["label"].astype(int).values

# one-hot encoding 
num_classes = len(chord_to_index)
one_hot = one_hot_encode_labels(label_indices, num_classes)

# create DataFrame 
one_hot_columns = [f"class_{i}" for i in range(num_classes)]
one_hot_df = pd.DataFrame(one_hot, columns=one_hot_columns)

# combine with filename + time 
minimal_df = df[["filename", "time"]].reset_index(drop=True)
result_df = pd.concat([minimal_df, one_hot_df], axis=1)

result_df.to_csv(output_onehot_csv_path, index=False)

print(f"One-hot encoded data saved to {output_onehot_csv_path}")

One-hot encoded data saved to output/timeframe_onehot.csv


## Baseline Model: SVM

In [53]:
import pandas as pd
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler

frame_csv_path = "output/timeframe_dataset.csv"
df = pd.read_csv(frame_csv_path)

# split to train and test dataset
feature_cols = [f"chroma_{i}" for i in range(12)]
X = df[feature_cols].values
y = df["label"].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# RBF kernel 
svm_model = SVC()
svm_model.fit(X_train_scaled, y_train)

y_pred = svm_model.predict(X_test_scaled)

# print confusion metrics with zero_division fix
print("Classification Report:")
print(classification_report(y_test, y_pred, zero_division=0))

print("Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))


Classification Report:
              precision    recall  f1-score   support

           0       0.47      0.71      0.57       185
           1       0.50      0.38      0.43        24
           2       0.47      0.37      0.41        19
           3       0.00      0.00      0.00         1
           4       0.00      0.00      0.00         2
           5       0.55      0.68      0.61        31
           6       0.75      0.46      0.57        13
           7       0.00      0.00      0.00         4
           8       0.00      0.00      0.00         1
           9       0.00      0.00      0.00         2
          10       0.59      0.74      0.66       133
          11       0.53      0.47      0.49        45
          12       0.57      0.29      0.38        14
          13       0.00      0.00      0.00         1
          14       0.00      0.00      0.00         1
          15       0.69      0.74      0.71        65
          16       0.00      0.00      0.00         5
    

## Deep Learning Models

### Reorganize timing data

In [54]:
import torch

# auto-select device:
if torch.cuda.is_available():
    device = torch.device("cuda")
    backend = "CUDA"
elif getattr(torch.backends, "mps", None) is not None \
     and torch.backends.mps.is_available():
    device = torch.device("mps")
    backend = "MPS (Apple Silicon)"
else:
    device = torch.device("cpu")
    backend = "CPU"

print(f"Using device: {device}  |  backend: {backend}")

Using device: mps  |  backend: MPS (Apple Silicon)


In [55]:
import numpy as np
import pandas as pd

def build_sequence_tensor(
    frame_df: pd.DataFrame,
    seq_len: int,
    num_feat: int = 12,
    num_classes: int = 24,
    to_torch: bool = False,
    device: str | None = None,
):

    groups = frame_df.groupby("filename", sort=False)
    n_song = len(groups)

    X_seq = np.zeros((n_song, seq_len, num_feat), dtype=np.float32)
    y_seq = np.zeros((n_song, seq_len),        dtype=np.int64)

    for idx, (_, g) in enumerate(groups):
        g = g.sort_values("time")

        x = g[[f"chroma_{i}" for i in range(num_feat)]].to_numpy(dtype=np.float32)
        y = g["label"].to_numpy(dtype=np.int64)

        pad = max(seq_len - len(x), 0)
        X_seq[idx] = np.pad(x, ((0, pad), (0, 0)), mode="constant")[:seq_len]
        y_seq[idx] = np.pad(y, (0, pad), mode="constant")[:seq_len]

    # one-hot via NumPy (no TF)
    eye = np.eye(num_classes, dtype=np.float32)
    y_seq_ohe = eye[y_seq]                    # (N, seq_len, C)

    if to_torch:
        import torch
        X_seq      = torch.tensor(X_seq,      device=device)
        y_seq_ohe  = torch.tensor(y_seq_ohe,  device=device)

    return X_seq, y_seq_ohe

In [56]:
import json, pandas as pd, numpy as np, torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

# load data
df = pd.read_csv("output/timeframe_dataset.csv")
num_classes = len(json.load(open("output/chord_vocab.json")))

X_seq, y_seq_ohe = build_sequence_tensor(
    frame_df     = df,
    seq_len      = 64,
    num_feat     = 12,
    num_classes  = num_classes,
    to_torch     = True,
    device       = device
)

# train / val split
X_tr, X_te, y_tr, y_te = train_test_split(
    X_seq, y_seq_ohe,
    test_size   = 0.20,
    random_state= 42,
    shuffle     = True
)

y_tr_idx = y_tr.argmax(dim=-1)
y_te_idx = y_te.argmax(dim=-1)


batch_size = 16
train_dl = DataLoader(
    TensorDataset(X_tr.float(), y_tr_idx.long()),
    batch_size = batch_size,
    shuffle    = True
)
val_dl = DataLoader(
    TensorDataset(X_te.float(), y_te_idx.long()),
    batch_size = batch_size,
    shuffle    = False
)

seq_len  = X_tr.size(1)
num_feat = X_tr.size(2) 

## 1. CNN Model

In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data import DataLoader, TensorDataset

# model
class ChromaCNN_Frame(nn.Module):
    def __init__(self, n_feat, n_class, hidden=128, p_drop=0.3):
        super().__init__()

        def same_conv(cin, cout, k, dil=1):
            pad = dil * (k - 1) // 2         # IMPORTANT
            return nn.Sequential(
                nn.Conv1d(cin, cout, k, padding=pad, dilation=dil),
                nn.BatchNorm1d(cout), nn.ReLU()
            )

        self.backbone = nn.Sequential(
            same_conv(n_feat, 64, 3, 1),
            same_conv(64, 128, 5, 2),                # pad=4
            same_conv(128, hidden, 7, 4),            # pad=12
            nn.Dropout(p_drop)
        )
        self.cls = nn.Conv1d(hidden, n_class, 1)

    def forward(self, x):             # (B, T, F)
        x = x.transpose(1, 2)         # (B, F, T)
        h = self.backbone(x)          # (B, hidden, T)
        y = self.cls(h).transpose(1, 2)
        return y                      # (B, T, C)


# init
cnn_model = ChromaCNN_Frame(num_feat, num_classes).to(device)
criterion  = nn.CrossEntropyLoss()
optimizer  = torch.optim.Adam(cnn_model.parameters(), lr=1e-3)

epochs, patience = 30, 3
best_loss, patience_ctr = np.inf, 0
os.makedirs("checkpoints", exist_ok=True)
best_ckpt = "checkpoints/best_cnn.pt"
last_ckpt = "checkpoints/last_cnn.pt"

# class SequenceChromaCNN(nn.Module):
#     def __init__(self, input_dim, num_classes):
#         super().__init__()
#         self.conv1 = nn.Conv1d(input_dim, 32, kernel_size=3, padding=1)
#         self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
#         self.conv3 = nn.Conv1d(64, num_classes, kernel_size=1)  # final output layer

#     def forward(self, x):
#         x = x.permute(0, 2, 1)        # (B, F=12, T)
#         x = F.relu(self.conv1(x))     # (B, 32, T)
#         x = F.relu(self.conv2(x))     # (B, 64, T)
#         x = self.conv3(x)             # (B, C, T)
#         return x.permute(0, 2, 1)     # (B, T, C)

In [58]:
# training loop
num_epochs = 20
for epoch in range(1, epochs + 1):
    cnn_model.train()
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = cnn_model(xb)
        loss   = criterion(logits.reshape(-1, num_classes), yb.reshape(-1))
        loss.backward()
        optimizer.step()

    # validation
    cnn_model.eval(); val_loss = 0.
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            loss = criterion(cnn_model(xb).reshape(-1, num_classes),
                             yb.reshape(-1))
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_dl.dataset)
    print(f"[CNN] Epoch {epoch:02d}  val_loss={val_loss:.4f}")
    torch.save(cnn_model.state_dict(), last_ckpt)

    if val_loss < best_loss:
        best_loss, patience_ctr = val_loss, 0
        torch.save(cnn_model.state_dict(), best_ckpt)
    else:
        patience_ctr += 1
        if patience_ctr >= patience:
            print("Early stopping.\n"); break

cnn_model.load_state_dict(torch.load(best_ckpt, map_location=device))
print("✓ CNN training done!")

# model = SequenceChromaCNN(input_dim=num_feat, num_classes=num_classes).to(device)

# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# num_epochs = 20
# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0

#     for X_batch, y_batch in train_dl: 
#         X_batch, y_batch = X_batch.to(device), y_batch.to(device)

#         outputs = model(X_batch)
#         loss = criterion(
#             outputs.reshape(-1, num_classes),   # (B*T, C)
#             y_batch.reshape(-1)                # (B*T,)
#         )

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()

#     print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")


[CNN] Epoch 01  val_loss=2.6379
[CNN] Epoch 02  val_loss=1.7465
[CNN] Epoch 03  val_loss=1.5066
[CNN] Epoch 04  val_loss=1.3481
[CNN] Epoch 05  val_loss=1.2838
[CNN] Epoch 06  val_loss=1.2287
[CNN] Epoch 07  val_loss=1.1864
[CNN] Epoch 08  val_loss=1.1578
[CNN] Epoch 09  val_loss=1.1157
[CNN] Epoch 10  val_loss=1.0798
[CNN] Epoch 11  val_loss=1.0741
[CNN] Epoch 12  val_loss=1.0581
[CNN] Epoch 13  val_loss=1.0414
[CNN] Epoch 14  val_loss=1.0474
[CNN] Epoch 15  val_loss=1.0560
[CNN] Epoch 16  val_loss=1.0278
[CNN] Epoch 17  val_loss=1.0277
[CNN] Epoch 18  val_loss=1.0216
[CNN] Epoch 19  val_loss=1.0207
[CNN] Epoch 20  val_loss=1.0388
[CNN] Epoch 21  val_loss=1.0308
[CNN] Epoch 22  val_loss=1.0576
Early stopping.

✓ CNN training done!


In [59]:
# evaluation
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np, torch

pad_id = num_classes
mask_pad = True               # =False not ignore padding

cnn_model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for xb, yb in val_dl:
        xb, yb = xb.to(device), yb.to(device)     # xb:(B,T,F)  yb:(B,T)
        logits = cnn_model(xb)                    # (B,T,C)
        preds  = logits.argmax(-1)                # (B,T)

        if mask_pad:
            valid = (yb != pad_id)
            all_preds.extend(preds[valid].cpu().numpy())
            all_labels.extend(yb[valid].cpu().numpy())
        else:
            all_preds.extend(preds.cpu().numpy().ravel())
            all_labels.extend(yb.cpu().numpy().ravel())


acc = accuracy_score(all_labels, all_preds)
print(f"Frame accuracy : {acc:.4f}\n")

print("Classification Report:")
print(classification_report(all_labels, all_preds, zero_division=0))

print("Confusion Matrix:")
print(confusion_matrix(all_labels, all_preds))




# model.eval()
# all_preds, all_labels = [], []

# with torch.no_grad():
#     for X_batch, y_batch in val_dl:
#         X_batch, y_batch = X_batch.to(device), y_batch.to(device)
#         outputs = model(X_batch)  # (B, T, C)
#         preds = outputs.argmax(dim=-1)  # (B, T)
        
#         all_preds.extend(preds.cpu().numpy().flatten())
#         all_labels.extend(y_batch.cpu().numpy().flatten())

# from sklearn.metrics import classification_report, confusion_matrix
# print(classification_report(all_labels, all_preds, zero_division=0))
# print(confusion_matrix(all_labels, all_preds))


Frame accuracy : 0.7369

Classification Report:
              precision    recall  f1-score   support

           0       0.93      0.97      0.95      1516
           1       0.00      0.00      0.00         6
           2       0.00      0.00      0.00        16
           5       0.00      0.00      0.00         2
           6       0.00      0.00      0.00        32
           7       0.00      0.00      0.00         7
           9       0.00      0.00      0.00         3
          10       0.57      0.82      0.67       134
          11       0.41      0.21      0.28        58
          12       0.00      0.00      0.00         3
          15       0.31      0.55      0.39        20
          16       0.00      0.00      0.00         3
          17       0.00      0.00      0.00         1
          20       0.63      0.60      0.62       177
          21       0.09      0.17      0.12        18
          22       0.00      0.00      0.00         6
          25       0.51      0.73

## 2. RNN Model

In [60]:
import os, torch, numpy as np
from torch import nn

# define RNN model
class SimpleRNNModel(nn.Module):
    def __init__(self, n_feat, n_classes, hidden=64):
        super().__init__()
        self.rnn = nn.RNN(
            input_size=n_feat,
            hidden_size=hidden,
            batch_first=True,
            nonlinearity="tanh"
        )
        self.fc = nn.Linear(hidden, n_classes)

    def forward(self, x):                 # x: (B, T, F)
        out, _ = self.rnn(x)              # (B, T, H)
        return self.fc(out)               # (B, T, C)

# init
rnn_model = SimpleRNNModel(num_feat, num_classes).to(device)
criterion  = nn.CrossEntropyLoss()
optimizer  = torch.optim.Adam(rnn_model.parameters(), lr=1e-3)

epochs, patience = 30, 3
best_loss        = np.inf
patience_ctr     = 0

# checkpoint paths
os.makedirs("checkpoints", exist_ok=True)
best_ckpt = "checkpoints/best_rnn.pt"
last_ckpt = "checkpoints/last_rnn.pt"

# training loop
for epoch in range(1, epochs + 1):
    rnn_model.train()
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = rnn_model(xb)
        loss = criterion(logits.reshape(-1, num_classes), yb.reshape(-1))
        loss.backward()
        optimizer.step()

    # validation
    rnn_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            logits = rnn_model(xb)
            loss = criterion(logits.reshape(-1, num_classes), yb.reshape(-1))
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_dl.dataset)

    print(f"Epoch {epoch:02d}  val_loss={val_loss:.4f}")
    torch.save(rnn_model.state_dict(), last_ckpt)            # always save last

    if val_loss < best_loss:                                 # save best
        best_loss = val_loss
        patience_ctr = 0
        torch.save(rnn_model.state_dict(), best_ckpt)
    else:
        patience_ctr += 1
        if patience_ctr >= patience:
            print("Early stopping.\n")
            break

# restore best
ckpt_to_load = best_ckpt if os.path.exists(best_ckpt) else last_ckpt
rnn_model.load_state_dict(torch.load(ckpt_to_load, map_location=device))
print(f"✓ RNN training done!")

Epoch 01  val_loss=3.4407
Epoch 02  val_loss=2.2562
Epoch 03  val_loss=1.9988
Epoch 04  val_loss=1.7366
Epoch 05  val_loss=1.3954
Epoch 06  val_loss=1.3137
Epoch 07  val_loss=1.2524
Epoch 08  val_loss=1.2119
Epoch 09  val_loss=1.1785
Epoch 10  val_loss=1.1553
Epoch 11  val_loss=1.1364
Epoch 12  val_loss=1.1212
Epoch 13  val_loss=1.1049
Epoch 14  val_loss=1.0948
Epoch 15  val_loss=1.0831
Epoch 16  val_loss=1.0738
Epoch 17  val_loss=1.0639
Epoch 18  val_loss=1.0574
Epoch 19  val_loss=1.0519
Epoch 20  val_loss=1.0467
Epoch 21  val_loss=1.0421
Epoch 22  val_loss=1.0390
Epoch 23  val_loss=1.0345
Epoch 24  val_loss=1.0305
Epoch 25  val_loss=1.0262
Epoch 26  val_loss=1.0240
Epoch 27  val_loss=1.0183
Epoch 28  val_loss=1.0173
Epoch 29  val_loss=1.0140
Epoch 30  val_loss=1.0119
✓ RNN training done!


## 3. BiLSTM Model

In [61]:
import os, torch, numpy as np
from torch import nn

# bidirectional LSTM
class BiLSTMModel(nn.Module):
    def __init__(self, n_feat, n_classes, hidden=64):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=n_feat,
            hidden_size=hidden,
            batch_first=True,
            bidirectional=True
        )
        self.fc = nn.Linear(hidden * 2, n_classes)

    def forward(self, x):                  # x: (B, T, F)
        out, _ = self.lstm(x)              # (B, T, 2H)
        out = self.fc(out)                 # (B, T, C)
        return out

# init
lstm_model = BiLSTMModel(num_feat, num_classes).to(device)
criterion   = nn.CrossEntropyLoss()
optimizer   = torch.optim.Adam(lstm_model.parameters(), lr=1e-3)

epochs, patience = 30, 3
best_loss        = np.inf
patience_ctr     = 0

# checkpoint directory
os.makedirs("checkpoints", exist_ok=True)
best_ckpt  = "checkpoints/best_lstm.pt"
last_ckpt  = "checkpoints/last_lstm.pt"

# training loop
for epoch in range(1, epochs + 1):
    lstm_model.train()
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = lstm_model(xb)
        loss   = criterion(logits.reshape(-1, num_classes),
                           yb.reshape(-1))
        loss.backward()
        optimizer.step()

    # validation
    lstm_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            logits = lstm_model(xb)
            loss   = criterion(logits.reshape(-1, num_classes),
                               yb.reshape(-1))
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_dl.dataset)

    print(f"Epoch {epoch:02d}  val_loss={val_loss:.4f}")

    torch.save(lstm_model.state_dict(), last_ckpt)

    # save best & early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        patience_ctr = 0
        torch.save(lstm_model.state_dict(), best_ckpt)
    else:
        patience_ctr += 1
        if patience_ctr >= patience:
            print("Early stopping.\n")
            break

# load best
ckpt_to_load = best_ckpt if os.path.exists(best_ckpt) else last_ckpt
lstm_model.load_state_dict(torch.load(ckpt_to_load, map_location=device))
print(f"✓ LSTM training done!")


Epoch 01  val_loss=3.8185
Epoch 02  val_loss=3.3541
Epoch 03  val_loss=1.8136
Epoch 04  val_loss=1.4809
Epoch 05  val_loss=1.3593
Epoch 06  val_loss=1.2667
Epoch 07  val_loss=1.1915
Epoch 08  val_loss=1.1339
Epoch 09  val_loss=1.0947
Epoch 10  val_loss=1.0609
Epoch 11  val_loss=1.0358
Epoch 12  val_loss=1.0171
Epoch 13  val_loss=1.0006
Epoch 14  val_loss=0.9869
Epoch 15  val_loss=0.9744
Epoch 16  val_loss=0.9664
Epoch 17  val_loss=0.9585
Epoch 18  val_loss=0.9478
Epoch 19  val_loss=0.9435
Epoch 20  val_loss=0.9367
Epoch 21  val_loss=0.9284
Epoch 22  val_loss=0.9223
Epoch 23  val_loss=0.9169
Epoch 24  val_loss=0.9141
Epoch 25  val_loss=0.9073
Epoch 26  val_loss=0.9044
Epoch 27  val_loss=0.9037
Epoch 28  val_loss=0.8999
Epoch 29  val_loss=0.8962
Epoch 30  val_loss=0.8960
✓ LSTM training done!


## 4. Hybrid Model (CNN+BiLSTM)

In [62]:
import torch, torch.nn as nn
import numpy as np, os
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix


class CNN_BiLSTM(nn.Module):
    def __init__(self, n_feat, n_class,
                 cnn_hidden=128, lstm_hidden=64, p_drop=0.3):
        super().__init__()

        def same_conv(cin, out_channels, k, dil=1):
            pad = dil * (k - 1) // 2
            return nn.Sequential(
                nn.Conv1d(cin, out_channels, k, padding=pad, dilation=dil),
                nn.BatchNorm1d(out_channels), nn.ReLU()
            )

        # local convolution
        self.cnn = nn.Sequential(
            same_conv(n_feat, 64, 3, 1),
            same_conv(64,  cnn_hidden, 5, 1)
        )
        # long-term dependency
        self.lstm = nn.LSTM(cnn_hidden, lstm_hidden,
                            batch_first=True, bidirectional=True)
        self.drop = nn.Dropout(p_drop)
        self.fc   = nn.Linear(lstm_hidden * 2, n_class)

    def forward(self, x):              # x:(B,T,F)
        x = x.transpose(1, 2)          # (B,F,T)
        h = self.cnn(x).transpose(1, 2)# (B,T,Hc)  len = T
        o, _ = self.lstm(h)            # (B,T,2Hl)
        o = self.drop(o)
        return self.fc(o)              # (B,T,C)

pad_id        = num_classes            # 24
total_classes = num_classes + 1        # 25

hyb_model = CNN_BiLSTM(num_feat, total_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer = torch.optim.Adam(hyb_model.parameters(), lr=1e-3)

best_loss, wait = np.inf, 0
os.makedirs("checkpoints", exist_ok=True)
best_ckpt, last_ckpt = "checkpoints/best_hyb.pt", "checkpoints/last_hyb.pt"


# class CNN_BiLSTM(nn.Module):
#     def __init__(self, input_dim, hidden_dim, num_classes):
#         super().__init__()
#         self.conv1 = nn.Conv1d(input_dim, 32, kernel_size=3, padding=1)
#         self.bilstm = nn.LSTM(input_size=32, hidden_size=hidden_dim,
#                               num_layers=1, batch_first=True, bidirectional=True)
#         self.fc = nn.Linear(hidden_dim * 2, num_classes)

#     def forward(self, x):
#         x = x.permute(0, 2, 1)            # (B, F=12, T)
#         x = F.relu(self.conv1(x))         # (B, 32, T)
#         x = x.permute(0, 2, 1)            # (B, T, 32)
#         lstm_out, _ = self.bilstm(x)      # (B, T, 2*hidden)
#         return self.fc(lstm_out)          # (B, T, C)


In [63]:
# training loop
epochs, patience = 30, 3

for ep in range(1, epochs + 1):
    # train
    hyb_model.train()
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = hyb_model(xb)                             # (B,T,25)
        loss   = criterion(logits.reshape(-1, logits.size(-1)),
                           yb.reshape(-1))
        loss.backward(); optimizer.step()

    # val
    hyb_model.eval(); val_loss = 0.
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            logits = hyb_model(xb)
            loss   = criterion(logits.reshape(-1, logits.size(-1)),
                               yb.reshape(-1))
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_dl.dataset)
    print(f"[HYB] Ep{ep:02d} val_loss={val_loss:.4f}")

    torch.save(hyb_model.state_dict(), last_ckpt)
    if val_loss < best_loss:
        best_loss, wait = val_loss, 0
        torch.save(hyb_model.state_dict(), best_ckpt)
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping."); break

hyb_model.load_state_dict(torch.load(best_ckpt, map_location=device))
print("✓ Hybrid CNN+BiLSTM training done!")



# model = CNN_BiLSTM(input_dim=num_feat, hidden_dim=64, num_classes=num_classes).to(device)

# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0.0

#     for X_batch, y_batch in train_dl:
#         X_batch, y_batch = X_batch.to(device), y_batch.to(device)  # (B, T, F), (B, T)

#         outputs = model(X_batch)  # (B, T, C)

#        # flatten predictions and labels
#         B, T, C = outputs.shape
#         loss = criterion(
#             outputs.view(B * T, C),   # (B*T, C)
#             y_batch.view(B * T)       # (B*T,)
#         )

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()

#     print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")



[HYB] Ep01 val_loss=3.0002
[HYB] Ep02 val_loss=1.6374
[HYB] Ep03 val_loss=1.5227
[HYB] Ep04 val_loss=1.3855
[HYB] Ep05 val_loss=1.2503
[HYB] Ep06 val_loss=1.1931
[HYB] Ep07 val_loss=1.1287
[HYB] Ep08 val_loss=1.0742
[HYB] Ep09 val_loss=1.0441
[HYB] Ep10 val_loss=1.0174
[HYB] Ep11 val_loss=1.0117
[HYB] Ep12 val_loss=0.9777
[HYB] Ep13 val_loss=0.9701
[HYB] Ep14 val_loss=0.9560
[HYB] Ep15 val_loss=0.9403
[HYB] Ep16 val_loss=0.9341
[HYB] Ep17 val_loss=0.9163
[HYB] Ep18 val_loss=0.9254
[HYB] Ep19 val_loss=0.9392
[HYB] Ep20 val_loss=0.9514
Early stopping.
✓ Hybrid CNN+BiLSTM training done!


In [64]:
# evaluation
pad_id = num_classes
mask_pad = True

hyb_model.eval()
all_p, all_l = [], []

with torch.no_grad():
    for xb, yb in val_dl:
        xb, yb = xb.to(device), yb.to(device)
        preds  = hyb_model(xb).argmax(-1)          # (B,T)

        if mask_pad:
            valid = (yb != pad_id)
            all_p.extend(preds[valid].cpu().numpy())
            all_l.extend(yb[valid].cpu().numpy())
        else:
            all_p.extend(preds.cpu().numpy().ravel())
            all_l.extend(yb.cpu().numpy().ravel())

acc = accuracy_score(all_l, all_p)
print(f"Frame accuracy : {acc:.4f}\n")
print(classification_report(all_l, all_p, zero_division=0))
print(confusion_matrix(all_l, all_p))


# from sklearn.metrics import classification_report, confusion_matrix 

# model.eval()
# all_preds = []
# all_labels = []

# with torch.no_grad():
#     for X_batch, y_batch in val_dl:
#         X_batch = X_batch.to(device)
#         outputs = model(X_batch)
#         preds = torch.argmax(outputs, dim=1).cpu().numpy()
#         all_preds.extend(preds)
#         all_labels.extend(y_batch.numpy())

# print("Classification Report:")
# print(classification_report(all_labels, all_preds, zero_division=0))

# print("Confusion Matrix:")
# print(confusion_matrix(all_labels, all_preds))


Frame accuracy : 0.7671

              precision    recall  f1-score   support

           0       0.94      0.97      0.96      1516
           1       1.00      0.17      0.29         6
           2       0.00      0.00      0.00        16
           5       0.00      0.00      0.00         2
           6       0.00      0.00      0.00        32
           7       0.00      0.00      0.00         7
           9       0.00      0.00      0.00         3
          10       0.58      0.88      0.70       134
          11       0.79      0.40      0.53        58
          12       0.00      0.00      0.00         3
          15       0.38      0.65      0.48        20
          16       0.00      0.00      0.00         3
          17       0.00      0.00      0.00         1
          20       0.58      0.68      0.63       177
          21       0.17      0.22      0.19        18
          22       0.00      0.00      0.00         6
          25       0.59      0.75      0.66       118
  

## Evaluation

In [65]:
import numpy as np, json, mir_eval, torch, pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# chord vocab
with open("output/chord_vocab.json") as f:
    chord_to_idx = json.load(f)
idx_to_chord = {int(v): k for k, v in chord_to_idx.items()}

MIREX_MAPPING = {
    "Major": "maj",
    "Minor": "min",
    "Dominant 7th": "7",
    "Diminished": "dim",
    "Augmented": "aug"
}

def ints_to_chords(int_arr):
    out = []
    for i in int_arr.flatten():
        txt = idx_to_chord.get(int(i), "N")
        if txt in ("Unknown", "N"):
            out.append("N"); continue
        try:
            root, qual = txt.split(" ", 1)
            out.append(f"{root}:{MIREX_MAP.get(qual,'maj')}")
        except ValueError:
            out.append("N")
    return out


In [66]:
def predict_np(model, X_np, batch=128):
    model.eval(); outs=[]
    with torch.no_grad():
        for i in range(0, len(X_np), batch):
            xb = torch.tensor(X_np[i:i+batch], dtype=torch.float32, device=device)
            logits = model(xb)                     # (B,T,C)
            outs.append(torch.softmax(logits, -1).cpu().numpy())
    return np.concatenate(outs, 0)                 # (N,T,C)


In [None]:
import numpy as np, torch, mir_eval

def evaluate_chord_predictions(name,
                               y_pred_prob,
                               y_true_ohe,
                               frame_rate: float = 1.0,
                               pad_id: int | None = None,
                               verbose=True):

    # ensure NumPy
    if torch.is_tensor(y_pred_prob):
        y_pred_prob = y_pred_prob.detach().cpu().numpy()
    if torch.is_tensor(y_true_ohe):
        y_true_ohe  = y_true_ohe.detach().cpu().numpy()

    # flatten class indices
    y_pred_int = y_pred_prob.argmax(-1).ravel()
    y_true_int = y_true_ohe.argmax(-1).ravel()

    # optional PAD filtering
    if pad_id is not None:
        mask = y_true_int != pad_id
        y_pred_int, y_true_int = y_pred_int[mask], y_true_int[mask]

    # convert to MIREX chord strings
    est_labels = ints_to_chords(y_pred_int)
    ref_labels = ints_to_chords(y_true_int)

    n = len(y_true_int)
    intervals = np.column_stack([np.arange(n)/frame_rate,
                                 (np.arange(n)+1)/frame_rate])

    result = mir_eval.chord.evaluate(intervals, ref_labels,
                                     intervals, est_labels)

    score_names = list(result.keys())
    scores      = list(result.values())
    frame_acc   = (y_pred_int == y_true_int).mean()

    if verbose:
        print(f"\n=== {name} ===")
        for k, v in zip(score_names, scores):
            print(f"{k:>10}: {v:.4f}")
        print(f"{'frame_acc':>10}: {frame_acc:.4f}")

    out = dict(result)
    out["frame_acc"] = frame_acc
    out["model"]     = name

    return out


def evaluate_models(model_dict, X_test, y_test_ohe, pad_id):
    rows=[]
    for n,m in model_dict.items():
        probs = predict_np(m, X_test)
        rows.append(evaluate_chord_predictions(n, probs, y_test_ohe,
                                               frame_rate=1.0,
                                               pad_id=pad_id,
                                               verbose=False))
        
    return pd.DataFrame(rows).set_index("model").round(4)


In [None]:
# load best checkpoints
cnn_model = ChromaCNN_Frame(num_feat, num_classes).to(device)
cnn_model.load_state_dict(torch.load("checkpoints/best_cnn.pt",
                                     map_location=device))

hyb_model = CNN_BiLSTM(num_feat, total_classes).to(device)
hyb_model.load_state_dict(torch.load("checkpoints/best_hyb.pt",
                                     map_location=device))

models = {"RNN": rnn_model,
          "LSTM": lstm_model,
          "CNN": cnn_model,
          "CNN+BiLSTM": hyb_model}

df = evaluate_models(models, X_te, y_te, pad_id)
print("\n=== Summary ==="); display(df)



  xb = torch.tensor(X_np[i:i+batch], dtype=torch.float32, device=device)



=== Summary ===


Unnamed: 0_level_0,thirds,thirds_inv,triads,triads_inv,tetrads,tetrads_inv,root,mirex,majmin,majmin_inv,sevenths,sevenths_inv,underseg,overseg,seg,frame_acc
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
RNN,0.7878,0.7878,0.7871,0.7871,0.774,0.774,0.8016,0.7871,0.7922,0.7922,0.7791,0.7791,0.9095,0.8983,0.8983,0.774
LSTM,0.7969,0.7969,0.7969,0.7969,0.7871,0.7871,0.8056,0.7969,0.8021,0.8021,0.7922,0.7922,0.8688,0.919,0.8688,0.7871
CNN,0.7445,0.7445,0.7445,0.7445,0.7369,0.7369,0.7703,0.7445,0.7495,0.7495,0.7418,0.7418,0.8576,0.9044,0.8576,0.7369
CNN+BiLSTM,0.7798,0.7798,0.7798,0.7798,0.7671,0.7671,0.8005,0.7798,0.7849,0.7849,0.7721,0.7721,0.7903,0.952,0.7903,0.7671


In [69]:

# import mir_eval
# import json

# # Load and invert your chord vocab
# with open("output/chord_vocab.json", "r") as f:
#     chord_to_index = json.load(f)
# index_to_chord = {int(v): k for k, v in chord_to_index.items()}

# # MIREX chord type mapping
# MIREX_MAPPING = {
#     "Major": "maj",
#     "Minor": "min",
#     "Dominant 7th": "7",
#     "Diminished": "dim",
#     "Augmented": "aug"
# }

# # convert class index to MIREX-style chord strings
# # return chord array in MIREX
# def ints_to_chords(y_int): # np.ndarray
#     if isinstance(y_int, np.ndarray):
#         y_int = y_int.flatten()

#     chords = []
#     for i in y_int:
#         chord_str = index_to_chord.get(int(i), "N")
#         if chord_str == "Unknown" or chord_str == "N":
#             chords.append("N")  # mirex for Null
#         else:
#             try:
#                 root, quality = chord_str.split(" ", 1)
#                 mir_label = f"{root}:{MIREX_MAPPING.get(quality, 'maj')}"
#                 chords.append(mir_label)
#             except Exception:
#                 chords.append("N")
#     return chords

# # model evaluation using mir_eval
# def evaluate_chord_predictions(name, y_pred_prob, y_true_ohe, frame_rate=1.0, return_scores=False):
    
#     y_pred_int = y_pred_prob.argmax(-1).flatten()
#     y_true_int = y_true_ohe.argmax(-1).flatten()

#     est_labels = ints_to_chords(y_pred_int)
#     ref_labels = ints_to_chords(y_true_int)

#     n = len(y_true_int)
#     intervals = np.column_stack([np.arange(n)/frame_rate, (np.arange(n)+1)/frame_rate])

#     result = mir_eval.chord.evaluate(intervals, ref_labels, intervals, est_labels)

#     score_names = list(result.keys())
#     scores = list(result.values())
#     frame_acc = np.mean(np.array(y_pred_int) == np.array(y_true_int))

#     print(f"\n=== {name} ===")
#     for nm, sc in zip(score_names, scores):
#         print(f"{nm:>10}: {sc:.4f}")
#     print(f"{'frame_acc':>10}: {frame_acc:.4f}")

#     if return_scores:
#         score_dict = {nm: sc for nm, sc in zip(score_names, scores)}
#         score_dict["frame_acc"] = frame_acc
#         score_dict["model"] = name
#         return score_dict

# # evaluate all models 
# def evaluate_models(model_dict, X_test, y_true_ohe, frame_rate=1.0):
#     results = []
#     for name, model in model_dict.items():
#         probs = predict_np(model, X_test)
#         scores = evaluate_chord_predictions(name, probs, y_true_ohe, frame_rate, return_scores=True)
#         results.append(scores)

#     df = pd.DataFrame(results).set_index("model")
#     print("\n=== Summary Table ===")
#     print(df.round(4))
#     return df

# cnn_model = SequenceChromaCNN(input_dim=num_feat, num_classes=num_classes).to(device)
# cnn_bilstm_model = CNN_BiLSTM(input_dim=num_feat, hidden_dim=64, num_classes=num_classes).to(device)

# model_dict = {
#     "RNN": rnn_model,
#     "LSTM": lstm_model,
#     "CNN": cnn_model,
#     "CNN+BiLSTM": cnn_bilstm_model
# }

# df_results = evaluate_models(model_dict, X_te, y_te)


In [70]:
# import numpy as np
# import mir_eval

# def evaluate_chord_predictions(name, y_pred_prob, y_true_ohe, frame_rate=1.0):
#     """
#     name : Identifier for printing.
#     y_pred_prob : ndarray, shape (N, seq_len, C)
#         Predicted class probabilities or logits.
#     y_true_ohe : ndarray, shape (N, seq_len, C)
#         One-hot ground-truth labels.
#     frame_rate : float, default 1.0
#         Frames per second (interval length = 1/frame_rate sec).
#     """
#     # flatten to 1-D vectors of class indices
#     y_pred_int = y_pred_prob.argmax(-1).flatten()
#     y_true_int = y_true_ohe.argmax(-1).flatten()

#     # map int to mir_eval chord strings 
#     est_labels = ints_to_chords(y_pred_int)
#     ref_labels = ints_to_chords(y_true_int)

#     n = len(y_true_int)
#     intervals = np.column_stack([np.arange(n)/frame_rate,
#                                  (np.arange(n)+1)/frame_rate])

#     result = mir_eval.chord.evaluate(
#         intervals, ref_labels, intervals, est_labels
#     )

#     # normalize output (names, scores)
#     if isinstance(result, dict):
#         score_names = list(result.keys())
#         scores      = list(result.values())

#     elif isinstance(result, (list, tuple)) and len(result) == 2:
#         a, b = result
#         score_names, scores = (a, b) if isinstance(a[0], str) else (b, a)

#     else:
#         scores = list(result)
#         score_names = ["root", "majmin", "thirds",
#                        "triads", "sevenths", "tetrads", "mirex"][:len(scores)]

#     print(f"\n=== {name} ===")
#     for nm, sc in zip(score_names, scores):
#         print(f"{nm:>10}: {sc:.4f}")

#     frame_acc = (y_pred_int == y_true_int).mean()
#     print(f"{'frame_acc':>10}: {frame_acc:.4f}")


# evaluate_chord_predictions("RNN Model",  rnn_probs,  y_te)
# evaluate_chord_predictions("LSTM Model", lstm_probs, y_te)
