In [10]:
import argparse
from pathlib import Path
import numpy as np
import random


In [11]:
DATA_PATH = Path("../data/processed/snippets.npz")

In [12]:
def load_snippets(data_path: Path = DATA_PATH):
    if not data_path.exists():
        raise FileNotFoundError(f"{data_path} not found. Run preprocess_midi.py first.")
    
    # allow_pickle=True is needed because genres / labels are object arrays
    data = np.load(data_path, allow_pickle=True)
    
    intervals = data["intervals"]            # shape: (N, L)
    durations = data["durations"]            # shape: (N, L)
    song_ids = data["song_ids"]              # shape: (N,)
    
    midi_filenames = data.get("midi_filenames", None)          # shape: (num_songs,)
    genres = data.get("genres", None)                          # shape: (N,)
    snippet_labels = data.get("snippet_labels", None)          # shape: (N,)
    snippet_start_secs = data.get("snippet_start_secs", None)  # shape: (N,)
    snippet_end_secs = data.get("snippet_end_secs", None)      # shape: (N,)

    return (intervals, durations, song_ids,
            midi_filenames, genres, snippet_labels,
            snippet_start_secs, snippet_end_secs)


def levenshtein_distance(seq1, seq2):
    """
    Compute Levenshtein edit distance between two 1D sequences.
    Assumes small-ish length (like 32), so O(L^2) is fine.
    """
    len1, len2 = len(seq1), len(seq2)
    # dp[i][j] = edit distance between seq1[:i] and seq2[:j]
    dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]

    for i in range(len1 + 1):
        dp[i][0] = i
    for j in range(len2 + 1):
        dp[0][j] = j

    for i in range(1, len1 + 1):
        for j in range(1, len2 + 1):
            cost = 0 if seq1[i - 1] == seq2[j - 1] else 1
            dp[i][j] = min(
                dp[i - 1][j] + 1,      # deletion
                dp[i][j - 1] + 1,      # insertion
                dp[i - 1][j - 1] + cost,  # substitution
            )
    return dp[len1][len2]


def compute_all_distances(query_idx, intervals):
    """
    Compute distances from one query snippet to all snippets.
    Returns array of distances shape (N,).
    """
    N = intervals.shape[0]
    query_seq = intervals[query_idx]
    dists = np.zeros(N, dtype=np.float32)

    for i in range(N):
        if i == query_idx:
            dists[i] = 0.0
        else:
            dists[i] = levenshtein_distance(query_seq, intervals[i])

    return dists


In [13]:
def run_baseline(query_index=None, top_k=5):
    """
    Run baseline retrieval and print the top_k nearest neighbors.
    This is safe to call from a notebook.
    """
    (intervals,
     durations,
     song_ids,
     midi_filenames,
     genres,
     snippet_labels,
     snippet_start_secs,
     snippet_end_secs) = load_snippets()

    N, L = intervals.shape
    print(f"Loaded {N} snippets, each of length {L}.")

    # Pick query index
    if query_index is not None:
        q_idx = query_index
        if q_idx < 0 or q_idx >= N:
            raise ValueError(f"query_index {q_idx} out of range [0, {N-1}]")
    else:
        q_idx = random.randint(0, N - 1)

    q_sid = int(song_ids[q_idx])
    q_file = midi_filenames[q_sid] if midi_filenames is not None else "N/A"
    q_genre = genres[q_idx] if genres is not None else "N/A"
    q_label = snippet_labels[q_idx] if snippet_labels is not None else None
    q_start = snippet_start_secs[q_idx] if snippet_start_secs is not None else None
    q_end = snippet_end_secs[q_idx] if snippet_end_secs is not None else None

    print(f"\nUsing snippet #{q_idx} as query (song_id={q_sid}, genre={q_genre}).")
    print(f"  file={q_file}")
    if q_label is not None:
        print(f"  label={q_label}")
    if q_start is not None and q_end is not None:
        print(f"  approx time: {q_start:.2f}s → {q_end:.2f}s")

    # Compute distances
    dists = compute_all_distances(q_idx, intervals)

    # Get top-k most similar snippets (excluding the query itself)
    k = top_k
    sorted_idx = np.argsort(dists)

    print(f"\nTop {k} most similar snippets by edit distance:")
    printed = 0
    for idx in sorted_idx:
        if idx == q_idx:
            continue  # skip the query itself

        sid = int(song_ids[idx])
        fname = midi_filenames[sid] if midi_filenames is not None else "N/A"
        g = genres[idx] if genres is not None else "N/A"
        lbl = snippet_labels[idx] if snippet_labels is not None else None
        s_sec = snippet_start_secs[idx] if snippet_start_secs is not None else None
        e_sec = snippet_end_secs[idx] if snippet_end_secs is not None else None

        print(f"  idx={idx:4d}, song_id={sid}, genre={g}, dist={dists[idx]:.2f}, file={fname}")
        if lbl is not None:
            print(f"     label={lbl}")
        if s_sec is not None and e_sec is not None:
            print(f"     approx time: {s_sec:.2f}s → {e_sec:.2f}s")

        printed += 1
        if printed >= k:
            break


In [18]:
# explicitly choose a query index:
run_baseline(query_index=201, top_k=10)

Loaded 34513 snippets, each of length 32.

Using snippet #201 as query (song_id=4, genre=classic).
  file=Axel_F_1.mid
  label=classic_Axel_F_1_idx000000_to000032_t0000.00s_to0027.26s
  approx time: 0.00s → 27.26s

Top 10 most similar snippets by edit distance:
  idx= 204, song_id=4, genre=classic, dist=0.00, file=Axel_F_1.mid
     label=classic_Axel_F_1_idx000048_to000080_t0040.90s_to0068.16s
     approx time: 40.90s → 68.16s
  idx= 207, song_id=4, genre=classic, dist=4.00, file=Axel_F_1.mid
     label=classic_Axel_F_1_idx000096_to000128_t0081.79s_to0121.64s
     approx time: 81.79s → 121.64s
  idx= 212, song_id=4, genre=classic, dist=8.00, file=Axel_F_1.mid
     label=classic_Axel_F_1_idx000176_to000208_t0146.81s_to0173.03s
     approx time: 146.81s → 173.03s
  idx= 205, song_id=4, genre=classic, dist=16.00, file=Axel_F_1.mid
     label=classic_Axel_F_1_idx000064_to000096_t0054.00s_to0081.79s
     approx time: 54.00s → 81.79s
  idx= 206, song_id=4, genre=classic, dist=16.00, file=Axe

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Baseline melody similarity using edit distance.")
    parser.add_argument("--query_index", type=int, default=None,
                        help="Index of query snippet (0..N-1). If omitted, choose random.")
    parser.add_argument("--top_k", type=int, default=5, help="Number of nearest neighbors to show.")
    args = parser.parse_args()

    run_baseline(query_index=args.query_index, top_k=args.top_k)