In [1]:
print('Importing libraries and loading the trained model')

import note_seq
from note_seq import sequences_lib
from note_seq.protobuf import music_pb2
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import NoExtractedExamplesError
from magenta.models.music_vae.trained_model import MultipleExtractedExamplesError 
import tensorflow as tf

import pandas as pd # For keaping the embeddings in a pandas dataframe
import pyreadr # for conversion to rda files for statistical analysis in R

import numpy as np
import os
import random
import sqlite3
import numpy as np
import faiss  # You'll need to install this: pip install faiss-cpu
import glob
from collections import defaultdict
# for setting up setup_faiss_db.py
import json
import hashlib
from collections import defaultdict

mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']

# Path to your newly created TFLite model
TFLITE_MODEL_PATH = 'models/music_vae_encoder_tf2.tflite'


print(f"\nLoading TFLite model from: {TFLITE_MODEL_PATH}...")
# The TFLite interpreter is independent of TF sessions and graphs.
interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL_PATH)
interpreter.allocate_tensors()
# Get input and output details for later use
tflite_input_details = interpreter.get_input_details()
tflite_output_details = interpreter.get_output_details()
print("TFLite model loaded successfully.")



Importing libraries and loading the trained model

Loading TFLite model from: models/music_vae_encoder_tf2.tflite...
TFLite model loaded successfully.


In [2]:
# --- Configuration and function definitions  --
DB_PATH = os.path.join('data_sets', 'midi_embeddings.db')
FAISS_INDEX_PATH = os.path.join('data_sets', 'midi_embeddings.index')
MELODY_DIR = os.path.join('data_sets', 'lmd_melodies') # Directory of extracted melodies
EMBEDDING_DIM = 512 # The dimension of your MusicVAE embeddings

def filter_pitch_range(ns, min_pitch=36, max_pitch=84):
    """Removes notes outside the specified MIDI pitch range."""
    valid_notes = [n for n in ns.notes if min_pitch <= n.pitch <= max_pitch]   
    del ns.notes[:]
    ns.notes.extend(valid_notes)   
    return ns

# --- Helper Functions (from previous iterations, still useful for cleaning) ---
def make_monophonic(ns,steps_per_quarter=4):
    """Reduces a NoteSequence to be monophonic by picking the highest note at each step."""
    if not ns.notes:
        return ns
    quantized_ns = sequences_lib.quantize_note_sequence(ns, steps_per_quarter)    
    notes_by_step = {}
    for note in quantized_ns.notes:
        # Use quantized_start_step for already quantized sequences
        if note.quantized_start_step not in notes_by_step:
            notes_by_step[note.quantized_start_step] = []
        notes_by_step[note.quantized_start_step].append(note)
    monophonic_notes = []
    for step in sorted(notes_by_step.keys()):
        notes_at_step = notes_by_step[step]
        # If multiple notes at a step, pick the highest pitch
        highest_note = max(notes_at_step, key=lambda n: n.pitch)
        monophonic_notes.append(highest_note)
    del ns.notes[:]
    ns.notes.extend(monophonic_notes)
    return ns

def snap_chunk_notes_to_grid(unquantized_chunk, steps_per_quarter):
    """
    Creates a new, unquantized NoteSequence with notes snapped to a grid.
    This is the key function. It takes a time-based chunk, finds the ideal
    quantized steps for its notes, and then creates a *new* unquantized
    sequence where the note start/end times correspond perfectly to those steps.
    Args:
      unquantized_chunk: The unquantized NoteSequence chunk.
      steps_per_quarter: The quantization resolution.
    Returns:
      A new, unquantized NoteSequence with grid-aligned note timings.
    """
    # 1. Quantize the chunk to determine the ideal grid steps for each note.
    try:
        quantized_temp_chunk = note_seq.quantize_note_sequence(
            unquantized_chunk, steps_per_quarter)
    except note_seq.BadTimeSignatureError:
        return None # Cannot process this chunk
    qpm = unquantized_chunk.tempos[0].qpm if unquantized_chunk.tempos else 120.0
    seconds_per_quarter = 60.0 / qpm
    # 2. Create a new, empty, unquantized sequence to be the output.
    grid_aligned_ns = music_pb2.NoteSequence()
    grid_aligned_ns.tempos.add().qpm = qpm
    grid_aligned_ns.ticks_per_quarter = unquantized_chunk.ticks_per_quarter
    # 3. For each note in the quantized version, create a new note in our
    #    output sequence with timings calculated from the quantized steps.
    for q_note in quantized_temp_chunk.notes:
        new_note = grid_aligned_ns.notes.add()
        new_note.pitch = q_note.pitch
        new_note.velocity = q_note.velocity
        new_note.instrument = q_note.instrument
        new_note.program = q_note.program
        # Convert quantized steps back into precise seconds
        start_quarters = q_note.quantized_start_step / steps_per_quarter
        end_quarters = q_note.quantized_end_step / steps_per_quarter
        new_note.start_time = start_quarters * seconds_per_quarter
        new_note.end_time = end_quarters * seconds_per_quarter
    # Set the total time of the new sequence.
    total_quarters = quantized_temp_chunk.total_quantized_steps / steps_per_quarter
    grid_aligned_ns.total_time = total_quarters * seconds_per_quarter
    return grid_aligned_ns

