In [31]:
from pathlib import Path
import numpy as np
import torch

SNIPPETS_PATH = Path("../data/processed/snippets.npz")
AUTOENC_EMBED_PATH = Path("../data/processed/autoencoder_embeddings.npz")
CONTRASTIVE_EMBED_PATH = Path("../data/processed/contrastive_embeddings.npz")

def load_embedding_index(embed_npz_path=AUTOENC_EMBED_PATH, snippets_npz_path=SNIPPETS_PATH):
    """
    Loads:
      - embeddings: (N, D)
      - song_ids: (N,)
      - midi_filenames: per-song
      - genres: (N,) if present
      - snippet_labels, start_secs, end_secs if present
      - snippet_length
      - min_interval, vocab_size (needed for embedding new MIDI)
    """
    emb_data = np.load(embed_npz_path, allow_pickle=True)
    embeddings = emb_data["embeddings"]       # (N, D)
    song_ids   = emb_data["song_ids"]         # (N,)
    midi_filenames = emb_data.get("midi_filenames", None)
    min_interval = int(emb_data["min_interval"])
    vocab_size   = int(emb_data["vocab_size"])

    snip_data = np.load(snippets_npz_path, allow_pickle=True)
    intervals = snip_data["intervals"]        # (N, L)
    snippet_length = intervals.shape[1]

    genres = snip_data.get("genres", None)
    labels = snip_data.get("snippet_labels", None)
    start_secs = snip_data.get("snippet_start_secs", None)
    end_secs   = snip_data.get("snippet_end_secs", None)

    # sanity
    assert embeddings.shape[0] == intervals.shape[0] == song_ids.shape[0], \
        "Embeddings, intervals, and song_ids must have same length."

    return {
        "embeddings": embeddings,
        "song_ids": song_ids,
        "midi_filenames": midi_filenames,
        "min_interval": min_interval,
        "vocab_size": vocab_size,
        "genres": genres,
        "labels": labels,
        "start_secs": start_secs,
        "end_secs": end_secs,
        "snippet_length": snippet_length,
    }


def cosine_sim(a, b):
    """
    a: (N, D)
    b: (D,)
    returns: (N,) cosine similarities
    """
    a_norm = a / np.linalg.norm(a, axis=1, keepdims=True)
    b_norm = b / np.linalg.norm(b)
    return a_norm @ b_norm


In [32]:
def demo_retrieval_by_snippet_index(
    query_idx: int,
    top_k: int = 10,
    embed_npz_path=AUTOENC_EMBED_PATH,   # or CONTRASTIVE_EMBED_PATH
    snippets_npz_path=SNIPPETS_PATH,
):
    """
    Simple demo: pick an existing snippet by index and retrieve nearest neighbors
    using the precomputed embeddings. No model needed.
    """
    index = load_embedding_index(embed_npz_path, snippets_npz_path)

    embeddings = index["embeddings"]
    song_ids   = index["song_ids"]
    midi_files = index["midi_filenames"]
    genres     = index["genres"]
    labels     = index["labels"]
    start_secs = index["start_secs"]
    end_secs   = index["end_secs"]

    N = embeddings.shape[0]
    if query_idx < 0 or query_idx >= N:
        raise ValueError(f"query_idx {query_idx} out of range [0, {N-1}]")

    sims = cosine_sim(embeddings, embeddings[query_idx])
    sims[query_idx] = -np.inf  # exclude self
    sorted_idx = np.argsort(-sims)

    print(f"\n=== Retrieval demo (existing snippet) ===")
    print(f"Query snippet index: {query_idx}")
    sid_q = int(song_ids[query_idx])
    fname_q = midi_files[sid_q] if midi_files is not None else "N/A"
    genre_q = genres[query_idx] if genres is not None else "unknown"
    label_q = labels[query_idx] if labels is not None else f"snippet_{query_idx}"

    print(f"  query song_id={sid_q}, genre={genre_q}, file={fname_q}")
    print(f"  label={label_q}")
    if start_secs is not None and end_secs is not None:
        print(f"  approx time: {start_secs[query_idx]:.2f}s → {end_secs[query_idx]:.2f}s")

    print(f"\nTop {top_k} neighbors:")
    for rank, idx in enumerate(sorted_idx[:top_k], start=1):
        sid = int(song_ids[idx])
        fname = midi_files[sid] if midi_files is not None else "N/A"
        genre = genres[idx] if genres is not None else "unknown"
        label = labels[idx] if labels is not None else f"snippet_{idx}"
        print(f"#{rank:02d}  sim={sims[idx]:.3f}")
        print(f"     idx={idx}, song_id={sid}, genre={genre}, file={fname}")
        print(f"     label={label}")
        if start_secs is not None and end_secs is not None:
            print(f"     approx time: {start_secs[idx]:.2f}s → {end_secs[idx]:.2f}s")
        print()


