In [13]:
# src/preprocess_midi.py

import os
from pathlib import Path
from typing import List, Tuple

import numpy as np
from music21 import interval, pitch

In [14]:
# -----------------------------
# Config
# -----------------------------
RAW_MIDI_DIR = Path("../data/raw_midi")
PROCESSED_DIR = Path("../data/processed")
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
TRANSPOSED_DIR = PROCESSED_DIR / "transposed_midis"
TRANSPOSED_DIR.mkdir(parents=True, exist_ok=True)


# how many notes per snippet (you can tweak)
SNIPPET_LENGTH = 32

# base rhythmic unit: 1 quarter note = 4 steps, so 1 step = sixteenth note
STEPS_PER_QUARTER = 4


In [18]:
# -----------------------------
# Helper functions
# -----------------------------

from music21 import converter, instrument, note, chord, stream, key, interval, pitch

def load_midi(filepath: Path) -> stream.Score:
    """Load a MIDI file into a music21 Score."""
    return converter.parse(str(filepath))



def pick_melody_part(score: stream.Score) -> stream.Part:
    """
    Heuristic for picking the 'melody' part:

    1. Skip percussion parts.
    2. If any part name/instrument name suggests 'melody/lead/right hand',
       pick that directly.
    3. Otherwise:
       - For each remaining part, compute:
         * n_notes
         * avg_pitch
       - Compute median avg_pitch across candidates.
       - Filter to parts with avg_pitch >= median (favor higher voices).
       - Among those, pick the one with the most notes; break ties by higher avg_pitch.

    Fallback: if all else fails, flatten the score.
    """
    candidates = []

    for p in score.parts:
        # Skip percussion
        if any(isinstance(i, instrument.UnpitchedPercussion) for i in p.getInstruments()):
            continue

        # Collect notes/chords
        notes_chords = [n for n in p.recurse().notes if isinstance(n, (note.Note, chord.Chord))]
        if not notes_chords:
            continue

        # Basic stats
        pitches = []
        for n in notes_chords:
            if isinstance(n, note.Note):
                pitches.append(n.pitch.midi)
            elif isinstance(n, chord.Chord):
                pitches.append(max(nn.pitch.midi for nn in n.notes))

        if not pitches:
            continue

        n_notes = len(pitches)
        avg_pitch = sum(pitches) / len(pitches)

        # part/instrument names (lowercased)
        part_name = (p.partName or "").lower()
        inst_names = [str(inst.instrumentName or "").lower()
                      for inst in p.getInstruments()]

        candidates.append({
            "part": p,
            "n_notes": n_notes,
            "avg_pitch": avg_pitch,
            "part_name": part_name,
            "inst_names": inst_names,
        })

    if not candidates:
        print("  [warn] no suitable non-percussion parts; flattening score.")
        return score.flat

    # 1) Name-based shortcut: if any part name/instrument suggests "melody"
    name_keywords = [
        "melody", "lead", "right hand", "treble", "solo", "violin", "flute", "trumpet"
    ]

    def looks_like_melody(c):
        text = c["part_name"] + " " + " ".join(c["inst_names"])
        text = text.lower()
        return any(kw in text for kw in name_keywords)

    name_candidates = [c for c in candidates if looks_like_melody(c)]
    if name_candidates:
        # among these, pick the one with highest avg_pitch (just in case)
        best = max(name_candidates, key=lambda c: c["avg_pitch"])
        print(f"  [info] pick_melody_part: selected by name heuristic: "
              f"part_name='{best['part_name']}', avg_pitch={best['avg_pitch']:.1f}, n_notes={best['n_notes']}")
        return best["part"]

    # 2) Pitch-based filtering: keep only parts at or above median avg_pitch
    avg_pitches = [c["avg_pitch"] for c in candidates]
    median_pitch = sorted(avg_pitches)[len(avg_pitches) // 2]

    high_voice_candidates = [c for c in candidates if c["avg_pitch"] >= median_pitch]
    if not high_voice_candidates:
        high_voice_candidates = candidates  # fallback to all

    # 3) Among high-voice candidates, pick the one with the most notes & higher pitch
    best = max(
        high_voice_candidates,
        key=lambda c: (c["n_notes"], c["avg_pitch"])  # primary: many notes, secondary: higher pitch
    )

    print(
        f"  [info] pick_melody_part: selected by stats: "
        f"part_name='{best['part_name']}', avg_pitch={best['avg_pitch']:.1f}, "
        f"n_notes={best['n_notes']}"
    )

    return best["part"]




def detect_key_and_transpose(melody: stream.Part) -> stream.Part:
    """
    Detect key with music21 and transpose so tonic is C (for major) or A (for minor).
    If key detection fails for some reason, return the original melody.
    """
    try:
        key_guess = melody.analyze('key')
    except Exception as e:
        print("  [warn] key analysis failed, leaving melody untransposed:", e)
        return melody

    # Decide target tonic
    if key_guess.mode == 'major':
        target_pitch = pitch.Pitch('C')
    else:
        # treat minor keys as aiming for A minor tonic
        target_pitch = pitch.Pitch('A')

    # Build interval from current tonic to target tonic
    itvl = interval.Interval(key_guess.tonic, target_pitch)

    transposed = melody.transpose(itvl)
    return transposed


def extract_pitch_duration_sequence(melody: stream.Part) -> List[Tuple[int, float]]:
    """
    Extract (midi_pitch, quarter_length_duration) from a melody line.
    Ignore rests; collapse chords to their top note.
    """
    seq = []
    for elem in melody.recurse().notesAndRests:
        if isinstance(elem, note.Note):
            midi_pitch = elem.pitch.midi
            dur = float(elem.quarterLength)
            seq.append((midi_pitch, dur))
        elif isinstance(elem, chord.Chord):
            # take highest note in chord as melody approximation
            midi_pitch = max(n.pitch.midi for n in elem.notes)
            dur = float(elem.quarterLength)
            seq.append((midi_pitch, dur))
        else:
            # ignore rests and other stuff for now
            continue
    return seq


def convert_to_intervals_and_durations(
    pitch_dur_seq: List[Tuple[int, float]]
) -> Tuple[List[int], List[int]]:
    """
    Convert absolute pitches to pitch intervals and durations to integer steps.
    intervals[i] = pitch[i] - pitch[i-1], with first interval = 0
    durations[i] = round( quarter_length * STEPS_PER_QUARTER )
    """
    if not pitch_dur_seq:
        return [], []

    pitches = [p for (p, _) in pitch_dur_seq]
    durs_q = [d for (_, d) in pitch_dur_seq]

    intervals = [0]  # first note has no previous reference
    for i in range(1, len(pitches)):
        intervals.append(int(pitches[i] - pitches[i - 1]))

    durations = [max(1, int(round(d * STEPS_PER_QUARTER))) for d in durs_q]

    return intervals, durations


def make_snippets(
    intervals: List[int],
    durations: List[int],
    snippet_length: int = SNIPPET_LENGTH
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Slice sequences into fixed-length snippets.
    We use a simple sliding window with stride = snippet_length // 2 (50% overlap).
    Short sequences yield zero snippets.
    """
    assert len(intervals) == len(durations)
    n = len(intervals)
    if n < snippet_length:
        return np.empty((0, snippet_length), dtype=np.int32), np.empty((0, snippet_length), dtype=np.int32)

    stride = snippet_length // 2
    interval_snips = []
    duration_snips = []

    for start in range(0, n - snippet_length + 1, stride):
        end = start + snippet_length
        interval_snips.append(intervals[start:end])
        duration_snips.append(durations[start:end])

    return np.array(interval_snips, dtype=np.int32), np.array(duration_snips, dtype=np.int32)



In [19]:
# -----------------------------
# Main preprocessing
# -----------------------------
def process_all_midis(rebuild_all: bool = False):
    """
    Preprocess MIDI files into fixed-length snippets.

    Args:
        rebuild_all (bool): 
            - If False (default): 
                * Load existing snippets.npz (if present)
                * Only process NEW MIDI files not already in midi_filenames
                * Append their snippets to the existing dataset
            - If True:
                * Ignore existing snippets.npz
                * Rebuild dataset from ALL MIDI files in RAW_MIDI_DIR
    """
    out_path = PROCESSED_DIR / "snippets.npz"

    # --------------------------------------------------
    # Collect all MIDI filenames in the raw directory
    # --------------------------------------------------
    midi_files = sorted(list(RAW_MIDI_DIR.glob("*.mid")) +
                        list(RAW_MIDI_DIR.glob("*.midi")))

    if not midi_files:
        print(f"No MIDI files found in {RAW_MIDI_DIR}. Nothing to do.")
        return

    # We'll fill these as we go
    all_interval_snips = []
    all_duration_snips = []
    all_song_ids = []

    # These are only used if we are appending (rebuild_all=False)
    existing_intervals = None
    existing_durations = None
    existing_song_ids = None
    existing_midi_filenames = None
    existing_filenames_set = set()

    # --------------------------------------------------
    # Load existing NPZ (if present and not rebuilding)
    # --------------------------------------------------
    if not rebuild_all and out_path.exists():
        print(f"Loading existing dataset: {out_path}")
        data = np.load(out_path, allow_pickle=True)

        existing_intervals = data["intervals"]
        existing_durations = data["durations"]
        existing_song_ids = data["song_ids"]
        existing_midi_filenames = data["midi_filenames"]  # 1D array of filenames

        existing_filenames_set = set(existing_midi_filenames.tolist())

        print(f"  Existing snippets: {existing_intervals.shape[0]}")
        print(f"  Existing MIDI files: {len(existing_filenames_set)}")
    elif rebuild_all:
        print("Rebuilding dataset from scratch; ignoring existing snippets.npz (if any).")

    # --------------------------------------------------
    # Decide which files to process
    # --------------------------------------------------
    if rebuild_all or existing_midi_filenames is None:
        # process ALL files
        files_to_process = midi_files
        base_song_idx = 0
        existing_midi_filenames_list = []
    else:
        # Only process files not already in midi_filenames
        files_to_process = [p for p in midi_files if p.name not in existing_filenames_set]
        base_song_idx = len(existing_midi_filenames)
        existing_midi_filenames_list = existing_midi_filenames.tolist()

    print(f"Found {len(midi_files)} total MIDI files.")
    print(f"{len(files_to_process)} file(s) to process this run.")

    if not files_to_process:
        print("No new MIDI files found. Dataset unchanged.")
        return

    # We'll also build up the new filenames in order
    new_filenames_list = []

    # --------------------------------------------------
    # Process selected MIDI files
    # --------------------------------------------------
    for local_idx, midi_path in enumerate(files_to_process):
        # song_id for this file within the full dataset:
        # existing songs first, then new ones in order
        song_id = base_song_idx + local_idx

        print(f"Processing {midi_path.name} "
              f"({local_idx + 1}/{len(files_to_process)}), assigned song_id={song_id}")

        try:
            score = load_midi(midi_path)
        except Exception as e:
            print(f"  Failed to load {midi_path.name}: {e}")
            continue

        melody = pick_melody_part(score)
        melody = detect_key_and_transpose(melody)

        # Save transposed melody
        out_midi_path = TRANSPOSED_DIR / f"transposed_{midi_path.stem}.mid"
        try:
            melody.write("midi", fp=str(out_midi_path))
            print(f"  Saved transposed: {out_midi_path}")
        except Exception as e:
            print(f"  [warn] could not save transposed MIDI for {midi_path.name}: {e}")

        pitch_dur_seq = extract_pitch_duration_sequence(melody)

        if len(pitch_dur_seq) < SNIPPET_LENGTH:
            print(f"  Skipping {midi_path.name}: too few notes ({len(pitch_dur_seq)})")
            continue

        intervals, durations = convert_to_intervals_and_durations(pitch_dur_seq)
        i_snips, d_snips = make_snippets(intervals, durations, SNIPPET_LENGTH)

        if i_snips.shape[0] == 0:
            print(f"  No snippets extracted from {midi_path.name}")
            continue

        all_interval_snips.append(i_snips)
        all_duration_snips.append(d_snips)
        all_song_ids.append(np.full(i_snips.shape[0], song_id, dtype=np.int32))

        new_filenames_list.append(midi_path.name)

    # --------------------------------------------------
    # If nothing new was processed successfully
    # --------------------------------------------------
    if not all_interval_snips:
        print("No snippets extracted from selected MIDI files. Dataset unchanged.")
        return

    # Stack new snippets
    new_intervals = np.vstack(all_interval_snips)
    new_durations = np.vstack(all_duration_snips)
    new_song_ids = np.concatenate(all_song_ids)

    # --------------------------------------------------
    # Merge old + new or just use new (if rebuild_all or no existing)
    # --------------------------------------------------
    if not rebuild_all and existing_intervals is not None:
        intervals_arr = np.vstack([existing_intervals, new_intervals])
        durations_arr = np.vstack([existing_durations, new_durations])
        song_ids_arr = np.concatenate([existing_song_ids, new_song_ids])

        # Append new filenames after existing, in consistent order
        midi_filenames_arr = np.array(existing_midi_filenames_list + new_filenames_list)
    else:
        intervals_arr = new_intervals
        durations_arr = new_durations
        song_ids_arr = new_song_ids

        # When rebuilding, our song_ids are aligned with files_to_process in order
        # but we still want the full midi_files list in case some were skipped.
        midi_filenames_arr = np.array(new_filenames_list)

    # --------------------------------------------------
    # Save updated dataset
    # --------------------------------------------------
    np.savez_compressed(
        out_path,
        intervals=intervals_arr,
        durations=durations_arr,
        song_ids=song_ids_arr,
        midi_filenames=midi_filenames_arr,
    )

    print(f"Saved updated snippets to {out_path}")
    print(f"Total snippets: {intervals_arr.shape[0]}")
    print(f"Total MIDI files represented: {len(midi_filenames_arr)}")


In [20]:
process_all_midis(rebuild_all=True)

Rebuilding dataset from scratch; ignoring existing snippets.npz (if any).
Found 30 total MIDI files.
30 file(s) to process this run.
Processing Pirates of the Caribbean - He's a Pirate (3).mid (1/30), assigned song_id=0
  [info] pick_melody_part: selected by name heuristic: part_name='right hand', avg_pitch=70.4, n_notes=274
  Saved transposed: ../data/processed/transposed_midis/transposed_Pirates of the Caribbean - He's a Pirate (3).mid
Processing appass_1.mid (2/30), assigned song_id=1




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=73.7, n_notes=2483
  Saved transposed: ../data/processed/transposed_midis/transposed_appass_1.mid
Processing appass_2.mid (3/30), assigned song_id=2
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=69.0, n_notes=630
  Saved transposed: ../data/processed/transposed_midis/transposed_appass_2.mid
Processing appass_3.mid (4/30), assigned song_id=3
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=72.8, n_notes=2943
  Saved transposed: ../data/processed/transposed_midis/transposed_appass_3.mid
Processing beethoven_hammerklavier_1.mid (5/30), assigned song_id=4




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=80.4, n_notes=2757
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_hammerklavier_1.mid
Processing beethoven_hammerklavier_2.mid (6/30), assigned song_id=5
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=76.4, n_notes=623
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_hammerklavier_2.mid
Processing beethoven_hammerklavier_3.mid (7/30), assigned song_id=6
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=77.3, n_notes=1286
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_hammerklavier_3.mid
Processing beethoven_hammerklavier_4.mid (8/30), assigned song_id=7
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=72.6, n_notes=3447
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_hammerklavier_4.mid
Processing be



  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=69.0, n_notes=2354
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_opus22_1.mid
Processing beethoven_opus22_2.mid (16/30), assigned song_id=15
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=68.8, n_notes=773
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_opus22_2.mid
Processing beethoven_opus22_3.mid (17/30), assigned song_id=16
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=70.4, n_notes=769
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_opus22_3.mid
Processing beethoven_opus22_4.mid (18/30), assigned song_id=17
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=73.4, n_notes=1145
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_opus22_4.mid
Processing beethoven_opus90_1.mid (19/30), assigned song_



  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=75.0, n_notes=902
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_opus90_1.mid
Processing beethoven_opus90_2.mid (20/30), assigned song_id=19
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=74.4, n_notes=1289
  Saved transposed: ../data/processed/transposed_midis/transposed_beethoven_opus90_2.mid
Processing elise.mid (21/30), assigned song_id=20




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=72.6, n_notes=496
  Saved transposed: ../data/processed/transposed_midis/transposed_elise.mid
Processing mond_1.mid (22/30), assigned song_id=21




  [info] pick_melody_part: selected by stats: part_name='piano right second', avg_pitch=62.0, n_notes=803
  Saved transposed: ../data/processed/transposed_midis/transposed_mond_1.mid
Processing mond_2.mid (23/30), assigned song_id=22




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=70.9, n_notes=381
  Saved transposed: ../data/processed/transposed_midis/transposed_mond_2.mid
Processing mond_3.mid (24/30), assigned song_id=23




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=71.0, n_notes=2528
  Saved transposed: ../data/processed/transposed_midis/transposed_mond_3.mid
Processing pathetique_1.mid (25/30), assigned song_id=24




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=71.9, n_notes=2134




  Saved transposed: ../data/processed/transposed_midis/transposed_pathetique_1.mid
Processing pathetique_2.mid (26/30), assigned song_id=25




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=67.6, n_notes=241
  Saved transposed: ../data/processed/transposed_midis/transposed_pathetique_2.mid
Processing pathetique_3.mid (27/30), assigned song_id=26




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=73.0, n_notes=1198
  Saved transposed: ../data/processed/transposed_midis/transposed_pathetique_3.mid
Processing waldstein_1.mid (28/30), assigned song_id=27
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=74.1, n_notes=3782
  Saved transposed: ../data/processed/transposed_midis/transposed_waldstein_1.mid
Processing waldstein_2.mid (29/30), assigned song_id=28
  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=67.9, n_notes=168
  Saved transposed: ../data/processed/transposed_midis/transposed_waldstein_2.mid
Processing waldstein_3.mid (30/30), assigned song_id=29




  [info] pick_melody_part: selected by stats: part_name='piano right', avg_pitch=72.7, n_notes=2425
  Saved transposed: ../data/processed/transposed_midis/transposed_waldstein_3.mid
Saved updated snippets to ../data/processed/snippets.npz
Total snippets: 2608
Total MIDI files represented: 30