def set_program_for_all_notes(note_sequence, program_number=0):
    """
    Resets the instrument program for every note in a NoteSequence.
    Args:
      note_sequence: The note_seq.NoteSequence object to modify.
      program_number: The integer program number to set for all notes.
                      Defaults to 0 (Acoustic Grand Piano).
    Returns:
      The modified NoteSequence.
    """
    for note in note_sequence.notes:
        note.program = program_number
    return note_sequence

def estimate_tempo_from_notes(
    note_sequence: music_pb2.NoteSequence,
    min_bpm: float = 60.0,
    max_bpm: float = 240.0,
    prior_bpm: float = 120.0
) -> float:
    """
    Estimates the tempo of an unquantized NoteSequence by analyzing note onsets.

    Args:
        note_sequence: An unquantized NoteSequence object.
        min_bpm: The minimum plausible tempo to consider.
        max_bpm: The maximum plausible tempo to consider.
        prior_bpm: The tempo to prefer (e.g., 120 BPM). The algorithm will favor
                   candidates closer to this value.

    Returns:
        The estimated tempo in beats per minute (BPM). Returns prior_bpm if
        not enough notes are present to make a guess.
    """
    # 1. Extract unique, sorted note onset times
    onsets = sorted(list(set(note.start_time for note in note_sequence.notes)))

    if len(onsets) < 3:  # Need a reasonable number of notes for a good guess
        print("Warning: Too few notes to reliably estimate tempo. Returning prior.")
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 2. Calculate Inter-Onset Intervals (IOIs)
    iois = np.diff(onsets)
    if len(iois) == 0:
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 3. Build a histogram of IOIs to find the most common intervals
    # We use a small bin size to capture fine timing details
    hist, bin_edges = np.histogram(iois, bins=np.arange(0, 5, 0.01), density=False)
    
    # Find peaks in the histogram. These are our primary rhythmic intervals.
    # A simple way is to get the top N bins.
    peak_indices = np.argsort(hist)[-10:] # Get indices of 10 strongest peaks
    
    tempo_candidates = defaultdict(float)

    # 4. Generate and score tempo candidates from histogram peaks
    for i in peak_indices:
        if hist[i] < 2: # Ignore insignificant peaks
            continue
            
        # The time (in seconds) corresponding to this peak
        interval = bin_edges[i]
        
        # This interval could be a quarter note, eighth note, etc.
        # Generate hypotheses based on this interval.
        for multiple in [0.25, 0.33, 0.5, 1, 2, 3, 4]:
            potential_beat_duration = interval * multiple
            if potential_beat_duration == 0:
                continue
            
            tempo = 60.0 / potential_beat_duration
            
            if min_bpm <= tempo <= max_bpm:
                # 5. Score the candidate
                # Score part 1: Rhythmic Strength (how strong was the peak?)
                strength_score = hist[i]
                
                # Score part 2: Proximity to prior_bpm (Gaussian score)
                # This gives a high score if tempo is near prior_bpm
                proximity_score = np.exp(-0.5 * ((tempo - prior_bpm) / 20.0)**2)
                
                # Combine scores and add to any existing score for this tempo
                combined_score = strength_score * proximity_score
                tempo_candidates[tempo] += combined_score

    if not tempo_candidates:
        print("Warning: Could not find any valid tempo candidates. Returning prior.")
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 6. Return the tempo with the highest score
    best_tempo = max(tempo_candidates, key=tempo_candidates.get)
    return best_tempo

