In [7]:
import argparse
from pathlib import Path
import numpy as np
import random


In [8]:
DATA_PATH = Path("../data/processed/snippets.npz")

In [9]:
def load_snippets(data_path: Path = DATA_PATH):
    if not data_path.exists():
        raise FileNotFoundError(f"{data_path} not found. Run preprocess_midi.py first.")
    data = np.load(data_path)
    intervals = data["intervals"]   # shape: (N, L)
    durations = data["durations"]   # shape: (N, L)
    song_ids = data["song_ids"]     # shape: (N,)
    return intervals, durations, song_ids


def levenshtein_distance(seq1, seq2):
    """
    Compute Levenshtein edit distance between two 1D sequences.
    Assumes small-ish length (like 32), so O(L^2) is fine.
    """
    len1, len2 = len(seq1), len(seq2)
    # dp[i][j] = edit distance between seq1[:i] and seq2[:j]
    dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]

    for i in range(len1 + 1):
        dp[i][0] = i
    for j in range(len2 + 1):
        dp[0][j] = j

    for i in range(1, len1 + 1):
        for j in range(1, len2 + 1):
            cost = 0 if seq1[i - 1] == seq2[j - 1] else 1
            dp[i][j] = min(
                dp[i - 1][j] + 1,      # deletion
                dp[i][j - 1] + 1,      # insertion
                dp[i - 1][j - 1] + cost,  # substitution
            )
    return dp[len1][len2]


def compute_all_distances(query_idx, intervals):
    """
    Compute distances from one query snippet to all snippets.
    Returns array of distances shape (N,).
    """
    N = intervals.shape[0]
    query_seq = intervals[query_idx]
    dists = np.zeros(N, dtype=np.float32)

    for i in range(N):
        if i == query_idx:
            dists[i] = 0.0
        else:
            dists[i] = levenshtein_distance(query_seq, intervals[i])

    return dists


In [10]:
def run_baseline(query_index=None, top_k=5):
    """
    Run baseline retrieval and print the top_k nearest neighbors.
    This is safe to call from a notebook.
    """
    intervals, durations, song_ids = load_snippets()

    N, L = intervals.shape
    print(f"Loaded {N} snippets, each of length {L}.")

    # Pick query index
    if query_index is not None:
        q_idx = query_index
        if q_idx < 0 or q_idx >= N:
            raise ValueError(f"query_index {q_idx} out of range [0, {N-1}]")
    else:
        q_idx = random.randint(0, N - 1)

    print(f"Using snippet #{q_idx} as query (song_id={song_ids[q_idx]}).")

    # Compute distances
    dists = compute_all_distances(q_idx, intervals)

    # Get top-k most similar snippets (excluding the query itself)
    k = top_k
    sorted_idx = np.argsort(dists)

    print(f"\nTop {k} most similar snippets by edit distance:")
    printed = 0
    for idx in sorted_idx:
        if idx == q_idx:
            continue  # skip the query itself
        print(f"  idx={idx:4d}, song_id={int(song_ids[idx])}, dist={dists[idx]:.2f}")
        printed += 1
        if printed >= k:
            break

In [13]:
# random query, top 5 neighbors
run_baseline()

# or explicitly choose a query index:
run_baseline(query_index=10, top_k=10)

Loaded 2608 snippets, each of length 32.
Using snippet #1686 as query (song_id=19).

Top 5 most similar snippets by edit distance:
  idx=2482, song_id=29, dist=17.00
  idx= 345, song_id=3, dist=17.00
  idx=2425, song_id=27, dist=17.00
  idx= 690, song_id=7, dist=17.00
  idx= 160, song_id=1, dist=18.00
Loaded 2608 snippets, each of length 32.
Using snippet #10 as query (song_id=0).

Top 10 most similar snippets by edit distance:
  idx=1759, song_id=20, dist=18.00
  idx= 739, song_id=7, dist=18.00
  idx= 619, song_id=6, dist=18.00
  idx=1907, song_id=23, dist=19.00
  idx=2518, song_id=29, dist=19.00
  idx= 563, song_id=5, dist=20.00
  idx= 363, song_id=3, dist=20.00
  idx=1677, song_id=19, dist=20.00
  idx= 764, song_id=7, dist=20.00
  idx= 196, song_id=2, dist=20.00


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Baseline melody similarity using edit distance.")
    parser.add_argument("--query_index", type=int, default=None,
                        help="Index of query snippet (0..N-1). If omitted, choose random.")
    parser.add_argument("--top_k", type=int, default=5, help="Number of nearest neighbors to show.")
    args = parser.parse_args()

    run_baseline(query_index=args.query_index, top_k=args.top_k)