In [1]:
print('Importing libraries and loading the trained model')
import magenta.music as mm
import note_seq
from note_seq import sequences_lib
from note_seq.protobuf import music_pb2
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
from magenta.models.music_vae.trained_model import NoExtractedExamplesError
from magenta.models.music_vae.trained_model import MultipleExtractedExamplesError
import numpy as np
import os
import tensorflow.compat.v1 as tf
import random
import sqlite3
import numpy as np
import faiss  # You'll need to install this: pip install faiss-cpu
import glob
from collections import defaultdict
# for setting up setup_faiss_db.py
import json





tf.disable_v2_behavior()

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
tf.set_random_seed(SEED)
random.seed(SEED)

mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']

BASE_DIR="models/download.magenta.tensorflow.org/models/music_vae"
mel_2bar = TrainedModel(mel_2bar_config, batch_size=4, checkpoint_dir_or_path=BASE_DIR + '/checkpoints/mel_2bar_big.ckpt')

Importing libraries and loading the trained model
Instructions for updating:
non-resource variables are not supported in the long term
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, CategoricalLstmDecoder, and hparams:
{'max_seq_len': 32, 'z_size': 512, 'free_bits': 0, 'max_beta': 0.5, 'beta_rate': 0.99999, 'batch_size': 4, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_learning_rate': 1e-05, 'conditional': True, 'dec_rnn_size': [2048, 2048, 2048], 'enc_rnn_size': [2048], 'dropout_keep_prob': 1.0, 'sampling_schedule': 'inverse_sigmoid', 'sampling_rate': 1000, 'use_cudnn': False, 'residual_encoder': False, 'residual_decoder': False, 'control_preprocessing_rnn_size': [256]}
INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [2048]

INFO:tensorflow:
Decoder Cells:
  units: [2048, 2048, 2048]

Instructions for updating:
Use `tf.cast` instead.


  tf.layers.dense(
  self._kernel = self.add_variable(
  self._bias = self.add_variable(


Instructions for updating:
Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API


  mu = tf.layers.dense(
  sigma = tf.layers.dense(


Instructions for updating:
`scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.
INFO:tensorflow:Restoring parameters from models/download.magenta.tensorflow.org/models/music_vae/checkpoints/mel_2bar_big.ckpt


In [None]:
# --- Configuration and function definitions  --
DB_PATH = os.path.join('data_sets', 'midi_embeddings.db')
FAISS_INDEX_PATH = os.path.join('data_sets', 'midi_embeddings.index')
MELODY_DIR = os.path.join('data_sets', 'lmd_melodies') # Directory of extracted melodies
EMBEDDING_DIM = 512 # The dimension of your MusicVAE embeddings

def filter_pitch_range(ns, min_pitch=36, max_pitch=84):
    """Removes notes outside the specified MIDI pitch range."""
    valid_notes = [n for n in ns.notes if min_pitch <= n.pitch <= max_pitch]   
    del ns.notes[:]
    ns.notes.extend(valid_notes)   
    return ns

# --- Helper Functions (from previous iterations, still useful for cleaning) ---
def make_monophonic(ns,steps_per_quarter=4):
    """Reduces a NoteSequence to be monophonic by picking the highest note at each step."""
    if not ns.notes:
        return ns
    quantized_ns = sequences_lib.quantize_note_sequence(ns, steps_per_quarter)    
    notes_by_step = {}
    for note in quantized_ns.notes:
        # Use quantized_start_step for already quantized sequences
        if note.quantized_start_step not in notes_by_step:
            notes_by_step[note.quantized_start_step] = []
        notes_by_step[note.quantized_start_step].append(note)
    monophonic_notes = []
    for step in sorted(notes_by_step.keys()):
        notes_at_step = notes_by_step[step]
        # If multiple notes at a step, pick the highest pitch
        highest_note = max(notes_at_step, key=lambda n: n.pitch)
        monophonic_notes.append(highest_note)
    del ns.notes[:]
    ns.notes.extend(monophonic_notes)
    return ns

def snap_chunk_notes_to_grid(unquantized_chunk, steps_per_quarter):
    """
    Creates a new, unquantized NoteSequence with notes snapped to a grid.
    This is the key function. It takes a time-based chunk, finds the ideal
    quantized steps for its notes, and then creates a *new* unquantized
    sequence where the note start/end times correspond perfectly to those steps.
    Args:
      unquantized_chunk: The unquantized NoteSequence chunk.
      steps_per_quarter: The quantization resolution.
    Returns:
      A new, unquantized NoteSequence with grid-aligned note timings.
    """
    # 1. Quantize the chunk to determine the ideal grid steps for each note.
    try:
        quantized_temp_chunk = note_seq.quantize_note_sequence(
            unquantized_chunk, steps_per_quarter)
    except note_seq.BadTimeSignatureError:
        return None # Cannot process this chunk
    qpm = unquantized_chunk.tempos[0].qpm if unquantized_chunk.tempos else 120.0
    seconds_per_quarter = 60.0 / qpm
    # 2. Create a new, empty, unquantized sequence to be the output.
    grid_aligned_ns = music_pb2.NoteSequence()
    grid_aligned_ns.tempos.add().qpm = qpm
    grid_aligned_ns.ticks_per_quarter = unquantized_chunk.ticks_per_quarter
    # 3. For each note in the quantized version, create a new note in our
    #    output sequence with timings calculated from the quantized steps.
    for q_note in quantized_temp_chunk.notes:
        new_note = grid_aligned_ns.notes.add()
        new_note.pitch = q_note.pitch
        new_note.velocity = q_note.velocity
        new_note.instrument = q_note.instrument
        new_note.program = q_note.program
        # Convert quantized steps back into precise seconds
        start_quarters = q_note.quantized_start_step / steps_per_quarter
        end_quarters = q_note.quantized_end_step / steps_per_quarter
        new_note.start_time = start_quarters * seconds_per_quarter
        new_note.end_time = end_quarters * seconds_per_quarter
    # Set the total time of the new sequence.
    total_quarters = quantized_temp_chunk.total_quantized_steps / steps_per_quarter
    grid_aligned_ns.total_time = total_quarters * seconds_per_quarter
    return grid_aligned_ns

def set_program_for_all_notes(note_sequence, program_number=0):
    """
    Resets the instrument program for every note in a NoteSequence.
    Args:
      note_sequence: The note_seq.NoteSequence object to modify.
      program_number: The integer program number to set for all notes.
                      Defaults to 0 (Acoustic Grand Piano).
    Returns:
      The modified NoteSequence.
    """
    for note in note_sequence.notes:
        note.program = program_number
    return note_sequence

def estimate_tempo_from_notes(
    note_sequence: music_pb2.NoteSequence,
    min_bpm: float = 60.0,
    max_bpm: float = 240.0,
    prior_bpm: float = 120.0
) -> float:
    """
    Estimates the tempo of an unquantized NoteSequence by analyzing note onsets.

    Args:
        note_sequence: An unquantized NoteSequence object.
        min_bpm: The minimum plausible tempo to consider.
        max_bpm: The maximum plausible tempo to consider.
        prior_bpm: The tempo to prefer (e.g., 120 BPM). The algorithm will favor
                   candidates closer to this value.

    Returns:
        The estimated tempo in beats per minute (BPM). Returns prior_bpm if
        not enough notes are present to make a guess.
    """
    # 1. Extract unique, sorted note onset times
    onsets = sorted(list(set(note.start_time for note in note_sequence.notes)))

    if len(onsets) < 3:  # Need a reasonable number of notes for a good guess
        print("Warning: Too few notes to reliably estimate tempo. Returning prior.")
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 2. Calculate Inter-Onset Intervals (IOIs)
    iois = np.diff(onsets)
    if len(iois) == 0:
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 3. Build a histogram of IOIs to find the most common intervals
    # We use a small bin size to capture fine timing details
    hist, bin_edges = np.histogram(iois, bins=np.arange(0, 5, 0.01), density=False)
    
    # Find peaks in the histogram. These are our primary rhythmic intervals.
    # A simple way is to get the top N bins.
    peak_indices = np.argsort(hist)[-10:] # Get indices of 10 strongest peaks
    
    tempo_candidates = defaultdict(float)

    # 4. Generate and score tempo candidates from histogram peaks
    for i in peak_indices:
        if hist[i] < 2: # Ignore insignificant peaks
            continue
            
        # The time (in seconds) corresponding to this peak
        interval = bin_edges[i]
        
        # This interval could be a quarter note, eighth note, etc.
        # Generate hypotheses based on this interval.
        for multiple in [0.25, 0.33, 0.5, 1, 2, 3, 4]:
            potential_beat_duration = interval * multiple
            if potential_beat_duration == 0:
                continue
            
            tempo = 60.0 / potential_beat_duration
            
            if min_bpm <= tempo <= max_bpm:
                # 5. Score the candidate
                # Score part 1: Rhythmic Strength (how strong was the peak?)
                strength_score = hist[i]
                
                # Score part 2: Proximity to prior_bpm (Gaussian score)
                # This gives a high score if tempo is near prior_bpm
                proximity_score = np.exp(-0.5 * ((tempo - prior_bpm) / 20.0)**2)
                
                # Combine scores and add to any existing score for this tempo
                combined_score = strength_score * proximity_score
                tempo_candidates[tempo] += combined_score

    if not tempo_candidates:
        print("Warning: Could not find any valid tempo candidates. Returning prior.")
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 6. Return the tempo with the highest score
    best_tempo = max(tempo_candidates, key=tempo_candidates.get)
    return best_tempo

def get_embeddings_for_song(track_id: str) -> dict:
    """
    Finds all MIDI files for a given track_id, generates embeddings for each,
    and returns them in a dictionary.

    Args:
        track_id (str): The ID of the track, e.g., "TRAAAGR128F425B14B".
        root_path (str): The root path of the repository.

    Returns:
        A dictionary where keys are MIDI filenames and values are lists of
        numpy array embeddings generated from that MIDI file.
        Returns an empty dictionary if the folder is not found or contains no MIDI files.
    """
    # 1. Construct the folder path from the track_id
    # e.g., 'data_sets/lmd_melodies/A/A/A/TRAAAGR128F425B14B'
    if len(track_id) < 5:
        print(f"Error: track_id '{track_id}' is too short to build a path.")
        return {}
        
    song_folder_path = os.path.join(
        'data_sets',
        'lmd_melodies',
        track_id[2],
        track_id[3],
        track_id[4],
        track_id
    )

    if not os.path.isdir(song_folder_path):
        print(f"Warning: Directory not found at {song_folder_path}")
        return {}

    # 2. Find all MIDI files in the directory
    midi_filepaths = glob.glob(os.path.join(song_folder_path, '*.mid'))
    midi_filepaths.extend(glob.glob(os.path.join(song_folder_path, '*.midi')))

    if not midi_filepaths:
        print(f"Warning: No MIDI files found in {song_folder_path}")
        return {}

    # 3. Process each MIDI file to generate embeddings
    all_embeddings = {}
    
    

    for midi_path in midi_filepaths:
        filename = os.path.basename(midi_path)
        print(f"Processing {filename}...")
        
        try:
            # Load and quantize the MIDI file
            midi_ns = note_seq.midi_file_to_note_sequence(midi_path)
            print(f"Processing {filename}...")
            cleaned_quantized_list = []
            qpm = estimate_tempo_from_notes(midi_ns)
            seconds_per_quarter = 60.0 / qpm
            steps_per_quarter=mel_2bar_config.data_converter._steps_per_quarter 
            seconds_per_step = seconds_per_quarter / steps_per_quarter
            num_steps_per_chunk = mel_2bar_config.hparams.max_seq_len
            hop_size_in_seconds = num_steps_per_chunk * seconds_per_step # 32 / 4 = 8.0 seconds
            cleaned_ms = make_monophonic(midi_ns)
            
            cleaned_ms = snap_chunk_notes_to_grid(cleaned_ms, steps_per_quarter)
            cleaned_ms = set_program_for_all_notes(cleaned_ms, program_number=0)
            
            

            if cleaned_ms.notes:
                slices = sequences_lib.split_note_sequence(
                    note_sequence=cleaned_ms,
                    hop_size_seconds=hop_size_in_seconds  
                )
                for chunk in slices:
                    cleaned_quantized_list.append(chunk)
            embeddings=[]
            
            for chunk in cleaned_quantized_list:
                try:
                    embedding = mel_2bar.encode([chunk])
                    embeddings.append(embedding)
                except NoExtractedExamplesError as e:
                    print(f"Skipping chunk, insufficient note data")
                except MultipleExtractedExamplesError as e:
                    print(f"Skipping chunk, multiple examples extracted")    
                continue
            if(len(embeddings)>0):
                all_embeddings[filename] = embeddings
            

        except Exception as e:
            print(f"  -> Could not process file {filename}. Error: {e}")

    return all_embeddings

# --- Example Usage ---


# --- NEW: Main Orchestrator Function ---
def process_tracks_from_db(
    start_line_to_process: int=0,
    num_tracks_to_process: int = 1,
    db_path: str='data_sets/track_metadata.db'
) -> dict:
    """
    Reads track_ids from an SQLite database, generates embeddings for each,
    and returns a nested dictionary of all embeddings.

    Args:
        num_tracks_to_process (int, optional): The number of tracks to process.
                                               If None, processes all tracks. Defaults to 1.
        db_path (str): Path to the track_metadata.db SQLite file.
       

    Returns:
        A dictionary where keys are track_ids and values are the dictionaries
        returned by get_embeddings_for_song.
    """
    if not os.path.exists(db_path):
        print(f"Error: Database not found at {db_path}")
        return {}

    print(f"Connecting to database: {db_path}")
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # 1. Fetch the track_ids from the database
    query = "SELECT track_id FROM songs" # Assuming the table is named 'songs'
    if num_tracks_to_process is not None and num_tracks_to_process > 0:
        num_tracks_to_query=num_tracks_to_process
        query += f" LIMIT {start_line_to_process}, {num_tracks_to_query}"
    
    print("Fetching track_ids...")
    cursor.execute(query)
    # Fetch all rows and flatten the list of tuples [('id1',), ('id2',)] -> ['id1', 'id2']
    track_ids = [row[0] for row in cursor.fetchall()]
    conn.close()
    
    if not track_ids:
        print("No track_ids found in the database.")
        return {}

    total_tracks = len(track_ids)
    print(f"Found {total_tracks} track_ids to process.")

    # 2. Iterate through track_ids and get embeddings for each
    all_track_embeddings = {}
    for i, track_id in enumerate(track_ids):
        print(f"\n--- Processing track {i + 1}/{total_tracks}: {track_id} ---")
        
        song_embeddings = get_embeddings_for_song(track_id)
        
        if song_embeddings:
            all_track_embeddings[track_id] = song_embeddings
            print(f"-> Success: Found and processed {len(song_embeddings)} MIDI file(s) for this track.")
            
    print("\n--- Processing Complete ---")
    return all_track_embeddings











In [7]:

DB_FOLDER = os.path.join("data_sets","faiss")
INDEX_FILE = os.path.join(DB_FOLDER,"my_faiss_index.index")
METADATA_FILE =os.path.join(DB_FOLDER,"faiss_metadata.json")
FINGERPRINT_FILE = os.path.join(DB_FOLDER, "embedding_fingerprints.json")


# -------------------

def setup_faiss_database(index_file, metadata_file,EMBEDDING_DIM=512):
    """
    Initializes an empty FAISS index and an empty ID mapping,
    saving them to disk.
    """
    # Create a flat L2 index (Euclidean distance)
    # IndexFlatL2 is a simple, brute-force index. For very large datasets,
    # you might consider more advanced indices like IndexIVFFlat.
    base_index = faiss.IndexFlatL2(EMBEDDING_DIM)
    
    # Wrap it with IndexIDMap to allow custom (integer) IDs
    index_with_ids = faiss.IndexIDMap(base_index)
    
    # Save the empty index
    faiss.write_index(index_with_ids, index_file)
    print(f"Empty FAISS index (dimension {EMBEDDING_DIM}) created and saved to: {index_file}")
    
    # Initialize an empty ID mapping and save it
    # This map will store string_id -> integer_id
    string_to_int_map = {}
    with open(metadata_file, 'w') as f:
        json.dump(string_to_int_map, f)
    print(f"Empty ID metadata file created and saved to: {metadata_file}")


    # --- Create Empty Fingerprint Set ---
    # We save it as an empty list in JSON
    with open(FINGERPRINT_FILE, 'w') as f:
        json.dump([], f)
    print(f"Empty fingerprint file saved to: {FINGERPRINT_FILE}")



def add_embeddings_with_deduplication(new_embedding_matrix, new_track_ids):
    """
    Loads an existing FAISS index and adds new embeddings, skipping any
    (track_id, embedding_vector) pairs that already exist.
    """
    if not all(os.path.exists(f) for f in [INDEX_FILE, METADATA_FILE, FINGERPRINT_FILE]):
        print("Error: Database files not found. Please run setup_faiss_db.py first.")
        return

    # 1. LOAD EXISTING DATA
    index = faiss.read_index(INDEX_FILE)
    with open(METADATA_FILE, 'r') as f:
        string_to_int_map = json.load(f)
    with open(FINGERPRINT_FILE, 'r') as f:
        # Load fingerprints and convert to a set of tuples for fast lookups
        # A numpy array is not hashable, so we convert it to a tuple
        fingerprint_list = json.load(f)
        fingerprint_set = set(tuple(fp) for fp in fingerprint_list)

    print(f"Loaded index with {index.ntotal} embeddings and {len(fingerprint_set)} fingerprints.")

    # 2. FILTER OUT DUPLICATES
    truly_new_embeddings = []
    truly_new_track_ids = []
    
    for i, embedding in enumerate(new_embedding_matrix):
        track_id = new_track_ids[i]
        # Create a hashable fingerprint: (track_id, tuple(embedding_values))
        # We convert the embedding to a list then a tuple to make it hashable
        fingerprint = (track_id, tuple(embedding.tolist()))
        
        if fingerprint not in fingerprint_set:
            truly_new_embeddings.append(embedding)
            truly_new_track_ids.append(track_id)
            # Add the new fingerprint to our set to handle duplicates within the new batch itself
            fingerprint_set.add(fingerprint)

    if not truly_new_embeddings:
        print("No new unique embeddings to add. Database is already up-to-date.")
        return

    print(f"Found {len(truly_new_embeddings)} new unique embeddings to add.")

    # 3. PREPARE AND ADD THE NEW DATA
    new_embedding_array = np.array(truly_new_embeddings, dtype='float32')
    
    # Assign integer IDs, creating new ones if necessary
    next_int_id = max(string_to_int_map.values()) + 1 if string_to_int_map else 0
    integer_ids_for_faiss = []
    for tid in truly_new_track_ids:
        if tid not in string_to_int_map:
            string_to_int_map[tid] = next_int_id
            next_int_id += 1
        integer_ids_for_faiss.append(string_to_int_map[tid])
    
    integer_ids_for_faiss = np.array(integer_ids_for_faiss, dtype='int64')

    # Add to the FAISS index
    index.add_with_ids(new_embedding_array, integer_ids_for_faiss)

    # 4. SAVE EVERYTHING BACK TO DISK
    faiss.write_index(index, INDEX_FILE)
    with open(METADATA_FILE, 'w') as f:
        json.dump(string_to_int_map, f)
    
    # Convert the set of fingerprints back to a list of lists for JSON serialization
    updated_fingerprint_list = [list(fp) for fp in fingerprint_set]
    with open(FINGERPRINT_FILE, 'w') as f:
        json.dump(updated_fingerprint_list, f)
        
    print(f"Successfully added {len(truly_new_embeddings)} embeddings. Total in index: {index.ntotal}.")
    print("Database files have been updated.")


def search_nearest_track(query_embedding, k=1):
    """
    Searches the FAISS database for the nearest neighbor(s) to a given embedding.

    Args:
        query_embedding (np.ndarray or list): A single embedding vector of shape (512,) or (1, 512).
        k (int): The number of nearest neighbors to return. Defaults to 1.

    Returns:
        list: A list of tuples, where each tuple contains (track_id, distance).
              Returns an empty list if the database is not found or is empty.
    """
    # 1. VALIDATE DATABASE AND INPUTS
    if not os.path.exists(INDEX_FILE) or not os.path.exists(METADATA_FILE):
        print("Error: Database files not found. Please run setup_faiss_db.py first.")
        return []

    # Load the FAISS index
    index = faiss.read_index(INDEX_FILE)
    
    if index.ntotal == 0:
        print("Database is empty. No search can be performed.")
        return []

    # Validate the query embedding
    query_vector = np.array(query_embedding, dtype='float32')
    if query_vector.ndim == 1:
        # If the input is a 1D array, reshape it to 2D (1, 512) for FAISS
        query_vector = query_vector.reshape(1, -1)
    
    if query_vector.shape[1] != index.d:
        print(f"Error: Query embedding dimension ({query_vector.shape[1]}) does not match index dimension ({index.d}).")
        return []

    # 2. LOAD THE ID MAPPING
    # We need to create the reverse mapping from integer_id -> string_id
    with open(METADATA_FILE, 'r') as f:
        string_to_int_map = json.load(f)
        # JSON keys are strings, so we must cast the integer ID back to int
        int_to_string_map = {int(v): k for k, v in string_to_int_map.items()}

    # 3. PERFORM THE SEARCH
    # The search function returns distances and indices (the integer IDs)
    distances, indices = index.search(query_vector, k)

    # 4. PROCESS AND RETURN THE RESULTS
    results = []
    for i in range(k):
        # Get the integer ID from the search result
        integer_id = indices[0][i]
        
        # Check if the ID is valid (FAISS can return -1 for invalid results)
        if integer_id == -1:
            continue
            
        # Look up the string track_id from our map
        string_id = int_to_string_map.get(integer_id, "ID_Not_Found")
        
        # Get the corresponding distance
        distance = distances[0][i]
        
        results.append((string_id, distance))
        
    return results






first_time=False

if(first_time):
    # Ensure the files don't exist if you want to start fresh
    try:
        os.makedirs('data_sets/faiss', exist_ok=True)
        print(f"Folder 'data_sets/faiss' created successfully or already exists.")
    except OSError as e:
        print(f"Error creating folder 'data_sets/faiss': {e}")


    
    if os.path.exists(INDEX_FILE):
        os.remove(INDEX_FILE)
        print(f"Removed existing {INDEX_FILE}")
    if os.path.exists(METADATA_FILE):
        os.remove(METADATA_FILE)
        print(f"Removed existing {METADATA_FILE}")
        
    setup_faiss_database(INDEX_FILE, METADATA_FILE,EMBEDDING_DIM)
print("\nFAISS database setup complete. Ready to add embeddings.")




FAISS database setup complete. Ready to add embeddings.


In [None]:
# Calculate embeddings for a set of tracks from the database
all_embeddings = process_tracks_from_db(
    start_line_to_process=0,
    num_tracks_to_process=30)



total_embeddings = sum(len(array) for track_dict in all_embeddings.values() for array in track_dict.values())

track_ids = []

# Iterate through each track ID and its corresponding inner dictionary
for track_id, inner_dict in all_embeddings.items():
    # Calculate the total number of embeddings for the current track_id
    # This uses the sum() and generator expression method from before
    count = sum(len(array) for array in inner_dict.values())  
    # Extend the main list by adding the track_id 'count' times
    # The expression [track_id] * count creates a new list like ['document_1', 'document_1', ...]
    track_ids.extend([track_id] * count)

# Flatten all numbers into a single list
embedding_matrix = []

# Loop through the top-level dictionary
for inner_dict in all_embeddings.values():
    # Loop through the second-level dictionary
    for list_of_tuples in inner_dict.values():
        # Loop through the list of tuples
        for tpl in list_of_tuples:
            # 1. Access the first element of the tuple (the wrapper)
            wrapper = tpl[0]        
            # 2. Access the first (and only) element of the wrapper
            embedding_vector = wrapper[0]          
            # 3. Add the 512-element vector to our matrix
            embedding_matrix.append(embedding_vector)



add_embeddings_with_deduplication(embedding_matrix, track_ids)



In [8]:

# --- Example Usage ---
# This demonstrates how to call the function.
    
    # 1. Create a random "observed" embedding to simulate a query
    # In a real application, this would come from your MIDI processing
observed_embedding = np.random.rand(EMBEDDING_DIM)

    # 2. Search for the single nearest track ID (k=1)
print("--- Searching for the single nearest track ---")
nearest_results = search_nearest_track(observed_embedding, k=1)

if nearest_results:
    # The function returns a list, so we get the first element
    track_id, distance = nearest_results[0]
    print(f"The nearest track is: '{track_id}'")
    print(f"Distance: {distance:.4f}")
else:
    print("Search returned no results.")
    


--- Searching for the single nearest track ---
The nearest track is: 'TRAAAGR128F425B14B'
Distance: 585.3926