def convert_note_sequence_to_input_tensor(note_sequence: music_pb2.NoteSequence) -> np.ndarray:
    """
    Converts a NoteSequence into the input tensor format expected by the TFLite model.
    Args:
        note_sequence: A NoteSequence object to convert.
    Returns:
        A numpy array of shape (1, sequence_length, feature_dim) suitable for model input.
    """
    data_converter = mel_2bar_config.data_converter
    extracted_tensors = data_converter.to_tensors(note_sequence)

    inputs = []
    

    if not extracted_tensors.inputs:
        raise NoExtractedExamplesError(
            'No examples extracted from NoteSequence: %s' % note_sequence)
    if len(extracted_tensors.inputs) > 1:
        raise MultipleExtractedExamplesError(
            'Multiple (%d) examples extracted from NoteSequence: %s' %
            (len(extracted_tensors.inputs), note_sequence))
    inputs.append(extracted_tensors.inputs[0])
    #controls.append(extracted_tensors.controls[0])
    #lengths.append(extracted_tensors.lengths[0])
    # Stack and reshape to match model input
    input_tensor = np.stack(inputs).astype(np.float32)
    return input_tensor
    

def get_embeddings_for_song(track_id: str) -> dict:
    """
    Finds all MIDI files for a given track_id, generates embeddings for each,
    and returns them in a dictionary.

    Args:
        track_id (str): The ID of the track, e.g., "TRAAAGR128F425B14B".
        root_path (str): The root path of the repository.

    Returns:
        A dictionary where keys are MIDI filenames and values are lists of
        numpy array embeddings generated from that MIDI file.
        Returns an empty dictionary if the folder is not found or contains no MIDI files.
    """
    # 1. Construct the folder path from the track_id
    # e.g., 'data_sets/lmd_melodies/A/A/A/TRAAAGR128F425B14B'
    if len(track_id) < 5:
        print(f"Error: track_id '{track_id}' is too short to build a path.")
        return {}
        
    song_folder_path = os.path.join(
        'data_sets',
        'lmd_melodies',
        track_id[2],
        track_id[3],
        track_id[4],
        track_id
    )

    if not os.path.isdir(song_folder_path):
        print(f"Warning: Directory not found at {song_folder_path}")
        return {}

    # 2. Find all MIDI files in the directory
    midi_filepaths = glob.glob(os.path.join(song_folder_path, '*.mid'))
    midi_filepaths.extend(glob.glob(os.path.join(song_folder_path, '*.midi')))

    if not midi_filepaths:
        print(f"Warning: No MIDI files found in {song_folder_path}")
        return {}

    # 3. Process each MIDI file to generate embeddings
    all_embeddings = {}
    
    

    for midi_path in midi_filepaths:
        filename = os.path.basename(midi_path)
        print(f"Processing {filename}...")
        
        try:
            # Load and quantize the MIDI file
            midi_ns = note_seq.midi_file_to_note_sequence(midi_path)
            print(f"Processing {filename}...")
            cleaned_quantized_list = []
            qpm = estimate_tempo_from_notes(midi_ns)
            seconds_per_quarter = 60.0 / qpm
            steps_per_quarter=mel_2bar_config.data_converter._steps_per_quarter 
            seconds_per_step = seconds_per_quarter / steps_per_quarter
            num_steps_per_chunk = mel_2bar_config.hparams.max_seq_len
            hop_size_in_seconds = num_steps_per_chunk * seconds_per_step # 32 / 4 = 8.0 seconds
            cleaned_ms = make_monophonic(midi_ns)
            
            cleaned_ms = snap_chunk_notes_to_grid(cleaned_ms, steps_per_quarter)
            cleaned_ms = set_program_for_all_notes(cleaned_ms, program_number=0)
            
            

            if cleaned_ms.notes:
                slices = sequences_lib.split_note_sequence(
                    note_sequence=cleaned_ms,
                    hop_size_seconds=hop_size_in_seconds  
                )
                for chunk in slices:
                    cleaned_quantized_list.append(chunk)
            embeddings=[]
            
            for chunk in cleaned_quantized_list:
                try:
                    input_tensor = convert_note_sequence_to_input_tensor(chunk)
                    interpreter.set_tensor(tflite_input_details[0]['index'], input_tensor)
                    interpreter.invoke()
                    tflite_embedding = interpreter.get_tensor(tflite_output_details[0]['index'])
                    embedding = np.array(tflite_embedding[0])
                    embeddings.append(embedding)
                except NoExtractedExamplesError as e:
                    print(f"Skipping chunk, insufficient note data")
                except MultipleExtractedExamplesError as e:
                    print(f"Skipping chunk, multiple examples extracted")    
                continue
            if(len(embeddings)>0):
                all_embeddings[filename] = embeddings
            

        except Exception as e:
            print(f"  -> Could not process file {filename}. Error: {e}")

    return all_embeddings