In [33]:
demo_retrieval_by_snippet_index(
    query_idx=201,
    top_k=10,
    embed_npz_path=AUTOENC_EMBED_PATH,  # or CONTRASTIVE_EMBED_PATH
)



=== Retrieval demo (existing snippet) ===
Query snippet index: 201
  query song_id=4, genre=classic, file=Axel_F_1.mid
  label=classic_Axel_F_1_idx000000_to000032_t0000.00s_to0027.26s
  approx time: 0.00s → 27.26s

Top 10 neighbors:
#01  sim=1.000
     idx=204, song_id=4, genre=classic, file=Axel_F_1.mid
     label=classic_Axel_F_1_idx000048_to000080_t0040.90s_to0068.16s
     approx time: 40.90s → 68.16s

#02  sim=0.869
     idx=28179, song_id=749, genre=rnb, file=Feel_So_High.mid
     label=rnb_Feel_So_High_idx000192_to000224_t0088.79s_to0104.59s
     approx time: 88.79s → 104.59s

#03  sim=0.835
     idx=26177, song_id=704, genre=pop, file=When the Going Gets Tough.mid
     label=pop_When the Going Gets Tough_idx000016_to000048_t0008.73s_to0028.51s
     approx time: 8.73s → 28.51s

#04  sim=0.834
     idx=207, song_id=4, genre=classic, file=Axel_F_1.mid
     label=classic_Axel_F_1_idx000096_to000128_t0081.79s_to0121.64s
     approx time: 81.79s → 121.64s

#05  sim=0.824
     idx=248

In [34]:
# -----------------------------
# Helper functions
# -----------------------------
from typing import List, Tuple
from music21 import converter, instrument, note, chord, stream, key, interval, pitch, tempo

# how many notes per snippet (you can tweak)
SNIPPET_LENGTH = 32

# base rhythmic unit: 1 quarter note = 4 steps, so 1 step = sixteenth note
STEPS_PER_QUARTER = 4


def load_midi(filepath: Path) -> stream.Score:
    """Load a MIDI file into a music21 Score."""
    return converter.parse(str(filepath))


