In [45]:
# preprocess_midi.py

import os
from pathlib import Path
from typing import List, Tuple

import numpy as np
from music21 import interval, pitch

In [46]:
# -----------------------------
# Config
# -----------------------------
RAW_MIDI_DIR = Path("../data/mini_dataset")
PROCESSED_DIR = Path("../data/processed")
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
TRANSPOSED_DIR = PROCESSED_DIR / "transposed_midis"
TRANSPOSED_DIR.mkdir(parents=True, exist_ok=True)


# how many notes per snippet
SNIPPET_LENGTH = 32

# base rhythmic unit: 1 quarter note = 4 steps, so 1 step = sixteenth note
STEPS_PER_QUARTER = 4


In [61]:
# -----------------------------
# Helper functions
# -----------------------------

from music21 import converter, instrument, note, chord, stream, key, interval, pitch

def load_midi(filepath: Path) -> stream.Score:
    """Load a MIDI file into a music21 Score."""
    return converter.parse(str(filepath))


def pick_melody_part(score: stream.Score) -> stream.Part | None:
    """
    Heuristic for picking the 'melody' part:

    1. Skip *purely* percussion parts.
    2. If any part name/instrument name suggests 'melody/lead/right hand',
       pick that directly.
    3. Otherwise:
       - For each remaining part, compute:
         * n_notes
         * avg_pitch
       - Compute median avg_pitch across candidates.
       - Filter to parts with avg_pitch >= median (favor higher voices).
       - Among those, pick the one with the most notes; break ties by higher avg_pitch.

    Returns the chosen Part, or None if nothing suitable is found.
    """
    candidates = []

    for p in score.parts:
        insts = list(p.getInstruments())

        # Determine if this part is purely percussion (all instruments percussion-like)
        has_percussion = any(
            isinstance(i, instrument.UnpitchedPercussion) or
            ("percussion" in (i.bestName() or "").lower())
            for i in insts
        )
        has_non_percussion = any(
            not isinstance(i, instrument.UnpitchedPercussion) and
            "percussion" not in (i.bestName() or "").lower()
            for i in insts
        )

        # Skip only if it's *purely* percussion, not mixed
        if has_percussion and not has_non_percussion:
            continue

        # Collect notes/chords
        notes_chords = [n for n in p.recurse().notes if isinstance(n, (note.Note, chord.Chord))]
        if not notes_chords:
            continue

        # Basic stats
        pitches = []
        for n in notes_chords:
            if isinstance(n, note.Note):
                pitches.append(n.pitch.midi)
            elif isinstance(n, chord.Chord):
                pitches.append(max(nn.pitch.midi for nn in n.notes))

        if not pitches:
            continue

        n_notes = len(pitches)
        avg_pitch = sum(pitches) / len(pitches)

        # part/instrument names (lowercased)
        part_name = (p.partName or "").lower()
        inst_names = [str(inst.instrumentName or "").lower()
                      for inst in insts]

        candidates.append({
            "part": p,
            "n_notes": n_notes,
            "avg_pitch": avg_pitch,
            "part_name": part_name,
            "inst_names": inst_names,
        })

    if not candidates:
        print("  [warn] no suitable melodic parts; skipping this file.")
        return None

    # 1) Name-based shortcut: if any part name/instrument suggests "melody"
    name_keywords = [
        "melody", "lead", "right hand", "rh", "treble", "solo", "violin", "flute", "trumpet", "saw wave"
    ]

    def looks_like_melody(c):
        text = c["part_name"] + " " + " ".join(c["inst_names"])
        text = text.lower()
        return any(kw in text for kw in name_keywords)

    name_candidates = [c for c in candidates if looks_like_melody(c)]
    if name_candidates:
        # among these, pick the one with highest avg_pitch (just in case)
        best = max(name_candidates, key=lambda c: c["avg_pitch"])
        print(f"  [info] pick_melody_part: selected by name heuristic: "
              f"part_name='{best['part_name']}', avg_pitch={best['avg_pitch']:.1f}, n_notes={best['n_notes']}")
        return best["part"]

    # 2) Pitch-based filtering: keep only parts at or above median avg_pitch
    avg_pitches = [c["avg_pitch"] for c in candidates]
    median_pitch = sorted(avg_pitches)[len(avg_pitches) // 2]

    high_voice_candidates = [c for c in candidates if c["avg_pitch"] >= median_pitch]
    if not high_voice_candidates:
        high_voice_candidates = candidates  # fallback to all

    # 3) Among high-voice candidates, pick the one with the most notes & higher pitch
    best = max(
        high_voice_candidates,
        key=lambda c: (c["n_notes"], c["avg_pitch"])  # primary: many notes, secondary: higher pitch
    )

    print(
        f"  [info] pick_melody_part: selected by stats: "
        f"part_name='{best['part_name']}', avg_pitch={best['avg_pitch']:.1f}, "
        f"n_notes={best['n_notes']}"
    )

    return best["part"]




def detect_key_and_transpose(melody: stream.Part) -> stream.Part:
    """
    Detect key with music21 and transpose so tonic is C (for major) or A (for minor).
    If key detection fails for some reason, return the original melody.
    """
    try:
        key_guess = melody.analyze('key')
    except Exception as e:
        print("  [warn] key analysis failed, leaving melody untransposed:", e)
        return melody

    # Decide target tonic
    if key_guess.mode == 'major':
        target_pitch = pitch.Pitch('C')
    else:
        # treat minor keys as aiming for A minor tonic
        target_pitch = pitch.Pitch('A')

    # Build interval from current tonic to target tonic
    itvl = interval.Interval(key_guess.tonic, target_pitch)

    transposed = melody.transpose(itvl)
    return transposed


def extract_pitch_duration_sequence(melody: stream.Part) -> List[Tuple[int, float]]:
    """
    Extract (midi_pitch, quarter_length_duration) from a melody line.
    Ignore rests; collapse chords to their top note.
    """
    seq = []
    for elem in melody.recurse().notesAndRests:
        if isinstance(elem, note.Note):
            midi_pitch = elem.pitch.midi
            dur = float(elem.quarterLength)
            seq.append((midi_pitch, dur))
        elif isinstance(elem, chord.Chord):
            # take highest note in chord as melody approximation
            midi_pitch = max(n.pitch.midi for n in elem.notes)
            dur = float(elem.quarterLength)
            seq.append((midi_pitch, dur))
        else:
            # ignore rests and other stuff for now
            continue
    return seq


def convert_to_intervals_and_durations(
    pitch_dur_seq: List[Tuple[int, float]]
) -> Tuple[List[int], List[int]]:
    """
    Convert absolute pitches to pitch intervals and durations to integer steps.
    intervals[i] = pitch[i] - pitch[i-1], with first interval = 0
    durations[i] = round( quarter_length * STEPS_PER_QUARTER )
    """
    if not pitch_dur_seq:
        return [], []

    pitches = [p for (p, _) in pitch_dur_seq]
    durs_q = [d for (_, d) in pitch_dur_seq]

    intervals = [0]  # first note has no previous reference
    for i in range(1, len(pitches)):
        intervals.append(int(pitches[i] - pitches[i - 1]))

    durations = [max(1, int(round(d * STEPS_PER_QUARTER))) for d in durs_q]

    return intervals, durations


def make_snippets(
    intervals: List[int],
    durations: List[int],
    snippet_length: int = SNIPPET_LENGTH
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Slice sequences into fixed-length snippets.
    We use a simple sliding window with stride = snippet_length // 2 (50% overlap).
    Short sequences yield zero snippets.
    """
    assert len(intervals) == len(durations)
    n = len(intervals)
    if n < snippet_length:
        return np.empty((0, snippet_length), dtype=np.int32), np.empty((0, snippet_length), dtype=np.int32)

    stride = snippet_length // 2
    interval_snips = []
    duration_snips = []

    for start in range(0, n - snippet_length + 1, stride):
        end = start + snippet_length
        interval_snips.append(intervals[start:end])
        duration_snips.append(durations[start:end])

    return np.array(interval_snips, dtype=np.int32), np.array(duration_snips, dtype=np.int32)

def make_snippets_with_timestamps(
    intervals: List[int],
    durations: List[int],
    durations_q: List[float],
    snippet_length: int = SNIPPET_LENGTH
) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, int]]]:
    """
    Slice sequences into fixed-length snippets.
    Also return timestamp pairs (start_q, end_q) in quarter lengths.
    """
    assert len(intervals) == len(durations)
    n = len(intervals)
    if n < snippet_length:
        return np.empty((0, snippet_length), dtype=np.int32), np.empty((0, snippet_length), dtype=np.int32), []

    stride = snippet_length // 2
    interval_snips, duration_snips, timestamps = [], [], []

    cumulative_q = np.cumsum([0] + durations_q)  # cumulative time in quarter lengths

    for start in range(0, n - snippet_length + 1, stride):
        end = start + snippet_length
        interval_snips.append(intervals[start:end])
        duration_snips.append(durations[start:end])

        start_q = cumulative_q[start]
        end_q = cumulative_q[end]
        timestamps.append((int(round(start_q)), int(round(end_q))))

    return np.array(interval_snips, dtype=np.int32), np.array(duration_snips, dtype=np.int32), timestamps


In [48]:
from music21 import instrument

def sanitize_melody_instrument(melody_part):
    """
    Remove any existing Instrument metadata and force a clean, non-percussion
    instrument on a non-drum channel.
    """
    # Remove ALL Instrument objects from this part
    for inst in list(melody_part.recurse().getElementsByClass(instrument.Instrument)):
        try:
            melody_part.remove(inst)
        except Exception:
            pass

    # Set a friendly part name
    melody_part.partName = "Melody"

    # Insert one clean Piano instrument at the beginning
    piano = instrument.Piano()
    piano.midiProgram = 0  # Acoustic Grand
    piano.midiChannel = 0  # Channel 1 (NOT 10/drums)
    melody_part.insert(0, piano)

    return melody_part


In [49]:
import warnings
from music21.midi.translate import TranslateWarning
warnings.filterwarnings("ignore", category=TranslateWarning)

In [50]:
# -----------------------------
# Main preprocessing
# -----------------------------
def process_all_midis(rebuild_all: bool = False):
    """
    Preprocess MIDI files into fixed-length snippets.

    Args:
        rebuild_all (bool): 
            - If False (default): 
                * Load existing snippets.npz (if present)
                * Only process NEW MIDI files not already in midi_filenames
                * Append their snippets to the existing dataset
            - If True:
                * Ignore existing snippets.npz
                * Rebuild dataset from ALL MIDI files in RAW_MIDI_DIR
    """
    out_path = PROCESSED_DIR / "snippets.npz"

    # --------------------------------------------------
    # Collect all MIDI filenames in the raw directory
    # --------------------------------------------------
    
    
    midi_files = sorted(list(RAW_MIDI_DIR.rglob("*.mid")) +
                        list(RAW_MIDI_DIR.rglob("*.midi")))

    if not midi_files:
        print(f"No MIDI files found in {RAW_MIDI_DIR}. Nothing to do.")
        return

    # We'll fill these as we go
    all_interval_snips = []
    all_duration_snips = []
    all_song_ids = []

    # These are only used if we are appending (rebuild_all=False)
    existing_intervals = None
    existing_durations = None
    existing_song_ids = None
    existing_midi_filenames = None
    existing_filenames_set = set()

    # --------------------------------------------------
    # Load existing NPZ (if present and not rebuilding)
    # --------------------------------------------------
    if not rebuild_all and out_path.exists():
        print(f"Loading existing dataset: {out_path}")
        data = np.load(out_path, allow_pickle=True)

        existing_intervals = data["intervals"]
        existing_durations = data["durations"]
        existing_song_ids = data["song_ids"]
        existing_midi_filenames = data["midi_filenames"]  # 1D array of filenames

        existing_filenames_set = set(existing_midi_filenames.tolist())

        print(f"  Existing snippets: {existing_intervals.shape[0]}")
        print(f"  Existing MIDI files: {len(existing_filenames_set)}")
    elif rebuild_all:
        print("Rebuilding dataset from scratch; ignoring existing snippets.npz (if any).")

    # --------------------------------------------------
    # Decide which files to process
    # --------------------------------------------------
    if rebuild_all or existing_midi_filenames is None:
        # process ALL files
        files_to_process = midi_files
        base_song_idx = 0
        existing_midi_filenames_list = []
    else:
        # Only process files not already in midi_filenames
        files_to_process = [p for p in midi_files if p.name not in existing_filenames_set]
        base_song_idx = len(existing_midi_filenames)
        existing_midi_filenames_list = existing_midi_filenames.tolist()

    print(f"Found {len(midi_files)} total MIDI files.")
    print(f"{len(files_to_process)} file(s) to process this run.")

    if not files_to_process:
        print("No new MIDI files found. Dataset unchanged.")
        return

    # We'll also build up the new filenames in order
    new_filenames_list = []

    # --------------------------------------------------
    # Process selected MIDI files
    # --------------------------------------------------
    for local_idx, midi_path in enumerate(files_to_process):
        # song_id for this file within the full dataset:
        # existing songs first, then new ones in order
        song_id = base_song_idx + local_idx

        print(f"Processing {midi_path.name} "
              f"({local_idx + 1}/{len(files_to_process)}), assigned song_id={song_id}")

        try:
            score = load_midi(midi_path)
        except Exception as e:
            print(f"  Failed to load {midi_path.name}: {e}")
            continue

        melody = pick_melody_part(score)
        if melody is None:
            print(f"  Skipping {midi_path.name}: no usable melodic part found.")
            continue
        melody = detect_key_and_transpose(melody)
        melody = sanitize_melody_instrument(melody)

        # Save transposed melody
        out_midi_path = TRANSPOSED_DIR / f"transposed_{midi_path.stem}.mid"
        try:
            melody.write("midi", fp=str(out_midi_path))
            print(f"  Saved transposed: {out_midi_path}")
        except Exception as e:
            print(f"  [warn] could not save transposed MIDI for {midi_path.name}: {e}")

        pitch_dur_seq = extract_pitch_duration_sequence(melody)

        if len(pitch_dur_seq) < SNIPPET_LENGTH:
            print(f"  Skipping {midi_path.name}: too few notes ({len(pitch_dur_seq)})")
            continue

        intervals, durations = convert_to_intervals_and_durations(pitch_dur_seq)

        i_snips, d_snips = make_snippets(intervals, durations, SNIPPET_LENGTH)

        # durs_q = [d for (_, d) in pitch_dur_seq]
        # i_snips, d_snips, timestamp_pairs = make_snippets_with_timestamps(intervals, durations, durs_q, SNIPPET_LENGTH)

        # genre = midi_path.parent.name.lower()  # extract genre from folder

        if i_snips.shape[0] == 0:
            print(f"  No snippets extracted from {midi_path.name}")
            continue

        all_interval_snips.append(i_snips)
        all_duration_snips.append(d_snips)
        all_song_ids.append(np.full(i_snips.shape[0], song_id, dtype=np.int32))

        new_filenames_list.append(midi_path.name)

    # --------------------------------------------------
    # If nothing new was processed successfully
    # --------------------------------------------------
    if not all_interval_snips:
        print("No snippets extracted from selected MIDI files. Dataset unchanged.")
        return

    # Stack new snippets
    new_intervals = np.vstack(all_interval_snips)
    new_durations = np.vstack(all_duration_snips)
    new_song_ids = np.concatenate(all_song_ids)

    # --------------------------------------------------
    # Merge old + new or just use new (if rebuild_all or no existing)
    # --------------------------------------------------
    if not rebuild_all and existing_intervals is not None:
        intervals_arr = np.vstack([existing_intervals, new_intervals])
        durations_arr = np.vstack([existing_durations, new_durations])
        song_ids_arr = np.concatenate([existing_song_ids, new_song_ids])

        # Append new filenames after existing, in consistent order
        midi_filenames_arr = np.array(existing_midi_filenames_list + new_filenames_list)
    else:
        intervals_arr = new_intervals
        durations_arr = new_durations
        song_ids_arr = new_song_ids

        # When rebuilding, our song_ids are aligned with files_to_process in order
        # but we still want the full midi_files list in case some were skipped.
        midi_filenames_arr = np.array(new_filenames_list)

    # --------------------------------------------------
    # Save updated dataset
    # --------------------------------------------------
    np.savez_compressed(
        out_path,
        intervals=intervals_arr,
        durations=durations_arr,
        song_ids=song_ids_arr,
        midi_filenames=midi_filenames_arr,
    )

    print(f"Saved updated snippets to {out_path}")
    print(f"Total snippets: {intervals_arr.shape[0]}")
    print(f"Total MIDI files represented: {len(midi_filenames_arr)}")


process_all_midis(rebuild_all=True)

In [51]:
# -----------------------------
# Main preprocessing
# -----------------------------
def process_all_midis_new(rebuild_all: bool = False):
    """
    Preprocess MIDI files into fixed-length snippets.

    Args:
        rebuild_all (bool): 
            - If False (default): 
                * Load existing snippets.npz (if present)
                * Only process NEW MIDI files not already in midi_filenames
                * Append their snippets to the existing dataset
            - If True:
                * Ignore existing snippets.npz
                * Rebuild dataset from ALL MIDI files in RAW_MIDI_DIR
    """
    out_path = PROCESSED_DIR / "snippets.npz"

    # --------------------------------------------------
    # Collect all MIDI filenames in the raw directory
    # --------------------------------------------------
    midi_files = sorted(list(RAW_MIDI_DIR.rglob("*.mid")) +
                        list(RAW_MIDI_DIR.rglob("*.midi")))

    if not midi_files:
        print(f"No MIDI files found in {RAW_MIDI_DIR}. Nothing to do.")
        return

    # We'll fill these as we go
    all_interval_snips = []
    all_duration_snips = []
    all_song_ids = []
    all_genre_snips = []  # <-- per-snippet genres

    # These are only used if we are appending (rebuild_all=False)
    existing_intervals = None
    existing_durations = None
    existing_song_ids = None
    existing_midi_filenames = None
    existing_genres = None  # <-- existing genres (if any)
    existing_filenames_set = set()

    # --------------------------------------------------
    # Load existing NPZ (if present and not rebuilding)
    # --------------------------------------------------
    if not rebuild_all and out_path.exists():
        print(f"Loading existing dataset: {out_path}")
        data = np.load(out_path, allow_pickle=True)

        existing_intervals = data["intervals"]
        existing_durations = data["durations"]
        existing_song_ids = data["song_ids"]
        existing_midi_filenames = data["midi_filenames"]  # 1D array of filenames

        # If an older npz didn't have genres, we'll synthesize "unknown"
        if "genres" in data.files:
            existing_genres = data["genres"]
        else:
            existing_genres = np.array(
                ["unknown"] * existing_intervals.shape[0],
                dtype=object
            )

        existing_filenames_set = set(existing_midi_filenames.tolist())

        print(f"  Existing snippets: {existing_intervals.shape[0]}")
        print(f"  Existing MIDI files: {len(existing_filenames_set)}")
    elif rebuild_all:
        print("Rebuilding dataset from scratch; ignoring existing snippets.npz (if any).")

    # --------------------------------------------------
    # Decide which files to process
    # --------------------------------------------------
    if rebuild_all or existing_midi_filenames is None:
        # process ALL files
        files_to_process = midi_files
        base_song_idx = 0
        existing_midi_filenames_list = []
    else:
        # Only process files not already in midi_filenames
        files_to_process = [p for p in midi_files if p.name not in existing_filenames_set]
        base_song_idx = len(existing_midi_filenames)
        existing_midi_filenames_list = existing_midi_filenames.tolist()

    print(f"Found {len(midi_files)} total MIDI files.")
    print(f"{len(files_to_process)} file(s) to process this run.")

    if not files_to_process:
        print("No new MIDI files found. Dataset unchanged.")
        return

    # We'll also build up the new filenames in order
    new_filenames_list = []

    # --------------------------------------------------
    # Process selected MIDI files
    # --------------------------------------------------
    for local_idx, midi_path in enumerate(files_to_process):
        # song_id for this file within the full dataset:
        # existing songs first, then new ones in order
        song_id = base_song_idx + local_idx

        print(f"Processing {midi_path.name} "
              f"({local_idx + 1}/{len(files_to_process)}), assigned song_id={song_id}")

        try:
            score = load_midi(midi_path)
        except Exception as e:
            print(f"  Failed to load {midi_path.name}: {e}")
            continue

        melody = pick_melody_part(score)
        if melody is None:
            print(f"  Skipping {midi_path.name}: no usable melodic part found.")
            continue
        melody = detect_key_and_transpose(melody)
        melody = sanitize_melody_instrument(melody)

        # Save transposed melody
        out_midi_path = TRANSPOSED_DIR / f"transposed_{midi_path.stem}.mid"
        try:
            melody.write("midi", fp=str(out_midi_path))
            print(f"  Saved transposed: {out_midi_path}")
        except Exception as e:
            print(f"  [warn] could not save transposed MIDI for {midi_path.name}: {e}")

        pitch_dur_seq = extract_pitch_duration_sequence(melody)

        if len(pitch_dur_seq) < SNIPPET_LENGTH:
            print(f"  Skipping {midi_path.name}: too few notes ({len(pitch_dur_seq)})")
            continue

        intervals, durations = convert_to_intervals_and_durations(pitch_dur_seq)

        i_snips, d_snips = make_snippets(intervals, durations, SNIPPET_LENGTH)

        if i_snips.shape[0] == 0:
            print(f"  No snippets extracted from {midi_path.name}")
            continue

        # --- Genre per snippet (from folder name) ---
        genre = midi_path.parent.name.lower()
        num_snips = i_snips.shape[0]
        all_genre_snips.append(np.array([genre] * num_snips, dtype=object))

        all_interval_snips.append(i_snips)
        all_duration_snips.append(d_snips)
        all_song_ids.append(np.full(num_snips, song_id, dtype=np.int32))

        new_filenames_list.append(midi_path.name)

    # --------------------------------------------------
    # If nothing new was processed successfully
    # --------------------------------------------------
    if not all_interval_snips:
        print("No snippets extracted from selected MIDI files. Dataset unchanged.")
        return

    # Stack new snippets
    new_intervals = np.vstack(all_interval_snips)
    new_durations = np.vstack(all_duration_snips)
    new_song_ids = np.concatenate(all_song_ids)
    new_genres = np.concatenate(all_genre_snips)  # <-- all new snippet genres

    # --------------------------------------------------
    # Merge old + new or just use new (if rebuild_all or no existing)
    # --------------------------------------------------
    if not rebuild_all and existing_intervals is not None:
        intervals_arr = np.vstack([existing_intervals, new_intervals])
        durations_arr = np.vstack([existing_durations, new_durations])
        song_ids_arr = np.concatenate([existing_song_ids, new_song_ids])

        # Merge genres (old + new)
        genres_arr = np.concatenate([existing_genres, new_genres])

        # Append new filenames after existing, in consistent order
        midi_filenames_arr = np.array(existing_midi_filenames_list + new_filenames_list)
    else:
        intervals_arr = new_intervals
        durations_arr = new_durations
        song_ids_arr = new_song_ids
        genres_arr = new_genres  # only new
        midi_filenames_arr = np.array(new_filenames_list)

    # --------------------------------------------------
    # Save updated dataset (with genres)
    # --------------------------------------------------
    np.savez_compressed(
        out_path,
        intervals=intervals_arr,
        durations=durations_arr,
        song_ids=song_ids_arr,
        midi_filenames=midi_filenames_arr,
        genres=genres_arr,  # <-- new field
    )

    print(f"Saved updated snippets to {out_path}")
    print(f"Total snippets: {intervals_arr.shape[0]}")
    print(f"Total MIDI files represented: {len(midi_filenames_arr)}")


process_all_midis_new(rebuild_all = True)

In [52]:
def process_all_midis_new_with_sec(rebuild_all: bool = False, counts: int = -1):
    """
    Preprocess MIDI files into fixed-length snippets.

    This version:
      - Uses original melody's quarter-length timing (musical time).
      - For each snippet, stores:
          * genre
          * snippet_start_idx, snippet_end_idx (note indices in melody)
          * snippet_start_qs, snippet_end_qs (cumulative quarter lengths)
          * snippet_start_secs, snippet_end_secs (seconds scaled to real song length)
          * human-readable snippet_labels:
              "<genre>_<stem>_idx000016_to000048_t0012.34s_to0020.56s"
      - Writes everything into one snippets.npz.
      - Can append to existing dataset (unless rebuild_all=True).
      - Can limit files processed via `counts`.
    """

    out_path = PROCESSED_DIR / "snippets.npz"

    # --------------------------------------------------
    # Collect MIDI files
    # --------------------------------------------------
    midi_files = sorted(
        list(RAW_MIDI_DIR.rglob("*.mid")) +
        list(RAW_MIDI_DIR.rglob("*.midi"))
    )

    if not midi_files:
        print("No MIDI files found.")
        return

    # Per-snippet accumulators
    all_interval_snips = []
    all_duration_snips = []
    all_song_ids = []
    all_genres = []
    all_snippet_labels = []
    all_start_indices = []
    all_end_indices = []
    all_start_qs = []
    all_end_qs = []
    all_start_secs = []
    all_end_secs = []

    # Existing data (for appending)
    existing_intervals = None
    existing_durations = None
    existing_song_ids = None
    existing_midi_filenames = None
    existing_genres = None
    existing_snippet_labels = None
    existing_start_indices = None
    existing_end_indices = None
    existing_start_qs = None
    existing_end_qs = None
    existing_start_secs = None
    existing_end_secs = None
    existing_filenames_set = set()

    # --------------------------------------------------
    # Load existing dataset if present
    # --------------------------------------------------
    if not rebuild_all and out_path.exists():
        print(f"Loading existing dataset from {out_path}")
        data = np.load(out_path, allow_pickle=True)

        existing_intervals = data["intervals"]
        existing_durations = data["durations"]
        existing_song_ids = data["song_ids"]
        existing_midi_filenames = data["midi_filenames"]

        existing_genres = data["genres"] if "genres" in data.files else None
        existing_snippet_labels = data["snippet_labels"] if "snippet_labels" in data.files else None

        existing_start_indices = data["snippet_start_indices"] if "snippet_start_indices" in data.files else None
        existing_end_indices = data["snippet_end_indices"] if "snippet_end_indices" in data.files else None
        existing_start_qs = data["snippet_start_qs"] if "snippet_start_qs" in data.files else None
        existing_end_qs = data["snippet_end_qs"] if "snippet_end_qs" in data.files else None
        existing_start_secs = data["snippet_start_secs"] if "snippet_start_secs" in data.files else None
        existing_end_secs = data["snippet_end_secs"] if "snippet_end_secs" in data.files else None

        existing_filenames_set = set(existing_midi_filenames.tolist())

        print(f"  Existing snippets: {existing_intervals.shape[0]}")
        print(f"  Existing MIDI songs: {len(existing_filenames_set)}")

    elif rebuild_all:
        print("Rebuilding dataset from scratch (ignoring existing snippets.npz).")

    # --------------------------------------------------
    # Decide which files to process
    # --------------------------------------------------
    if rebuild_all or existing_midi_filenames is None:
        files_to_process = midi_files
        base_song_idx = 0
        existing_filename_list = []
    else:
        files_to_process = [p for p in midi_files if p.name not in existing_filenames_set]
        base_song_idx = len(existing_midi_filenames)
        existing_filename_list = existing_midi_filenames.tolist()

    print(f"Found {len(midi_files)} total MIDI files.")
    print(f"{len(files_to_process)} new file(s) before limiting by counts.")

    if counts == -1:
        print(f"Processing all {len(files_to_process)} file(s)")
    elif counts > 0:
        files_to_process = files_to_process[:counts]
        print(f"Processing only first {len(files_to_process)} file(s) due to counts={counts}.")
    else:
        print("counts <= 0 → no processing performed.")
        return

    new_midi_filenames_list = []

    # ---------------------------------------------------------
    # Helper: extract original pitch + duration_q from melody
    # ---------------------------------------------------------
    def extract_pitch_dur_from_melody(melody_part):
        """
        Returns:
            pitches:  list[int]    (MIDI numbers)
            durs_q:   list[float]  (quarterLength)
        This is purely in musical time.
        """
        pitches = []
        durs_q = []
        for elem in melody_part.recurse().notesAndRests:
            if isinstance(elem, note.Rest):
                continue
            if isinstance(elem, note.Note):
                midi_pitch = elem.pitch.midi
                dur_q = float(elem.quarterLength)
                pitches.append(midi_pitch)
                durs_q.append(dur_q)
            elif isinstance(elem, chord.Chord):
                midi_pitch = max(n.pitch.midi for n in elem.notes)
                dur_q = float(elem.quarterLength)
                pitches.append(midi_pitch)
                durs_q.append(dur_q)
        return pitches, durs_q

    # ---------------------------------------------------------
    # Helper: make snippets + musical timing info
    # ---------------------------------------------------------
    def make_snippets_with_musical_time(intervals, durations_steps, original_durs_q, snippet_length=SNIPPET_LENGTH):
        """
        intervals:         list[int]         (pitch intervals)
        durations_steps:   list[int]         (quantized duration steps)
        original_durs_q:   list[float]       (original quarter-length durations, same indexing)
        Returns:
            interval_snips: (num_snips, L)
            duration_snips: (num_snips, L)
            start_indices:  (num_snips,)
            end_indices:    (num_snips,)
            start_qs:       (num_snips,)
            end_qs:         (num_snips,)
        """
        n = len(intervals)
        if n < snippet_length:
            return (np.empty((0, snippet_length), dtype=np.int32),
                    np.empty((0, snippet_length), dtype=np.int32),
                    np.empty((0,), dtype=np.int32),
                    np.empty((0,), dtype=np.int32),
                    np.empty((0,), dtype=float),
                    np.empty((0,), dtype=float))

        stride = snippet_length // 2

        # cumulative quarter lengths for original melody
        cum_q = np.cumsum([0.0] + list(original_durs_q))  # length n+1

        interval_snips = []
        duration_snips = []
        start_indices = []
        end_indices = []
        start_qs = []
        end_qs = []

        for start in range(0, n - snippet_length + 1, stride):
            end = start + snippet_length

            interval_snips.append(intervals[start:end])
            duration_snips.append(durations_steps[start:end])

            start_indices.append(start)
            end_indices.append(end)

            start_qs.append(cum_q[start])
            end_qs.append(cum_q[end])

        return (np.array(interval_snips, dtype=np.int32),
                np.array(duration_snips, dtype=np.int32),
                np.array(start_indices, dtype=np.int32),
                np.array(end_indices, dtype=np.int32),
                np.array(start_qs, dtype=float),
                np.array(end_qs, dtype=float))

    # ===========================================
    # PROCESS EACH FILE
    # ===========================================
    for file_idx, midi_path in enumerate(files_to_process):
        song_id = base_song_idx + file_idx
        genre = midi_path.parent.name.lower()

        print(f"\nProcessing {midi_path.name} ({file_idx + 1}/{len(files_to_process)}), genre={genre}")

        # Load score
        try:
            score = load_midi(midi_path)
        except Exception as e:
            print(f"  [error] failed to load {midi_path.name}: {e}")
            continue

        melody = pick_melody_part(score)
        if melody is None:
            print("  [warn] no melody part found; skipping.")
            continue

        # ---- Extract original (pitch, duration_q) in musical time ----
        pitches_orig, durs_orig_q = extract_pitch_dur_from_melody(melody)
        print(f"    [debug] original melodic events = {len(pitches_orig)}")

        if len(pitches_orig) < SNIPPET_LENGTH:
            print("  Skipping: too few melodic events.")
            continue

        # ---- Transpose + sanitize for saving / normalization ----
        melody_transposed = detect_key_and_transpose(melody)
        melody_transposed = sanitize_melody_instrument(melody_transposed)

        out_midi_path = TRANSPOSED_DIR / f"transposed_{midi_path.stem}.mid"
        try:
            melody_transposed.write("midi", fp=str(out_midi_path))
            print(f"  Saved transposed: {out_midi_path}")
        except Exception as e:
            print(f"  [warn] could not save transposed MIDI for {midi_path.name}: {e}")

        # ---- Convert to intervals + integer duration steps (model input) ----
        pitch_dur_seq = list(zip(pitches_orig, durs_orig_q))
        intervals, durations_steps = convert_to_intervals_and_durations(pitch_dur_seq)

        # ---- Make snippets with musical timing info ----
        (i_snips, d_snips,
         start_indices, end_indices,
         start_qs, end_qs) = make_snippets_with_musical_time(
            intervals, durations_steps, durs_orig_q
        )

        if i_snips.shape[0] == 0:
            print("  [info] No snippets extracted after slicing.")
            continue

        # ---- Approximate seconds from quarter lengths, scaled to real song length ----

        # 1) total melodic quarter length (from our cumulative Q values)
        total_q_melody = float(end_qs.max()) if end_qs.size > 0 else 0.0

        # 2) total song duration in seconds from score.secondsMap
        sm = score.secondsMap if not callable(score.secondsMap) else score.secondsMap()
        total_secs = 0.0
        for item in sm:
            off = float(item.get("offsetSeconds", 0.0))
            end_t = float(item.get("endTimeSeconds", off))
            if end_t > total_secs:
                total_secs = end_t

        # 3) derive effective sec-per-quarter so that melody span maps into real song length
        if total_q_melody > 0 and total_secs > 0:
            sec_per_q_eff = total_secs / total_q_melody
        else:
            # fallback if something weird happens
            tempos = score.flatten().getElementsByClass("MetronomeMark")
            if tempos:
                bpm = tempos[0].number
                if bpm and bpm > 0:
                    sec_per_q_eff = 60.0 / bpm
                else:
                    sec_per_q_eff = 0.5
            else:
                sec_per_q_eff = 0.5  # default 120 bpm

        start_secs = start_qs * sec_per_q_eff
        end_secs = end_qs * sec_per_q_eff

        # ---- Build labels (NO q ranges, WITH seconds) ----
        labels = []
        for s_idx, e_idx, ss, es in zip(start_indices, end_indices, start_secs, end_secs):
            label = (
                f"{genre}_{midi_path.stem}_"
                f"idx{s_idx:06d}_to{e_idx:06d}_"
                f"t{ss:07.2f}s_to{es:07.2f}s"
            )
            labels.append(label)

        # ---- Accumulate per-snippet arrays ----
        num_snips = i_snips.shape[0]
        all_interval_snips.append(i_snips)
        all_duration_snips.append(d_snips)
        all_song_ids.append(np.full(num_snips, song_id, dtype=np.int32))
        all_genres.append(np.array([genre] * num_snips, dtype=object))
        all_snippet_labels.append(np.array(labels, dtype=object))
        all_start_indices.append(start_indices)
        all_end_indices.append(end_indices)
        all_start_qs.append(start_qs)
        all_end_qs.append(end_qs)
        all_start_secs.append(start_secs)
        all_end_secs.append(end_secs)

        new_midi_filenames_list.append(midi_path.name)

    # ================================
    # MERGE + SAVE
    # ================================
    if not all_interval_snips:
        print("No snippets extracted from selected MIDI files. Dataset unchanged.")
        return

    new_intervals = np.vstack(all_interval_snips)
    new_durations = np.vstack(all_duration_snips)
    new_ids = np.concatenate(all_song_ids)
    new_genres = np.concatenate(all_genres)
    new_labels = np.concatenate(all_snippet_labels)
    new_start_indices = np.concatenate(all_start_indices)
    new_end_indices = np.concatenate(all_end_indices)
    new_start_qs = np.concatenate(all_start_qs)
    new_end_qs = np.concatenate(all_end_qs)
    new_start_secs = np.concatenate(all_start_secs)
    new_end_secs = np.concatenate(all_end_secs)

    if existing_intervals is not None and not rebuild_all:
        intervals_arr = np.vstack([existing_intervals, new_intervals])
        durations_arr = np.vstack([existing_durations, new_durations])
        song_ids_arr = np.concatenate([existing_song_ids, new_ids])

        if existing_genres is not None:
            genres_arr = np.concatenate([existing_genres, new_genres])
        else:
            genres_arr = np.concatenate(
                [np.array(["unknown"] * existing_intervals.shape[0], dtype=object), new_genres]
            )

        if existing_snippet_labels is not None:
            labels_arr = np.concatenate([existing_snippet_labels, new_labels])
        else:
            labels_arr = new_labels

        if existing_start_indices is not None:
            start_idx_arr = np.concatenate([existing_start_indices, new_start_indices])
        else:
            start_idx_arr = new_start_indices

        if existing_end_indices is not None:
            end_idx_arr = np.concatenate([existing_end_indices, new_end_indices])
        else:
            end_idx_arr = new_end_indices

        if existing_start_qs is not None:
            start_q_arr = np.concatenate([existing_start_qs, new_start_qs])
        else:
            start_q_arr = new_start_qs

        if existing_end_qs is not None:
            end_q_arr = np.concatenate([existing_end_qs, new_end_qs])
        else:
            end_q_arr = new_end_qs

        if existing_start_secs is not None:
            start_sec_arr = np.concatenate([existing_start_secs, new_start_secs])
        else:
            start_sec_arr = new_start_secs

        if existing_end_secs is not None:
            end_sec_arr = np.concatenate([existing_end_secs, new_end_secs])
        else:
            end_sec_arr = new_end_secs

        midi_filenames_arr = np.array(existing_filename_list + new_midi_filenames_list)

    else:
        intervals_arr = new_intervals
        durations_arr = new_durations
        song_ids_arr = new_ids
        genres_arr = new_genres
        labels_arr = new_labels
        start_idx_arr = new_start_indices
        end_idx_arr = new_end_indices
        start_q_arr = new_start_qs
        end_q_arr = new_end_qs
        start_sec_arr = new_start_secs
        end_sec_arr = new_end_secs
        midi_filenames_arr = np.array(new_midi_filenames_list)

    # Save final dataset
    np.savez_compressed(
        out_path,
        intervals=intervals_arr,
        durations=durations_arr,
        song_ids=song_ids_arr,
        midi_filenames=midi_filenames_arr,
        genres=genres_arr,
        snippet_labels=labels_arr,
        snippet_start_indices=start_idx_arr,
        snippet_end_indices=end_idx_arr,
        snippet_start_qs=start_q_arr,
        snippet_end_qs=end_q_arr,
        snippet_start_secs=start_sec_arr,
        snippet_end_secs=end_sec_arr,
    )

    print("\n✅ Saved final dataset:", out_path)
    print("Total snippets:", intervals_arr.shape[0])
    print("Total songs:", len(midi_filenames_arr))


process_all_midis_new_with_sec(rebuild_all = True, counts=-1)

In [53]:
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

def _process_single_midi_for_sec(args):
    """
    Worker function for parallel processing.
    Args is a tuple: (midi_path_str, snippet_length)

    Returns a dict of arrays for this MIDI, or None if skipped.
    NOTE: This function does NOT assign song_ids. That is done centrally
    in the main process after collecting all successful results.
    """
    midi_path_str, snippet_length = args
    midi_path = Path(midi_path_str)

    try:
        score = load_midi(midi_path)
    except Exception as e:
        print(f"  [worker] failed to load {midi_path.name}: {e}")
        return None

    melody = pick_melody_part(score)
    if melody is None:
        print(f"  [worker] no melody part in {midi_path.name}; skipping.")
        return None

    # ---- Extract original (pitch, duration_q) in musical time ----
    pitches_orig = []
    durs_orig_q = []
    for elem in melody.recurse().notesAndRests:
        if isinstance(elem, note.Rest):
            continue
        if isinstance(elem, note.Note):
            midi_pitch = elem.pitch.midi
            dur_q = float(elem.quarterLength)
        elif isinstance(elem, chord.Chord):
            midi_pitch = max(n.pitch.midi for n in elem.notes)
            dur_q = float(elem.quarterLength)
        else:
            continue
        pitches_orig.append(midi_pitch)
        durs_orig_q.append(dur_q)

    if len(pitches_orig) < snippet_length:
        print(f"  [worker] {midi_path.name}: too few melodic events ({len(pitches_orig)}).")
        return None

    # ---- Transpose + sanitize and save transposed melody (optional but kept) ----
    melody_transposed = detect_key_and_transpose(melody)
    melody_transposed = sanitize_melody_instrument(melody_transposed)

    out_midi_path = TRANSPOSED_DIR / f"transposed_{midi_path.stem}.mid"
    try:
        melody_transposed.write("midi", fp=str(out_midi_path))
        print(f"  [worker] Saved transposed: {out_midi_path}")
    except Exception as e:
        print(f"  [worker] could not save transposed MIDI for {midi_path.name}: {e}")

    # ---- Convert to intervals + integer duration steps (model input) ----
    pitch_dur_seq = list(zip(pitches_orig, durs_orig_q))
    intervals, durations_steps = convert_to_intervals_and_durations(pitch_dur_seq)

    # ---- Make snippets with musical timing info ----
    n = len(intervals)
    stride = snippet_length // 2
    cum_q = np.cumsum([0.0] + list(durs_orig_q))

    interval_snips = []
    duration_snips = []
    start_indices = []
    end_indices = []
    start_qs = []
    end_qs = []

    for start in range(0, n - snippet_length + 1, stride):
        end = start + snippet_length
        interval_snips.append(intervals[start:end])
        duration_snips.append(durations_steps[start:end])
        start_indices.append(start)
        end_indices.append(end)
        start_qs.append(cum_q[start])
        end_qs.append(cum_q[end])

    if not interval_snips:
        print(f"  [worker] {midi_path.name}: no snippets after slicing.")
        return None

    interval_snips = np.array(interval_snips, dtype=np.int32)
    duration_snips = np.array(duration_snips, dtype=np.int32)
    start_indices = np.array(start_indices, dtype=np.int32)
    end_indices = np.array(end_indices, dtype=np.int32)
    start_qs = np.array(start_qs, dtype=float)
    end_qs = np.array(end_qs, dtype=float)

    # ---- Approximate seconds from quarter lengths, scaled to real song length ----
    total_q_melody = float(end_qs.max()) if end_qs.size > 0 else 0.0

    sm = score.secondsMap if not callable(score.secondsMap) else score.secondsMap()
    total_secs = 0.0
    for item in sm:
        off = float(item.get("offsetSeconds", 0.0))
        end_t = float(item.get("endTimeSeconds", off))
        if end_t > total_secs:
            total_secs = end_t

    if total_q_melody > 0 and total_secs > 0:
        sec_per_q_eff = total_secs / total_q_melody
    else:
        tempos = score.flatten().getElementsByClass("MetronomeMark")
        if tempos:
            bpm = tempos[0].number
            if bpm and bpm > 0:
                sec_per_q_eff = 60.0 / bpm
            else:
                sec_per_q_eff = 0.5
        else:
            sec_per_q_eff = 0.5

    start_secs = start_qs * sec_per_q_eff
    end_secs = end_qs * sec_per_q_eff

    # ---- Genre & labels ----
    genre = midi_path.parent.name.lower()
    labels = [
        f"{genre}_{midi_path.stem}_idx{s_idx:06d}_to{e_idx:06d}_t{ss:07.2f}s_to{es:07.2f}s"
        for s_idx, e_idx, ss, es in zip(start_indices, end_indices, start_secs, end_secs)
    ]
    genres_arr = np.array([genre] * interval_snips.shape[0], dtype=object)
    labels_arr = np.array(labels, dtype=object)

    return {
        "intervals": interval_snips,
        "durations": duration_snips,
        "genres": genres_arr,
        "labels": labels_arr,
        "start_indices": start_indices,
        "end_indices": end_indices,
        "start_qs": start_qs,
        "end_qs": end_qs,
        "start_secs": start_secs,
        "end_secs": end_secs,
        "filename": midi_path.name,
    }


In [54]:
def process_all_midis_new_with_sec_parallel(
    rebuild_all: bool = False, counts: int = -1, num_workers: int = 4
):
    """
    Parallel version of process_all_midis_new_with_sec with FIXED song_id ↔ filename alignment.

    - Worker does NOT assign song_ids.
    - Main process enumerates successful results and assigns song_ids
      sequentially, so song_id == index into midi_filenames.
    """

    out_path = PROCESSED_DIR / "snippets.npz"

    # --------------------------------------------------
    # Collect MIDI files
    # --------------------------------------------------
    midi_files = sorted(
        list(RAW_MIDI_DIR.rglob("*.mid")) +
        list(RAW_MIDI_DIR.rglob("*.midi"))
    )

    if not midi_files:
        print("No MIDI files found.")
        return

    # Per-snippet accumulators
    all_interval_snips = []
    all_duration_snips = []
    all_song_ids = []
    all_genres = []
    all_snippet_labels = []
    all_start_indices = []
    all_end_indices = []
    all_start_qs = []
    all_end_qs = []
    all_start_secs = []
    all_end_secs = []

    # Existing data (for appending)
    existing_intervals = None
    existing_durations = None
    existing_song_ids = None
    existing_midi_filenames = None
    existing_genres = None
    existing_snippet_labels = None
    existing_start_indices = None
    existing_end_indices = None
    existing_start_qs = None
    existing_end_qs = None
    existing_start_secs = None
    existing_end_secs = None
    existing_filenames_set = set()

    # --------------------------------------------------
    # Load existing dataset if present
    # --------------------------------------------------
    if not rebuild_all and out_path.exists():
        print(f"Loading existing dataset from {out_path}")
        data = np.load(out_path, allow_pickle=True)

        existing_intervals = data["intervals"]
        existing_durations = data["durations"]
        existing_song_ids = data["song_ids"]
        existing_midi_filenames = data["midi_filenames"]

        existing_genres = data.get("genres", None)
        existing_snippet_labels = data.get("snippet_labels", None)
        existing_start_indices = data.get("snippet_start_indices", None)
        existing_end_indices = data.get("snippet_end_indices", None)
        existing_start_qs = data.get("snippet_start_qs", None)
        existing_end_qs = data.get("snippet_end_qs", None)
        existing_start_secs = data.get("snippet_start_secs", None)
        existing_end_secs = data.get("snippet_end_secs", None)

        existing_filenames_set = set(existing_midi_filenames.tolist())

        print(f"  Existing snippets: {existing_intervals.shape[0]}")
        print(f"  Existing MIDI songs: {len(existing_filenames_set)}")

    elif rebuild_all:
        print("Rebuilding dataset from scratch (ignoring existing snippets.npz).")

    # --------------------------------------------------
    # Decide which files to process
    # --------------------------------------------------
    if rebuild_all or existing_midi_filenames is None:
        files_to_process = midi_files
        base_song_idx = 0
        existing_filename_list = []
    else:
        files_to_process = [p for p in midi_files if p.name not in existing_filenames_set]
        base_song_idx = len(existing_midi_filenames)
        existing_filename_list = existing_midi_filenames.tolist()

    print(f"Found {len(midi_files)} total MIDI files.")
    print(f"{len(files_to_process)} new file(s) before limiting by counts.")

    if counts == -1:
        print(f"Processing all {len(files_to_process)} file(s)")
    elif counts > 0:
        files_to_process = files_to_process[:counts]
        print(f"Processing only first {len(files_to_process)} file(s) due to counts={counts}.")
    else:
        print("counts <= 0 → no processing performed.")
        return

    # --------------------------------------------------
    # PARALLEL PROCESSING
    # --------------------------------------------------
    worker_args = [
        (str(midi_path), SNIPPET_LENGTH)
        for midi_path in files_to_process
    ]

    print(f"Using {num_workers} worker process(es).")

    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(_process_single_midi_for_sec, worker_args))

    # Filter to only successful results (non-None)
    valid_results = [res for res in results if res is not None]

    new_midi_filenames_list = []
    # Assign song_ids sequentially starting from base_song_idx
    for local_song_idx, res in enumerate(valid_results):
        song_id = base_song_idx + local_song_idx
        n_snips = res["intervals"].shape[0]

        all_interval_snips.append(res["intervals"])
        all_duration_snips.append(res["durations"])
        all_song_ids.append(np.full(n_snips, song_id, dtype=np.int32))
        all_genres.append(res["genres"])
        all_snippet_labels.append(res["labels"])
        all_start_indices.append(res["start_indices"])
        all_end_indices.append(res["end_indices"])
        all_start_qs.append(res["start_qs"])
        all_end_qs.append(res["end_qs"])
        all_start_secs.append(res["start_secs"])
        all_end_secs.append(res["end_secs"])

        new_midi_filenames_list.append(res["filename"])

    # ================================
    # MERGE + SAVE
    # ================================
    if not all_interval_snips:
        print("No snippets extracted from selected MIDI files. Dataset unchanged.")
        return

    new_intervals = np.vstack(all_interval_snips)
    new_durations = np.vstack(all_duration_snips)
    new_ids = np.concatenate(all_song_ids)
    new_genres = np.concatenate(all_genres)
    new_labels = np.concatenate(all_snippet_labels)
    new_start_indices = np.concatenate(all_start_indices)
    new_end_indices = np.concatenate(all_end_indices)
    new_start_qs = np.concatenate(all_start_qs)
    new_end_qs = np.concatenate(all_end_qs)
    new_start_secs = np.concatenate(all_start_secs)
    new_end_secs = np.concatenate(all_end_secs)

    if existing_intervals is not None and not rebuild_all:
        intervals_arr = np.vstack([existing_intervals, new_intervals])
        durations_arr = np.vstack([existing_durations, new_durations])
        song_ids_arr = np.concatenate([existing_song_ids, new_ids])

        if existing_genres is not None:
            genres_arr = np.concatenate([existing_genres, new_genres])
        else:
            genres_arr = np.concatenate(
                [np.array(["unknown"] * existing_intervals.shape[0], dtype=object), new_genres]
            )

        if existing_snippet_labels is not None:
            labels_arr = np.concatenate([existing_snippet_labels, new_labels])
        else:
            labels_arr = new_labels

        if existing_start_indices is not None:
            start_idx_arr = np.concatenate([existing_start_indices, new_start_indices])
        else:
            start_idx_arr = new_start_indices

        if existing_end_indices is not None:
            end_idx_arr = np.concatenate([existing_end_indices, new_end_indices])
        else:
            end_idx_arr = new_end_indices

        if existing_start_qs is not None:
            start_q_arr = np.concatenate([existing_start_qs, new_start_qs])
        else:
            start_q_arr = new_start_qs

        if existing_end_qs is not None:
            end_q_arr = np.concatenate([existing_end_qs, new_end_qs])
        else:
            end_q_arr = new_end_qs

        if existing_start_secs is not None:
            start_sec_arr = np.concatenate([existing_start_secs, new_start_secs])
        else:
            start_sec_arr = new_start_secs

        if existing_end_secs is not None:
            end_sec_arr = np.concatenate([existing_end_secs, new_end_secs])
        else:
            end_sec_arr = new_end_secs

        midi_filenames_arr = np.array(existing_filename_list + new_midi_filenames_list)

    else:
        intervals_arr = new_intervals
        durations_arr = new_durations
        song_ids_arr = new_ids
        genres_arr = new_genres
        labels_arr = new_labels
        start_idx_arr = new_start_indices
        end_idx_arr = new_end_indices
        start_q_arr = new_start_qs
        end_q_arr = new_end_qs
        start_sec_arr = new_start_secs
        end_sec_arr = new_end_secs
        midi_filenames_arr = np.array(new_midi_filenames_list)

    # Save final dataset
    np.savez_compressed(
        out_path,
        intervals=intervals_arr,
        durations=durations_arr,
        song_ids=song_ids_arr,
        midi_filenames=midi_filenames_arr,
        genres=genres_arr,
        snippet_labels=labels_arr,
        snippet_start_indices=start_idx_arr,
        snippet_end_indices=end_idx_arr,
        snippet_start_qs=start_q_arr,
        snippet_end_qs=end_q_arr,
        snippet_start_secs=start_sec_arr,
        snippet_end_secs=end_sec_arr,
    )

    print("\n✅ Saved final dataset (parallel, aligned):", out_path)
    print("Total snippets:", intervals_arr.shape[0])
    print("Total songs:", len(midi_filenames_arr))


In [55]:
process_all_midis_new_with_sec_parallel(
    rebuild_all=True,
    counts=-1,        # or some smaller number for testing
    num_workers=8     # set to number of CPU cores or slightly less
)

Rebuilding dataset from scratch (ignoring existing snippets.npz).
Found 928 total MIDI files.
928 new file(s) before limiting by counts.
Processing all 928 file(s)
Using 8 worker process(es).
  [info] pick_melody_part: selected by stats: part_name='[mariah carey] dreamlover', avg_pitch=62.6, n_notes=686
  [info] pick_melody_part: selected by name heuristic: part_name='rhythm guitar', avg_pitch=78.0, n_notes=1
  [worker] All the Young Dudes.mid: too few melodic events (1).
  [info] pick_melody_part: selected by name heuristic: part_name='saw wave', avg_pitch=68.5, n_notes=231
  [info] pick_melody_part: selected by stats: part_name='untitled', avg_pitch=57.8, n_notes=1196
  [info] pick_melody_part: selected by stats: part_name='polysynth', avg_pitch=72.5, n_notes=1152
  [info] pick_melody_part: selected by name heuristic: part_name='mama mia', avg_pitch=66.8, n_notes=953
  [worker] Saved transposed: ../data/processed/transposed_midis/transposed_Axel_F_1.mid
  [info] pick_melody_part: sel

In [56]:
import numpy as np

data_old= np.load("../data/processed/snippets_old.npz", allow_pickle=True)
data= np.load("../data/processed/snippets.npz", allow_pickle=True)

In [57]:
print(data.files)

['intervals', 'durations', 'song_ids', 'midi_filenames', 'genres', 'snippet_labels', 'snippet_start_indices', 'snippet_end_indices', 'snippet_start_qs', 'snippet_end_qs', 'snippet_start_secs', 'snippet_end_secs']


In [58]:
print(data_old.files)

['intervals', 'durations', 'song_ids', 'midi_filenames']


In [59]:
data('midi_filenames','snippet_labels')

TypeError: 'NpzFile' object is not callable

In [60]:
data_old['intervals']

array([[  0,  -3,  -5, ...,   4,   0, -11],
       [  0,   2,  -4, ...,  -5,   0,   3],
       [  4,   3,   4, ...,   2,  -7,   7],
       ...,
       [ -1,  -5,   1, ...,   0,   0,   0],
       [  2,  -5,   3, ...,   3,   2,  -5],
       [  0,   3, -10, ...,  -3,  -2,  -2]], dtype=int32)