# --- Example Usage ---


# --- NEW: Main Orchestrator Function ---
def process_tracks_from_db(
    start_line_to_process: int=0,
    num_tracks_to_process: int = 1,
    db_path: str='data_sets/track_metadata.db'
) -> dict:
    """
    Reads track_ids from an SQLite database, generates embeddings for each,
    and returns a nested dictionary of all embeddings.

    Args:
        num_tracks_to_process (int, optional): The number of tracks to process.
                                               If None, processes all tracks. Defaults to 1.
        db_path (str): Path to the track_metadata.db SQLite file.
       

    Returns:
        A dictionary where keys are track_ids and values are the dictionaries
        returned by get_embeddings_for_song.
    """
    if not os.path.exists(db_path):
        print(f"Error: Database not found at {db_path}")
        return {}

    print(f"Connecting to database: {db_path}")
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # 1. Fetch the track_ids from the database
    query = "SELECT track_id FROM songs" # Assuming the table is named 'songs'
    if num_tracks_to_process is not None and num_tracks_to_process > 0:
        num_tracks_to_query=num_tracks_to_process
        query += f" LIMIT {start_line_to_process}, {num_tracks_to_query}"
    
    print("Fetching track_ids...")
    cursor.execute(query)
    # Fetch all rows and flatten the list of tuples [('id1',), ('id2',)] -> ['id1', 'id2']
    track_ids = [row[0] for row in cursor.fetchall()]
    conn.close()
    
    if not track_ids:
        print("No track_ids found in the database.")
        return {}

    total_tracks = len(track_ids)
    print(f"Found {total_tracks} track_ids to process.")

    # 2. Iterate through track_ids and get embeddings for each
    all_track_embeddings = {}
    for i, track_id in enumerate(track_ids):
        print(f"\n--- Processing track {i + 1}/{total_tracks}: {track_id} ---")
        
        song_embeddings = get_embeddings_for_song(track_id)
        
        if song_embeddings:
            all_track_embeddings[track_id] = song_embeddings
            print(f"-> Success: Found and processed {len(song_embeddings)} MIDI file(s) for this track.")
            
    print("\n--- Processing Complete ---")
    return all_track_embeddings











In [29]:

DB_FOLDER = os.path.join("data_sets","faiss")
INDEX_FILE = "faiss.index"
METADATA_FILE ="faiss_metadata.json"
FINGERPRINT_FILE =  "embedding_fingerprints.json"
RDA_FILE = "embeddings.rda"


# -------------------