def pick_melody_part(score: stream.Score) -> stream.Part | None:
    """
    Heuristic for picking the 'melody' part:

    1. Skip *purely* percussion parts.
    2. If any part name/instrument name suggests 'melody/lead/right hand',
       pick that directly.
    3. Otherwise:
       - For each remaining part, compute:
         * n_notes
         * avg_pitch
       - Compute median avg_pitch across candidates.
       - Filter to parts with avg_pitch >= median (favor higher voices).
       - Among those, pick the one with the most notes; break ties by higher avg_pitch.

    Returns the chosen Part, or None if nothing suitable is found.
    """
    candidates = []

    for p in score.parts:
        insts = list(p.getInstruments())

        # Determine if this part is purely percussion (all instruments percussion-like)
        has_percussion = any(
            isinstance(i, instrument.UnpitchedPercussion) or
            ("percussion" in (i.bestName() or "").lower())
            for i in insts
        )
        has_non_percussion = any(
            not isinstance(i, instrument.UnpitchedPercussion) and
            "percussion" not in (i.bestName() or "").lower()
            for i in insts
        )

        # Skip only if it's *purely* percussion, not mixed
        if has_percussion and not has_non_percussion:
            continue

        # Collect notes/chords
        notes_chords = [n for n in p.recurse().notes if isinstance(n, (note.Note, chord.Chord))]
        if not notes_chords:
            continue

        # Basic stats
        pitches = []
        for n in notes_chords:
            if isinstance(n, note.Note):
                pitches.append(n.pitch.midi)
            elif isinstance(n, chord.Chord):
                pitches.append(max(nn.pitch.midi for nn in n.notes))

        if not pitches:
            continue

        n_notes = len(pitches)
        avg_pitch = sum(pitches) / len(pitches)

        # part/instrument names (lowercased)
        part_name = (p.partName or "").lower()
        inst_names = [str(inst.instrumentName or "").lower()
                      for inst in insts]

        candidates.append({
            "part": p,
            "n_notes": n_notes,
            "avg_pitch": avg_pitch,
            "part_name": part_name,
            "inst_names": inst_names,
        })

    if not candidates:
        print("  [warn] no suitable melodic parts; skipping this file.")
        return None

    # 1) Name-based shortcut: if any part name/instrument suggests "melody"
    name_keywords = [
        "melody", "lead", "right hand", "rh", "treble", "solo", "violin", "flute", "trumpet", "saw wave"
    ]

    def looks_like_melody(c):
        text = c["part_name"] + " " + " ".join(c["inst_names"])
        text = text.lower()
        return any(kw in text for kw in name_keywords)

    name_candidates = [c for c in candidates if looks_like_melody(c)]
    if name_candidates:
        # among these, pick the one with highest avg_pitch (just in case)
        best = max(name_candidates, key=lambda c: c["avg_pitch"])
        print(f"  [info] pick_melody_part: selected by name heuristic: "
              f"part_name='{best['part_name']}', avg_pitch={best['avg_pitch']:.1f}, n_notes={best['n_notes']}")
        return best["part"]

    # 2) Pitch-based filtering: keep only parts at or above median avg_pitch
    avg_pitches = [c["avg_pitch"] for c in candidates]
    median_pitch = sorted(avg_pitches)[len(avg_pitches) // 2]

    high_voice_candidates = [c for c in candidates if c["avg_pitch"] >= median_pitch]
    if not high_voice_candidates:
        high_voice_candidates = candidates  # fallback to all

    # 3) Among high-voice candidates, pick the one with the most notes & higher pitch
    best = max(
        high_voice_candidates,
        key=lambda c: (c["n_notes"], c["avg_pitch"])  # primary: many notes, secondary: higher pitch
    )

    print(
        f"  [info] pick_melody_part: selected by stats: "
        f"part_name='{best['part_name']}', avg_pitch={best['avg_pitch']:.1f}, "
        f"n_notes={best['n_notes']}"
    )

    return best["part"]




def detect_key_and_transpose(melody: stream.Part) -> stream.Part:
    """
    Detect key with music21 and transpose so tonic is C (for major) or A (for minor).
    If key detection fails for some reason, return the original melody.
    """
    try:
        key_guess = melody.analyze('key')
    except Exception as e:
        print("  [warn] key analysis failed, leaving melody untransposed:", e)
        return melody

    # Decide target tonic
    if key_guess.mode == 'major':
        target_pitch = pitch.Pitch('C')
    else:
        # treat minor keys as aiming for A minor tonic
        target_pitch = pitch.Pitch('A')

    # Build interval from current tonic to target tonic
    itvl = interval.Interval(key_guess.tonic, target_pitch)

    transposed = melody.transpose(itvl)
    return transposed


def extract_pitch_duration_sequence(melody: stream.Part) -> List[Tuple[int, float]]:
    """
    Extract (midi_pitch, quarter_length_duration) from a melody line.
    Ignore rests; collapse chords to their top note.
    """
    seq = []
    for elem in melody.recurse().notesAndRests:
        if isinstance(elem, note.Note):
            midi_pitch = elem.pitch.midi
            dur = float(elem.quarterLength)
            seq.append((midi_pitch, dur))
        elif isinstance(elem, chord.Chord):
            # take highest note in chord as melody approximation
            midi_pitch = max(n.pitch.midi for n in elem.notes)
            dur = float(elem.quarterLength)
            seq.append((midi_pitch, dur))
        else:
            # ignore rests and other stuff for now
            continue
    return seq


def convert_to_intervals_and_durations(
    pitch_dur_seq: List[Tuple[int, float]]
) -> Tuple[List[int], List[int]]:
    """
    Convert absolute pitches to pitch intervals and durations to integer steps.
    intervals[i] = pitch[i] - pitch[i-1], with first interval = 0
    durations[i] = round( quarter_length * STEPS_PER_QUARTER )
    """
    if not pitch_dur_seq:
        return [], []

    pitches = [p for (p, _) in pitch_dur_seq]
    durs_q = [d for (_, d) in pitch_dur_seq]

    intervals = [0]  # first note has no previous reference
    for i in range(1, len(pitches)):
        intervals.append(int(pitches[i] - pitches[i - 1]))

    durations = [max(1, int(round(d * STEPS_PER_QUARTER))) for d in durs_q]

    return intervals, durations


def make_snippets(
    intervals: List[int],
    durations: List[int],
    snippet_length: int = SNIPPET_LENGTH
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Slice sequences into fixed-length snippets.
    We use a simple sliding window with stride = snippet_length // 2 (50% overlap).
    Short sequences yield zero snippets.
    """
    assert len(intervals) == len(durations)
    n = len(intervals)
    if n < snippet_length:
        return np.empty((0, snippet_length), dtype=np.int32), np.empty((0, snippet_length), dtype=np.int32)

    stride = snippet_length // 2
    interval_snips = []
    duration_snips = []

    for start in range(0, n - snippet_length + 1, stride):
        end = start + snippet_length
        interval_snips.append(intervals[start:end])
        duration_snips.append(durations[start:end])

    return np.array(interval_snips, dtype=np.int32), np.array(duration_snips, dtype=np.int32)

def make_snippets_with_timestamps(
    intervals: List[int],
    durations: List[int],
    durations_q: List[float],
    snippet_length: int = SNIPPET_LENGTH
) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, int]]]:
    """
    Slice sequences into fixed-length snippets.
    Also return timestamp pairs (start_q, end_q) in quarter lengths.
    """
    assert len(intervals) == len(durations)
    n = len(intervals)
    if n < snippet_length:
        return np.empty((0, snippet_length), dtype=np.int32), np.empty((0, snippet_length), dtype=np.int32), []

    stride = snippet_length // 2
    interval_snips, duration_snips, timestamps = [], [], []

    cumulative_q = np.cumsum([0] + durations_q)  # cumulative time in quarter lengths

    for start in range(0, n - snippet_length + 1, stride):
        end = start + snippet_length
        interval_snips.append(intervals[start:end])
        duration_snips.append(durations[start:end])

        start_q = cumulative_q[start]
        end_q = cumulative_q[end]
        timestamps.append((int(round(start_q)), int(round(end_q))))

    return np.array(interval_snips, dtype=np.int32), np.array(duration_snips, dtype=np.int32), timestamps

from music21 import instrument

def sanitize_melody_instrument(melody_part):
    """
    Remove any existing Instrument metadata and force a clean, non-percussion
    instrument on a non-drum channel.
    """
    # Remove ALL Instrument objects from this part
    for inst in list(melody_part.recurse().getElementsByClass(instrument.Instrument)):
        try:
            melody_part.remove(inst)
        except Exception:
            pass

    # Set a friendly part name
    melody_part.partName = "Melody"

    # Insert one clean Piano instrument at the beginning
    piano = instrument.Piano()
    piano.midiProgram = 0  # Acoustic Grand
    piano.midiChannel = 0  # Channel 1 (NOT 10/drums)
    melody_part.insert(0, piano)

    return melody_part


In [35]:
from pathlib import Path
from music21 import stream, note, pitch

STEPS_PER_QUARTER = 4  # must match your preprocessing

# Where to save snippet MIDIs
SNIPPET_MIDI_DIR = Path("../data/demo/snippet_midis")
QUERY_SNIPPET_MIDI_DIR = SNIPPET_MIDI_DIR / "query"
LIB_SNIPPET_MIDI_DIR = SNIPPET_MIDI_DIR / "retrieved"

for d in [SNIPPET_MIDI_DIR, QUERY_SNIPPET_MIDI_DIR, LIB_SNIPPET_MIDI_DIR]:
    d.mkdir(parents=True, exist_ok=True)


def snippet_to_stream(interval_seq, duration_seq, base_midi_pitch=60):
    """
    Convert one snippet (intervals, durations in *steps*) into a music21 Stream.
    base_midi_pitch: starting pitch (60 = middle C).
    """
    s = stream.Stream()
    current_pitch = base_midi_pitch
    
    for interval_val, dur_steps in zip(interval_seq, duration_seq):
        current_pitch += int(interval_val)
        p = pitch.Pitch()
        p.midi = current_pitch
        
        ql = float(dur_steps) / STEPS_PER_QUARTER  # back to quarterLength
        
        n = note.Note(p)
        n.quarterLength = ql
        s.append(n)
    
    return s


def save_library_snippet_as_midi(
    lib_idx: int,
    intervals_arr: np.ndarray,
    durations_arr: np.ndarray,
    song_ids: np.ndarray,
    out_dir: Path = LIB_SNIPPET_MIDI_DIR,
    base_midi_pitch: int = 60,
):
    """
    Save a library snippet (by global index) to MIDI.
    """
    if lib_idx < 0 or lib_idx >= intervals_arr.shape[0]:
        raise ValueError(f"lib_idx {lib_idx} out of range [0, {intervals_arr.shape[0]-1}]")
    
    interval_seq = intervals_arr[lib_idx]
    duration_seq = durations_arr[lib_idx]
    
    s = snippet_to_stream(interval_seq, duration_seq, base_midi_pitch=base_midi_pitch)
    
    out_path = out_dir / f"lib_snippet_{lib_idx}_song{song_ids[lib_idx]}.mid"
    s.write("midi", fp=str(out_path))
    print(f"Saved library snippet {lib_idx} (song_id={song_ids[lib_idx]}) to {out_path}")
    return out_path


def save_query_snippet_as_midi(
    q_idx: int,
    query_intervals: np.ndarray,
    query_durations: np.ndarray,
    out_dir: Path = QUERY_SNIPPET_MIDI_DIR,
    base_midi_pitch: int = 60,
):
    """
    Save a query (input-song) snippet (by local index) to MIDI.
    query_intervals: (Q, L)
    query_durations: (Q, L)
    """
    if q_idx < 0 or q_idx >= query_intervals.shape[0]:
        raise ValueError(f"q_idx {q_idx} out of range [0, {query_intervals.shape[0]-1}]")
    
    interval_seq = query_intervals[q_idx]
    duration_seq = query_durations[q_idx]
    
    s = snippet_to_stream(interval_seq, duration_seq, base_midi_pitch=base_midi_pitch)
    
    out_path = out_dir / f"query_snippet_{q_idx}.mid"
    s.write("midi", fp=str(out_path))
    print(f"Saved query snippet {q_idx} to {out_path}")
    return out_path


In [36]:
from music21 import tempo

def embed_midi_file_to_snippets(
    midi_path,
    model,
    min_interval: int,
    snippet_length: int,
    device=None,
    model_type: str = "auto",  # "auto" or "contrastive"
):
    """
    Returns:
        embeddings: (Q, D)
        meta: dict with:
          - start_indices, end_indices
          - start_qs, end_qs
          - start_secs, end_secs
          - interval_snips: (Q, L)
          - duration_snips: (Q, L)  [in STEPS, like training]
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    midi_path = Path(midi_path)
    print(f"Loading query MIDI: {midi_path.name}")
    score = load_midi(midi_path)

    melody = pick_melody_part(score)
    if melody is None:
        print("  [warn] No melodic part found; returning empty.")
        return np.empty((0, 128)), {}

    # transpose + sanitize (same as training preprocessing)
    melody_transposed = detect_key_and_transpose(melody)
    melody_transposed = sanitize_melody_instrument(melody_transposed)

    # extract pitch/duration seq
    pitch_dur_seq = extract_pitch_duration_sequence(melody_transposed)
    print(f"  [info] extracted {len(pitch_dur_seq)} melodic events.")

    if len(pitch_dur_seq) < snippet_length:
        print(f"  [warn] Too few notes (< {snippet_length}) → no snippets.")
        return np.empty((0, 128)), {}

    # absolute pitches & quarter-length durations
    durs_q = [d for (p, d) in pitch_dur_seq]

    # convert to intervals + quantized duration steps
    intervals, durations_steps = convert_to_intervals_and_durations(pitch_dur_seq)

    # ---- Slice into snippets with indices & quarter-time ----
    N = len(intervals)
    stride = snippet_length // 2

    cum_q = np.cumsum([0.0] + list(durs_q))  # len N+1

    all_interval_snips = []
    all_duration_snips = []
    all_start_idx = []
    all_end_idx = []
    all_start_q = []
    all_end_q = []

    for start in range(0, N - snippet_length + 1, stride):
        end = start + snippet_length
        all_interval_snips.append(intervals[start:end])
        all_duration_snips.append(durations_steps[start:end])
        all_start_idx.append(start)
        all_end_idx.append(end)
        all_start_q.append(cum_q[start])
        all_end_q.append(cum_q[end])

    interval_snips = np.array(all_interval_snips, dtype=np.int32)     # (Q, L)
    duration_snips = np.array(all_duration_snips, dtype=np.int32)     # (Q, L)
    start_indices = np.array(all_start_idx, dtype=np.int32)
    end_indices = np.array(all_end_idx, dtype=np.int32)
    start_qs = np.array(all_start_q, dtype=float)
    end_qs = np.array(all_end_q, dtype=float)

    num_snips = interval_snips.shape[0]
    print(f"  [info] created {num_snips} snippets of length {snippet_length}.")

    # ---- Approximate seconds using tempo ----
    tempos = score.flatten().getElementsByClass(tempo.MetronomeMark)
    if tempos:
        bpm = tempos[0].number or 120.0
    else:
        bpm = 120.0
    sec_per_q = 60.0 / bpm
    start_secs = start_qs * sec_per_q
    end_secs = end_qs * sec_per_q

    # ---- Map intervals to token IDs (shift by min_interval) ----
    shifted_snips = interval_snips - min_interval  # (Q, L)
    shifted_snips = np.clip(shifted_snips, 0, None)  # just in case

    x = torch.tensor(shifted_snips, dtype=torch.long, device=device)

    model.eval()
    with torch.no_grad():
        if model_type == "auto":
            z = model.encode(x)    # (Q, H)
        else:  # contrastive encoder directly returns normalized embedding
            z = model(x)           # (Q, D)

    embeddings = z.cpu().numpy()

    meta = {
        "start_indices": start_indices,
        "end_indices": end_indices,
        "start_qs": start_qs,
        "end_qs": end_qs,
        "start_secs": start_secs,
        "end_secs": end_secs,
        "interval_snips": interval_snips,
        "duration_snips": duration_snips,
    }

    return embeddings, meta


In [37]:
def demo_retrieval_for_midi_snippets(
    midi_path,
    model,
    model_type: str = "auto",          # "auto" or "contrastive"
    embed_npz_path=AUTOENC_EMBED_PATH,
    snippets_npz_path=SNIPPETS_PATH,
    top_k: int = 10,
    device=None,
):
    """
    For a given MIDI file:

      1) Slice into snippets + embed each snippet.
      2) Compute cosine similarity between *every* query snippet and *every* library snippet.
      3) For each library snippet, take the maximum similarity over all query snippets.
      4) Show top_k library snippets.
      5) Additionally: save the BEST query snippet and BEST library snippet as MIDI.
    """

    # --- Load library embeddings + metadata ---
    index = load_embedding_index(embed_npz_path, snippets_npz_path)
    embeddings   = index["embeddings"]       # (N, D)
    song_ids_emb = index["song_ids"]         # (N,)
    midi_files   = index["midi_filenames"]   # per-song
    genres       = index["genres"]
    labels       = index["labels"]
    start_secs   = index["start_secs"]
    end_secs     = index["end_secs"]
    min_interval = index["min_interval"]
    snippet_length = index["snippet_length"]

    # Also need intervals & durations for library snippets to reconstruct MIDI
    snip_data = np.load(snippets_npz_path, allow_pickle=True)
    intervals_arr = snip_data["intervals"]     # (N, L)
    durations_arr = snip_data["durations"]     # (N, L)
    song_ids_snip = snip_data["song_ids"]      # (N,)
    # sanity check: should match
    assert np.array_equal(song_ids_emb, song_ids_snip)

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Embed the query MIDI into snippet embeddings ---
    query_embs, meta = embed_midi_file_to_snippets(
        midi_path=midi_path,
        model=model,
        min_interval=min_interval,
        snippet_length=snippet_length,
        device=device,
        model_type=model_type,
    )

    Q = query_embs.shape[0]
    if Q == 0:
        print("No query snippets to retrieve with.")
        return []

    query_interval_snips = meta["interval_snips"]   # (Q, L)
    query_duration_snips = meta["duration_snips"]   # (Q, L)

    # Normalize library and query embeddings
    lib_norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)  # (N, D)
    q_norm   = query_embs / np.linalg.norm(query_embs, axis=1, keepdims=True)  # (Q, D)

    # --- Similarity matrix: query_snippets x library_snippets ---
    sim_matrix = q_norm @ lib_norm.T  # (Q, N)

    # For each library snippet, take the best similarity over all query snippets
    best_sim_per_lib = sim_matrix.max(axis=0)      # (N,)
    best_q_idx_per_lib = sim_matrix.argmax(axis=0) # (N,)

    # Sort library snippets by descending similarity
    sorted_lib_idx = np.argsort(-best_sim_per_lib)

    top_k = min(top_k, embeddings.shape[0])

    print(f"\n=== Global snippet-level retrieval for: {Path(midi_path).name} ===")
    print(f"Total query snippets from this song: {Q}")
    print(f"Showing top {top_k} library snippets (across ALL query snippets).\n")

    results = []

    for rank, lib_idx in enumerate(sorted_lib_idx[:top_k], start=1):
        sim_score = best_sim_per_lib[lib_idx]
        qid = best_q_idx_per_lib[lib_idx]  # which query snippet gave this best match

        # query snippet timing
        q_s_sec = meta["start_secs"][qid]
        q_e_sec = meta["end_secs"][qid]

        # library snippet metadata
        sid   = int(song_ids_emb[lib_idx])
        fname = midi_files[sid] if midi_files is not None else "N/A"
        genre = genres[lib_idx] if genres is not None else "unknown"
        label = labels[lib_idx] if labels is not None else f"snippet_{lib_idx}"
        s_sec = start_secs[lib_idx] if start_secs is not None else None
        e_sec = end_secs[lib_idx] if end_secs is not None else None

        print(f"#{rank:02d}  sim={sim_score:.3f}")
        print(f"     BEST query snippet #{qid} in input song")
        print(f"       query time:  {q_s_sec:7.2f}s → {q_e_sec:7.2f}s")
        print(f"     library idx={lib_idx}, song_id={sid}, genre={genre}, file={fname}")
        print(f"       label: {label}")
        if s_sec is not None and e_sec is not None:
            print(f"       library time: {s_sec:7.2f}s → {e_sec:7.2f}s")
        print()

        results.append({
            "rank": rank,
            "similarity": float(sim_score),
            "library_idx": int(lib_idx),
            "library_song_id": sid,
            "library_file": fname,
            "library_genre": genre,
            "library_label": label,
            "library_start_sec": float(s_sec) if s_sec is not None else None,
            "library_end_sec": float(e_sec) if e_sec is not None else None,
            "query_snippet_idx": int(qid),
            "query_start_sec": float(q_s_sec),
            "query_end_sec": float(q_e_sec),
        })

    # ---- Save the BEST matching pair as MIDI ----
    if len(sorted_lib_idx) > 0:
        best_lib_idx = sorted_lib_idx[0]
        best_q_idx = best_q_idx_per_lib[best_lib_idx]

        print("Saving best matching query and library snippets as MIDI...")

        # Query snippet MIDI
        q_midi_path = save_query_snippet_as_midi(
            q_idx=best_q_idx,
            query_intervals=query_interval_snips,
            query_durations=query_duration_snips,
            out_dir=QUERY_SNIPPET_MIDI_DIR,
            base_midi_pitch=60,
        )

        # Library snippet MIDI
        lib_midi_path = save_library_snippet_as_midi(
            lib_idx=best_lib_idx,
            intervals_arr=intervals_arr,
            durations_arr=durations_arr,
            song_ids=song_ids_emb,
            out_dir=LIB_SNIPPET_MIDI_DIR,
            base_midi_pitch=60,
        )

        print(f"Best query snippet saved to:    {q_midi_path}")
        print(f"Best library snippet saved to:  {lib_midi_path}")

    return results


In [38]:
import torch
import torch.nn as nn
# ---------------------
# Autoencoder model
# ---------------------

class MelodyAutoencoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, hidden_dim=128, num_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.encoder_rnn = nn.GRU(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.decoder_rnn = nn.GRU(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.output_fc = nn.Linear(hidden_dim, vocab_size)

        # Learned start token for the decoder
        self.start_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def encode(self, x):
        emb = self.embed(x)
        _, h_n = self.encoder_rnn(emb)
        return h_n[-1]  # (B, H)

    def decode(self, z, seq_len):
        """
        z: (B, H)
        seq_len: int (L)
        Decoder only gets z + a learned start vector, not the target tokens.
        """
        B = z.size(0)
        h0 = z.unsqueeze(0)              # (1, B, H)

        # Repeat a learned start embedding L times as input
        # shape: (B, L, E)
        start_emb = self.start_token.expand(B, seq_len, -1)

        out, _ = self.decoder_rnn(start_emb, h0)  # (B, L, H)
        logits = self.output_fc(out)              # (B, L, V)
        return logits

    def forward(self, x):
        z = self.encode(x)
        L = x.shape[1]
        logits = self.decode(z, L)
        return logits, z



In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MelodyEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, proj_dim=128, num_layers=1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.encoder_rnn = nn.GRU(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
        )
        # hidden_dim * 2 because bidirectional
        self.proj = nn.Linear(hidden_dim * 2, proj_dim)

    def forward(self, x):
        """
        x: (B, L) token ids
        returns: (B, D) L2-normalized embedding
        """
        emb = self.token_embed(x)           # (B, L, E)
        _, h_n = self.encoder_rnn(emb)      # (2*num_layers, B, H)
        h_fw = h_n[-2]                      # (B, H)
        h_bw = h_n[-1]                      # (B, H)
        h_cat = torch.cat([h_fw, h_bw], dim=-1)  # (B, 2H)
        z = self.proj(h_cat)               # (B, D)
        z = F.normalize(z, dim=-1)         # L2-normalize for cosine similarity
        return z


In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load min_interval & vocab_size from any embedding NPZ
idx_info = load_embedding_index(AUTOENC_EMBED_PATH, SNIPPETS_PATH)
vocab_size = idx_info["vocab_size"]

# 1) Autoencoder
auto_model = MelodyAutoencoder(
    vocab_size=vocab_size,
    embed_dim=64,    # SAME values used during training
    hidden_dim=128,
)

auto_model.load_state_dict(torch.load("../models/autoencoder.pt", map_location=device))
auto_model.to(device)
auto_model.eval()

# 2) Contrastive encoder
contrastive_model = MelodyEncoder(
    vocab_size=vocab_size,
    embed_dim=128,    # same as training
    hidden_dim=256 ,
    proj_dim=128,
    num_layers=1,
)

contrastive_model.load_state_dict(torch.load("../models/contrastive_encoder.pt", map_location=device))
contrastive_model.to(device)
contrastive_model.eval()


MelodyEncoder(
  (token_embed): Embedding(184, 128)
  (encoder_rnn): GRU(128, 256, batch_first=True, bidirectional=True)
  (proj): Linear(in_features=512, out_features=128, bias=True)
)

In [42]:
# after training:
# result = train_autoencoder(...)
# auto_model = result["model"]   # or whatever variable you kept

demo_midi_path = "../data/demo/MAROON 5.She will be loved K.MID"

results = demo_retrieval_for_midi_snippets(
    midi_path=demo_midi_path,
    model=auto_model,                # or contrastive_model
    model_type="auto",               # "contrastive" if using contrastive encoder
    embed_npz_path=AUTOENC_EMBED_PATH,
    snippets_npz_path=SNIPPETS_PATH,
    top_k=10,
    device=device,
)

Loading query MIDI: MAROON 5.She will be loved K.MID
  [info] pick_melody_part: selected by name heuristic: part_name='she will be loved', avg_pitch=63.1, n_notes=2014
  [info] extracted 2014 melodic events.
  [info] created 124 snippets of length 32.

=== Global snippet-level retrieval for: MAROON 5.She will be loved K.MID ===
Total query snippets from this song: 124
Showing top 10 library snippets (across ALL query snippets).

#01  sim=0.866
     BEST query snippet #18 in input song
       query time:   207.40s →  228.63s
     library idx=24031, song_id=655, genre=pop, file=Horny 98.mid
       label: pop_Horny 98_idx000336_to000368_t0094.04s_to0102.07s
       library time:   94.04s →  102.07s

#02  sim=0.834
     BEST query snippet #53 in input song
       query time:   653.14s →  675.74s
     library idx=11233, song_id=308, genre=folk, file=Lodi_2.mid
       label: folk_Lodi_2_idx000368_to000400_t0062.31s_to0068.03s
       library time:   62.31s →   68.03s

#03  sim=0.833
     BEST 

In [29]:
# after training:
# info_contrastive = train_contrastive_encoder(...)
# contrastive_model = info_contrastive["model"]

demo_midi_path = "../data/demo/ABBA.Mamma Mia K.mid"

demo_retrieval_for_midi_snippets(
    midi_path=demo_midi_path,
    model=contrastive_model,
    model_type="contrastive",
    top_k=10,
    embed_npz_path=CONTRASTIVE_EMBED_PATH,
    snippets_npz_path=SNIPPETS_PATH,
)


Loading query MIDI: ABBA.Mamma Mia K.mid
  [info] pick_melody_part: selected by stats: part_name='mamma mia', avg_pitch=64.2, n_notes=877
  [info] extracted 877 melodic events.
  [info] created 53 snippets of length 32.

=== Global snippet-level retrieval for: ABBA.Mamma Mia K.mid ===
Total query snippets from this song: 53
Showing top 10 library snippets (across ALL query snippets).

#01  sim=0.862
     BEST query snippet #41 in input song
       query time:   376.95s →  387.52s
     library idx=10567, song_id=285, genre=folk, file=Elvis Crespo _ Pintame.mid
       label: folk_Elvis Crespo _ Pintame_idx001328_to001360_t0212.50s_to0217.13s
       library time:  212.50s →  217.13s

#02  sim=0.856
     BEST query snippet #41 in input song
       query time:   376.95s →  387.52s
     library idx=10570, song_id=285, genre=folk, file=Elvis Crespo _ Pintame.mid
       label: folk_Elvis Crespo _ Pintame_idx001376_to001408_t0220.15s_to0225.25s
       library time:  220.15s →  225.25s

#03  sim

[{'rank': 1,
  'similarity': 0.8619452118873596,
  'library_idx': 10567,
  'library_song_id': 285,
  'library_file': 'Elvis Crespo _ Pintame.mid',
  'library_genre': 'folk',
  'library_label': 'folk_Elvis Crespo _ Pintame_idx001328_to001360_t0212.50s_to0217.13s',
  'library_start_sec': 212.49905867796122,
  'library_end_sec': 217.13257887974112,
  'query_snippet_idx': 41,
  'query_start_sec': 376.95488721804514,
  'query_end_sec': 387.5187969924813},
 {'rank': 2,
  'similarity': 0.8564080595970154,
  'library_idx': 10570,
  'library_song_id': 285,
  'library_file': 'Elvis Crespo _ Pintame.mid',
  'library_genre': 'folk',
  'library_label': 'folk_Elvis Crespo _ Pintame_idx001376_to001408_t0220.15s_to0225.25s',
  'library_start_sec': 220.15285775472324,
  'library_end_sec': 225.24730404987386,
  'query_snippet_idx': 41,
  'query_start_sec': 376.95488721804514,
  'query_end_sec': 387.5187969924813},
 {'rank': 3,
  'similarity': 0.8541462421417236,
  'library_idx': 10569,
  'library_song_i