class FaissManager:
    """
    Manages a FAISS index with a deduplication layer, designed to accept
    simple 1D lists of numbers as embeddings.
    """
    def __init__(self, 
                 dimension: int = 512, 
                 index_file: str = 'faiss.index', 
                 hashes_file: str = 'embeddings_hashes.json', 
                 metadata_file: str = 'metadata.json',
                 data_root: str = 'data_sets/faiss'):
        """
        Initializes the manager.

        Args:
            dimension: The dimension of the embedding vectors (e.g., 512).
            index_file: Path to save/load the FAISS index.
            hashes_file: Path to the JSON file storing unique hashes.
            data_root: Root directory for storing FAISS data files. Defaults to 'data_sets/faiss'.
        """
        self.dimension = dimension
        self.data_root = data_root
        self.index_file = os.path.join(data_root,index_file)
        self.hashes_file = os.path.join(data_root,hashes_file)
        self.metadata_file = os.path.join(data_root,metadata_file)
        
        # Load or initialize the FAISS index

        # Load or initialize the FAISS index
        self.index = self._load_faiss_index()
            
        # Load existing hashes for deduplication
        self.existing_hashes = self._load_json_to_set(self.hashes_file)
        
        # Load the metadata mapping track_ids to FAISS indices
        # Use defaultdict to simplify adding new track_ids
        self.metadata = defaultdict(list, self._load_json_to_dict(self.metadata_file))
        
        self.generate_faiss_index_to_track_id_lookup()


        print(
            f"Manager initialized with {self.index.ntotal} vectors, "
            f"{len(self.existing_hashes)} unique hashes, and "
            f"{len(self.metadata)} tracks in metadata."
        )

    def _load_faiss_index(self) -> faiss.Index:
        """Loads or creates a FAISS index."""
        if os.path.exists(self.index_file):
            print(f"Loading existing FAISS index from {self.index_file}")
            index = faiss.read_index(self.index_file)
            if index.d != self.dimension:
                raise ValueError(f"Index dimension mismatch: loaded index has {index.d}, manager expects {self.dimension}.")
            return index
        else:
            print(f"Creating new FAISS index with dimension {self.dimension}.")
            return faiss.IndexFlatL2(self.dimension)

    def _load_json_to_set(self, file_path: str) -> set:
        """Loads a JSON array from a file into a set."""
        if not os.path.exists(file_path): return set()
        try:
            with open(file_path, 'r') as f: return set(json.load(f))
        except (json.JSONDecodeError, TypeError):
            print(f"Warning: {file_path} is corrupted. Starting fresh.")
            return set()

    def _load_json_to_dict(self, file_path: str) -> dict:
        """Loads a JSON object from a file into a dict."""
        if not os.path.exists(file_path): return {}
        try:
            with open(file_path, 'r') as f: return json.load(f)
        except json.JSONDecodeError:
            print(f"Warning: {file_path} is corrupted. Starting fresh.")
            return {}

    def _generate_hash(self, track_id: str, embedding: list[float]) -> str:
        """Generates a SHA-256 hash for a track_id and embedding."""

         # --- FIX ---
    # Convert the embedding to a standard Python list if it's a NumPy array.
    # The hasattr check is a safe way to detect NumPy-like objects.
        if hasattr(embedding, 'tolist'):
            embedding_list = embedding.tolist()
        else:
            embedding_list = embedding
        # --- END FIX ---

        data_to_hash = {"track_id": track_id, "embedding": embedding_list}
        canonical_string = json.dumps(data_to_hash, sort_keys=True, separators=(',', ':'))
        return hashlib.sha256(canonical_string.encode('utf-8')).hexdigest()

    import numpy as np
# Assume other necessary parts of the class like __init__ are defined elsewhere.

    def add_embedding(self, track_id: str, embedding: list[float]) -> int | None:
        """
        Adds an embedding, updating the FAISS index and metadata if not a duplicate.

        Args:
            track_id: The identifier for the track.
            embedding: A 1D list of floats representing the embedding.

        Returns:
            The new Faiss index (int) if the embedding was added, None otherwise.
        """
        if len(embedding) != self.dimension:
            print(f"Error: Embedding length {len(embedding)} != index dimension {self.dimension}.")
            return None

        new_hash = self._generate_hash(track_id, embedding)
        if new_hash in self.existing_hashes:
            # This is now a silent failure as the batch method will handle reporting.
            # print(f"Duplicate found for (track_id, embedding) pair: '{track_id}'. Skipping.")
            return None
        
        # Get the numerical index for the new vector *before* adding it.
        new_faiss_id = self.index.ntotal
    
        # Add to FAISS index
        vector_batch = np.array([embedding], dtype='float32')
        self.index.add(vector_batch)
    
        # Update hashes
        self.existing_hashes.add(new_hash)
    
        # Update metadata
        # Assuming self.metadata is a defaultdict(list) or similar
        self.metadata[track_id].append(new_faiss_id)
    
        print(f"Added new embedding for track_id '{track_id}' at FAISS index {new_faiss_id}.")
        self.generate_faiss_index_to_track_id_lookup()
    
        # Return the new index ID on success
        return new_faiss_id

    
    
    # --- UPDATED FUNCTION ---
    def add_embeddings(self, track_ids: list[str], embeddings: list[list[float]]) -> tuple[list[str], list[list[float]], list[int]]:
        """
        Adds a batch of embeddings, ensuring no duplicates based on track_id.

        Args:
            track_ids: A list of track identifiers.
            embeddings: A list of 1D embedding lists (e.g., of 512 numbers each).

        Returns:
            A tuple containing three lists:
            - The list of track_ids for new embeddings that were successfully added.
            - The corresponding list of the new embeddings themselves.
            - The list of corresponding Faiss indices for the new embeddings.
        """
        if len(track_ids) != len(embeddings):
            raise ValueError("Input error: The number of track_ids must match the number of embeddings.")

        if not track_ids:
            print("Warning: Called add_embeddings with empty lists.")
            return [], [], []

        added_track_ids = []
        added_embeddings = []
        added_faiss_indices = []

        for track_id, embedding in zip(track_ids, embeddings):
            if len(embedding) != self.dimension:
                print(f"Warning: Skipping embedding for track '{track_id}'. Invalid dimension {len(embedding)}.")
                continue
        
            # This helper method is now expected to return the Faiss index (int) on success, or None on failure.
            new_index_id = self.add_embedding(track_id, embedding)
        
            if new_index_id is not None:
                added_track_ids.append(track_id)
                added_embeddings.append(embedding)
                added_faiss_indices.append(new_index_id)

        # Return the lists of added data, now including the Faiss indices
        return added_track_ids, added_embeddings, added_faiss_indices



    def save(self):
        """Saves the FAISS index, deduplication hashes, and metadata to disk."""
        print("\n--- Saving all data ---")


        # --- FIX: Ensure destination directories exist before writing ---
    # Create a set of unique directory paths to avoid redundant checks
        dir_paths = {
            os.path.dirname(self.index_file),
            os.path.dirname(self.hashes_file),
            os.path.dirname(self.metadata_file)
        }

        for path in dir_paths:
        # An empty path means the file is in the current directory, no need to create.
            if path:
                os.makedirs(path, exist_ok=True)
                print(f"Ensured directory exists: {path}")
    # --- END FIX ---
        
        # 1. Save FAISS index
        faiss.write_index(self.index, self.index_file)
        print(f"FAISS index with {self.index.ntotal} vectors saved to {self.index_file}.")
        
        # 2. Save hashes
        with open(self.hashes_file, 'w') as f:
            json.dump(list(self.existing_hashes), f)
        print(f"{len(self.existing_hashes)} hashes saved to {self.hashes_file}.")
        
        # 3. Save metadata
        with open(self.metadata_file, 'w') as f:
            json.dump(self.metadata, f, indent=2)
        print(f"Metadata for {len(self.metadata)} tracks saved to {self.metadata_file}.")
        
        print("Save complete.")
    
    # Add this code inside your FaissManager's __init__ method, after self.metadata is loaded.

# Create a reverse mapping from FAISS index ID -> track_id for fast lookups.
    def generate_faiss_index_to_track_id_lookup(self):
        self.reverse_metadata = {}
        for track_id, faiss_ids in self.metadata.items():
            for faiss_id in faiss_ids:
                self.reverse_metadata[faiss_id] = track_id

    
    def search_best_match(self, query_embedding: list[float]) -> tuple[str | None, float | None]:
        """
        Searches the FAISS index for the single best match to a query embedding.

        Args:
            query_embedding: A 1D list of floats representing the observed embedding.

        Returns:
            A tuple containing:
            - The track_id of the best match (or None if not found).
            - The L2 distance (similarity score) to the best match (or None if not found).
        """
        if self.index.ntotal == 0:
            print("Warning: Search attempted on an empty index.")
            return None, None

        if len(query_embedding) != self.dimension:
            print(f"Error: Query embedding has dimension {len(query_embedding)}, but index requires {self.dimension}.")
            return None, None

        # FAISS requires a 2D numpy array for queries.
        query_vector = np.array([query_embedding], dtype='float32')

        # Perform the search for the 1 nearest neighbor (k=1).
        # D: distances, I: indices
        distances, indices = self.index.search(query_vector, k=1)

        # The result for the first (and only) query vector is at index 0.
        best_faiss_id = indices[0][0]
        best_distance = distances[0][0]

        # If the index is empty, FAISS can return -1.
        if best_faiss_id == -1:
            return None, None
        
    # Use the reverse metadata map to find the track_id.
        track_id = self.reverse_metadata.get(best_faiss_id)

        if track_id is None:
            # This would indicate an inconsistency between the index and metadata.
            print(f"Warning: FAISS ID {best_faiss_id} found, but not in metadata.")
            return None, None

        return track_id, float(best_distance)
    
#===================================

class FaissRDAManager(FaissManager):
    """
    Extends FaissManager to also manage an .rda file with a specific
    data structure: track_id (factor), faiss_index (numeric), emb... (numeric).
    """
    def __init__(self, rda_path: str, *args, **kwargs):
        """
        Initializes the FaissRDAManager.

        Args:
            rda_path (str): The file path for the .rda file.
            *args, **kwargs: Arguments for the parent FaissManager (e.g., index_path, dimension).
        """
        super().__init__(*args, **kwargs)

        self.rda_path = os.path.join(self.data_root, rda_path)
        
        # Define column names
        self.embedding_columns = [f"emb{i:03d}" for i in range(1, self.dimension + 1)]
        self.all_columns = ['track_id', 'faiss_index'] + self.embedding_columns
        
        # NEW: Define the pandas dtypes to ensure correct R types.
        # 'category' -> R factor
        # 'int32'/'float32' -> R numeric
        self.column_dtypes = {
            'track_id': 'category',
            'faiss_index': 'int32'
        }
        self.column_dtypes.update({col: 'float32' for col in self.embedding_columns})
        
        self.embeddings_df = self._load_rda()
        
        print(f"FaissRDAManager initialized. Tracking {len(self.embeddings_df)} embeddings in '{self.rda_path}'.")

    def _load_rda(self) -> pd.DataFrame:
        """Loads embeddings from the .rda file or creates an empty DataFrame with correct types."""
        df = pd.DataFrame(columns=self.all_columns) # Start with a base empty frame
        if os.path.exists(self.rda_path):
            try:
                print(f"Loading existing embeddings from '{self.rda_path}'...")
                r_objects = pyreadr.read_r(self.rda_path)
                loaded_df = r_objects.get("embeddings_data")
                if loaded_df is not None:
                    df = loaded_df
                    df.columns = self.all_columns # Ensure column names are correct
                else:
                    print("Warning: RDA file exists but contains no 'embeddings_data' object.")
            except Exception as e:
                print(f"Warning: Could not read RDA file at '{self.rda_path}'. Error: {e}.")
        
        # NEW: Enforce the correct dtypes on the loaded or newly created DataFrame.
        # This is the crucial step for ensuring type consistency.
        return df.astype(self.column_dtypes)

    def _update_rda_file(self):
        """Writes the current DataFrame of embeddings to the .rda file."""
        if self.embeddings_df.empty:
            return

        print(f"Updating RDA file at '{self.rda_path}'...")
        # 3. Verify the dtypes before writing (optional but good for debugging)
        print("--- dtypes of the final DataFrame before writing ---")
        print(self.embeddings_df[self.embeddings_df.keys()[4 ]].dtypes)
        print(self.embeddings_df[self.embeddings_df.keys()[4 ]].values[0])
        # --- FIX for float32 not being recognized in pyrdr: Convert all float32 columns to float64 for pyreadr compatibility ---
        for col in self.embeddings_df.select_dtypes(include=['float32']).columns:
            self.embeddings_df[col] = self.embeddings_df[col].astype('float64')

        self.embeddings_df.columns = self.embeddings_df.columns.astype(str)
        pyreadr.write_rdata(self.rda_path, self.embeddings_df, df_name="embeddings_data")
        print("RDA file successfully updated.")

    def add_embeddings(self, track_ids: list[str], embeddings: list[list[float]]) -> tuple[list[str], list[list[float]], list[int]]:
        """
        Adds new embeddings to the FAISS index and updates the .rda file
        with the structured DataFrame.
        """
        newly_added_ids, newly_added_embeddings, newly_added_faiss_indices = super().add_embeddings(track_ids, embeddings)

        if newly_added_embeddings:
            print(f"Detected {len(newly_added_embeddings)} new embeddings to add to RDA file.")
            
            new_data = {
                'track_id': newly_added_ids,
                'faiss_index': newly_added_faiss_indices,
                **dict(zip(self.embedding_columns, np.array(newly_added_embeddings, dtype=np.float32).T))
            }
            new_df = pd.DataFrame(new_data)
            
            # NEW: Enforce the schema on the new data before concatenating.

            new_df = new_df.astype(self.column_dtypes)
            
            # Append to the main DataFrame
            self.embeddings_df = pd.concat([self.embeddings_df, new_df], ignore_index=True).astype(self.column_dtypes)

            # Iterate through the columns you want to convert
            for col, dtype in self.column_dtypes.items():
                if 'float' in str(dtype):
                # Use pd.to_numeric with errors='coerce'.
                # This will turn any non-numeric strings (like '') into NaN.
                    self.embeddings_df[col] = pd.to_numeric(self.embeddings_df[col], errors='coerce')
            
            self._update_rda_file()
        else:
            print("No new non-duplicate embeddings were added. RDA file is already up-to-date.")
        
        return newly_added_ids, newly_added_embeddings, newly_added_faiss_indices








### Instantiate it for further use




#===========================================
# --
manager = FaissRDAManager( rda_path=RDA_FILE,
                           dimension=EMBEDDING_DIM,
                           index_file=INDEX_FILE,
                           hashes_file=FINGERPRINT_FILE,
                           metadata_file=METADATA_FILE,
                           data_root=DB_FOLDER)



    # Add multiple embeddings for the same track


 





Creating new FAISS index with dimension 512.
Manager initialized with 0 vectors, 0 unique hashes, and 0 tracks in metadata.
FaissRDAManager initialized. Tracking 0 embeddings in 'data_sets\faiss\embeddings.rda'.


In [4]:
# Calculate embeddings for a set of tracks from the database
all_embeddings = process_tracks_from_db(
    start_line_to_process=0,
    num_tracks_to_process=50)










Connecting to database: data_sets/track_metadata.db
Fetching track_ids...
Found 50 track_ids to process.

--- Processing track 1/50: TRAAAAK128F9318786 ---

--- Processing track 2/50: TRAAAAV128F421A322 ---

--- Processing track 3/50: TRAAAAW128F429D538 ---

--- Processing track 4/50: TRAAAAY128F42A73F0 ---

--- Processing track 5/50: TRAAABD128F429CF47 ---

--- Processing track 6/50: TRAAACN128F9355673 ---

--- Processing track 7/50: TRAAACV128F423E09E ---

--- Processing track 8/50: TRAAADJ128F4287B47 ---

--- Processing track 9/50: TRAAADT12903CCC339 ---

--- Processing track 10/50: TRAAADZ128F9348C2E ---

--- Processing track 11/50: TRAAAEA128F935A30D ---

--- Processing track 12/50: TRAAAED128E0783FAB ---

--- Processing track 13/50: TRAAAEF128F4273421 ---

--- Processing track 14/50: TRAAAEM128F93347B9 ---

--- Processing track 15/50: TRAAAEW128F42930C0 ---

--- Processing track 16/50: TRAAAFD128F92F423A ---

--- Processing track 17/50: TRAAAFI12903CE4F0E ---

--- Processing trac

In [5]:
total_embeddings = sum(len(array) for track_dict in all_embeddings.values() for array in track_dict.values())

track_ids = []

# Iterate through each track ID and its corresponding inner dictionary
for track_id, inner_dict in all_embeddings.items():
    # Calculate the total number of embeddings for the current track_id
    # This uses the sum() and generator expression method from before
    count = sum(len(array) for array in inner_dict.values())  
    # Extend the main list by adding the track_id 'count' times
    # The expression [track_id] * count creates a new list like ['document_1', 'document_1', ...]
    track_ids.extend([track_id] * count)

# Flatten all numbers into a single list
embedding_matrix = []

# Loop through the top-level dictionary
for inner_dict in all_embeddings.values():
    # Loop through the second-level dictionary
    for list_of_tuples in inner_dict.values():
        # Loop through the list of tuples
        for tpl in list_of_tuples:
            # 1. Access the first element of the tuple (the wrapper)
            embedding_vector = tpl               
            # 3. Add the 512-element vector to our matrix
            embedding_matrix.append(embedding_vector)

print(f"\nTotal embeddings to add to FAISS index: {len(embedding_matrix)}")
print(f"\nLength of each embedding vector: {len(embedding_matrix[0])}")


Total embeddings to add to FAISS index: 150

Length of each embedding vector: 512


In [30]:
manager.add_embeddings(track_ids, embedding_matrix) 
manager.save()

Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 0.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 1.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 2.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 3.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 4.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 5.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 6.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 7.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 8.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 9.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 10.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 11.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAISS index 12.
Added new embedding for track_id 'TRAAAGR128F425B14B' at FAIS

  self.embeddings_df = pd.concat([self.embeddings_df, new_df], ignore_index=True).astype(self.column_dtypes)


Updating RDA file at 'data_sets\faiss\embeddings.rda'...
--- dtypes of the final DataFrame before writing ---
float32
-0.12608287
RDA file successfully updated.

--- Saving all data ---
Ensured directory exists: data_sets\faiss
FAISS index with 88 vectors saved to data_sets\faiss\faiss.index.
88 hashes saved to data_sets\faiss\embedding_fingerprints.json.
Metadata for 1 tracks saved to data_sets\faiss\faiss_metadata.json.
Save complete.


In [7]:
# Example query (with made-up embedding data)
query_vec = [0.1] * manager.dimension

#     # Perform the searc

found_track_id, distance = manager.search_best_match(query_vec)
#
if found_track_id:
    print(f"\nSearch Result:")
    print(f"Best match found is track_id: '{found_track_id}'")
    print(f"L2 Distance: {distance}")
else:
    print("\nNo match found in the index.")



Search Result:
Best match found is track_id: 'TRAAAGR128F425B14B'
L2 Distance: 98.55679321289062
