In [2]:
print('Importing libraries and loading the trained model')
import magenta.music as mm
import note_seq
from note_seq import sequences_lib
from note_seq.protobuf import music_pb2
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
from magenta.models.music_vae.trained_model import NoExtractedExamplesError
from magenta.models.music_vae.trained_model import MultipleExtractedExamplesError
import numpy as np
import os
import tensorflow.compat.v1 as tf
import random
import sqlite3
import numpy as np
import faiss  # You'll need to install this: pip install faiss-cpu
import glob
from collections import defaultdict




tf.disable_v2_behavior()

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
tf.set_random_seed(SEED)
random.seed(SEED)

mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']

BASE_DIR="models/download.magenta.tensorflow.org/models/music_vae"
mel_2bar = TrainedModel(mel_2bar_config, batch_size=4, checkpoint_dir_or_path=BASE_DIR + '/checkpoints/mel_2bar_big.ckpt')

Importing libraries and loading the trained model


INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, CategoricalLstmDecoder, and hparams:
{'max_seq_len': 32, 'z_size': 512, 'free_bits': 0, 'max_beta': 0.5, 'beta_rate': 0.99999, 'batch_size': 4, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_learning_rate': 1e-05, 'conditional': True, 'dec_rnn_size': [2048, 2048, 2048], 'enc_rnn_size': [2048], 'dropout_keep_prob': 1.0, 'sampling_schedule': 'inverse_sigmoid', 'sampling_rate': 1000, 'use_cudnn': False, 'residual_encoder': False, 'residual_decoder': False, 'control_preprocessing_rnn_size': [256]}
INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [2048]

INFO:tensorflow:
Decoder Cells:
  units: [2048, 2048, 2048]

INFO:tensorflow:Restoring parameters from models/download.magenta.tensorflow.org/models/music_vae/checkpoints/mel_2bar_big.ckpt


In [3]:
# --- Configuration ---
DB_PATH = os.path.join('data_sets', 'midi_embeddings.db')
FAISS_INDEX_PATH = os.path.join('data_sets', 'midi_embeddings.index')
MELODY_DIR = os.path.join('data_sets', 'lmd_melodies') # Directory of extracted melodies
EMBEDDING_DIM = 512 # The dimension of your MusicVAE embeddings



In [4]:
def filter_pitch_range(ns, min_pitch=36, max_pitch=84):
    """Removes notes outside the specified MIDI pitch range."""
    valid_notes = [n for n in ns.notes if min_pitch <= n.pitch <= max_pitch]   
    del ns.notes[:]
    ns.notes.extend(valid_notes)   
    return ns

In [5]:
# --- Helper Functions (from previous iterations, still useful for cleaning) ---
def make_monophonic(ns,steps_per_quarter=4):
    """Reduces a NoteSequence to be monophonic by picking the highest note at each step."""
    if not ns.notes:
        return ns
    quantized_ns = sequences_lib.quantize_note_sequence(ns, steps_per_quarter)    
    notes_by_step = {}
    for note in quantized_ns.notes:
        # Use quantized_start_step for already quantized sequences
        if note.quantized_start_step not in notes_by_step:
            notes_by_step[note.quantized_start_step] = []
        notes_by_step[note.quantized_start_step].append(note)
    monophonic_notes = []
    for step in sorted(notes_by_step.keys()):
        notes_at_step = notes_by_step[step]
        # If multiple notes at a step, pick the highest pitch
        highest_note = max(notes_at_step, key=lambda n: n.pitch)
        monophonic_notes.append(highest_note)
    del ns.notes[:]
    ns.notes.extend(monophonic_notes)
    return ns

In [6]:
def snap_chunk_notes_to_grid(unquantized_chunk, steps_per_quarter):
    """
    Creates a new, unquantized NoteSequence with notes snapped to a grid.
    This is the key function. It takes a time-based chunk, finds the ideal
    quantized steps for its notes, and then creates a *new* unquantized
    sequence where the note start/end times correspond perfectly to those steps.
    Args:
      unquantized_chunk: The unquantized NoteSequence chunk.
      steps_per_quarter: The quantization resolution.
    Returns:
      A new, unquantized NoteSequence with grid-aligned note timings.
    """
    # 1. Quantize the chunk to determine the ideal grid steps for each note.
    try:
        quantized_temp_chunk = note_seq.quantize_note_sequence(
            unquantized_chunk, steps_per_quarter)
    except note_seq.BadTimeSignatureError:
        return None # Cannot process this chunk
    qpm = unquantized_chunk.tempos[0].qpm if unquantized_chunk.tempos else 120.0
    seconds_per_quarter = 60.0 / qpm
    # 2. Create a new, empty, unquantized sequence to be the output.
    grid_aligned_ns = music_pb2.NoteSequence()
    grid_aligned_ns.tempos.add().qpm = qpm
    grid_aligned_ns.ticks_per_quarter = unquantized_chunk.ticks_per_quarter
    # 3. For each note in the quantized version, create a new note in our
    #    output sequence with timings calculated from the quantized steps.
    for q_note in quantized_temp_chunk.notes:
        new_note = grid_aligned_ns.notes.add()
        new_note.pitch = q_note.pitch
        new_note.velocity = q_note.velocity
        new_note.instrument = q_note.instrument
        new_note.program = q_note.program
        # Convert quantized steps back into precise seconds
        start_quarters = q_note.quantized_start_step / steps_per_quarter
        end_quarters = q_note.quantized_end_step / steps_per_quarter
        new_note.start_time = start_quarters * seconds_per_quarter
        new_note.end_time = end_quarters * seconds_per_quarter
    # Set the total time of the new sequence.
    total_quarters = quantized_temp_chunk.total_quantized_steps / steps_per_quarter
    grid_aligned_ns.total_time = total_quarters * seconds_per_quarter
    return grid_aligned_ns


In [7]:
def set_program_for_all_notes(note_sequence, program_number=0):
    """
    Resets the instrument program for every note in a NoteSequence.
    Args:
      note_sequence: The note_seq.NoteSequence object to modify.
      program_number: The integer program number to set for all notes.
                      Defaults to 0 (Acoustic Grand Piano).
    Returns:
      The modified NoteSequence.
    """
    for note in note_sequence.notes:
        note.program = program_number
    return note_sequence


In [8]:
def estimate_tempo_from_notes(
    note_sequence: music_pb2.NoteSequence,
    min_bpm: float = 60.0,
    max_bpm: float = 240.0,
    prior_bpm: float = 120.0
) -> float:
    """
    Estimates the tempo of an unquantized NoteSequence by analyzing note onsets.

    Args:
        note_sequence: An unquantized NoteSequence object.
        min_bpm: The minimum plausible tempo to consider.
        max_bpm: The maximum plausible tempo to consider.
        prior_bpm: The tempo to prefer (e.g., 120 BPM). The algorithm will favor
                   candidates closer to this value.

    Returns:
        The estimated tempo in beats per minute (BPM). Returns prior_bpm if
        not enough notes are present to make a guess.
    """
    # 1. Extract unique, sorted note onset times
    onsets = sorted(list(set(note.start_time for note in note_sequence.notes)))

    if len(onsets) < 3:  # Need a reasonable number of notes for a good guess
        print("Warning: Too few notes to reliably estimate tempo. Returning prior.")
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 2. Calculate Inter-Onset Intervals (IOIs)
    iois = np.diff(onsets)
    if len(iois) == 0:
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 3. Build a histogram of IOIs to find the most common intervals
    # We use a small bin size to capture fine timing details
    hist, bin_edges = np.histogram(iois, bins=np.arange(0, 5, 0.01), density=False)
    
    # Find peaks in the histogram. These are our primary rhythmic intervals.
    # A simple way is to get the top N bins.
    peak_indices = np.argsort(hist)[-10:] # Get indices of 10 strongest peaks
    
    tempo_candidates = defaultdict(float)

    # 4. Generate and score tempo candidates from histogram peaks
    for i in peak_indices:
        if hist[i] < 2: # Ignore insignificant peaks
            continue
            
        # The time (in seconds) corresponding to this peak
        interval = bin_edges[i]
        
        # This interval could be a quarter note, eighth note, etc.
        # Generate hypotheses based on this interval.
        for multiple in [0.25, 0.33, 0.5, 1, 2, 3, 4]:
            potential_beat_duration = interval * multiple
            if potential_beat_duration == 0:
                continue
            
            tempo = 60.0 / potential_beat_duration
            
            if min_bpm <= tempo <= max_bpm:
                # 5. Score the candidate
                # Score part 1: Rhythmic Strength (how strong was the peak?)
                strength_score = hist[i]
                
                # Score part 2: Proximity to prior_bpm (Gaussian score)
                # This gives a high score if tempo is near prior_bpm
                proximity_score = np.exp(-0.5 * ((tempo - prior_bpm) / 20.0)**2)
                
                # Combine scores and add to any existing score for this tempo
                combined_score = strength_score * proximity_score
                tempo_candidates[tempo] += combined_score

    if not tempo_candidates:
        print("Warning: Could not find any valid tempo candidates. Returning prior.")
        bpm=prior_bpm
        if note_sequence.tempos:
            bpm = note_sequence.tempos[0].qpm
        return bpm

    # 6. Return the tempo with the highest score
    best_tempo = max(tempo_candidates, key=tempo_candidates.get)
    return best_tempo



In [None]:
def get_embeddings_for_song(track_id: str) -> dict:
    """
    Finds all MIDI files for a given track_id, generates embeddings for each,
    and returns them in a dictionary.

    Args:
        track_id (str): The ID of the track, e.g., "TRAAAGR128F425B14B".
        root_path (str): The root path of the repository.

    Returns:
        A dictionary where keys are MIDI filenames and values are lists of
        numpy array embeddings generated from that MIDI file.
        Returns an empty dictionary if the folder is not found or contains no MIDI files.
    """
    # 1. Construct the folder path from the track_id
    # e.g., 'data_sets/lmd_melodies/A/A/A/TRAAAGR128F425B14B'
    if len(track_id) < 5:
        print(f"Error: track_id '{track_id}' is too short to build a path.")
        return {}
        
    song_folder_path = os.path.join(
        'data_sets',
        'lmd_melodies',
        track_id[2],
        track_id[3],
        track_id[4],
        track_id
    )

    if not os.path.isdir(song_folder_path):
        print(f"Warning: Directory not found at {song_folder_path}")
        return {}

    # 2. Find all MIDI files in the directory
    midi_filepaths = glob.glob(os.path.join(song_folder_path, '*.mid'))
    midi_filepaths.extend(glob.glob(os.path.join(song_folder_path, '*.midi')))

    if not midi_filepaths:
        print(f"Warning: No MIDI files found in {song_folder_path}")
        return {}

    # 3. Process each MIDI file to generate embeddings
    all_embeddings = {}
    
    

    for midi_path in midi_filepaths:
        filename = os.path.basename(midi_path)
        print(f"Processing {filename}...")
        
        try:
            # Load and quantize the MIDI file
            midi_ns = note_seq.midi_file_to_note_sequence(midi_path)
            print(f"Processing {filename}...")
            cleaned_quantized_list = []
            qpm = estimate_tempo_from_notes(midi_ns)
            seconds_per_quarter = 60.0 / qpm
            steps_per_quarter=mel_2bar_config.data_converter._steps_per_quarter 
            seconds_per_step = seconds_per_quarter / steps_per_quarter
            num_steps_per_chunk = mel_2bar_config.hparams.max_seq_len
            hop_size_in_seconds = num_steps_per_chunk * seconds_per_step # 32 / 4 = 8.0 seconds
            cleaned_ms = make_monophonic(midi_ns)
            
            cleaned_ms = snap_chunk_notes_to_grid(cleaned_ms, steps_per_quarter)
            cleaned_ms = set_program_for_all_notes(cleaned_ms, program_number=0)
            
            

            if cleaned_ms.notes:
                slices = sequences_lib.split_note_sequence(
                    note_sequence=cleaned_ms,
                    hop_size_seconds=hop_size_in_seconds  
                )
                for chunk in slices:
                    cleaned_quantized_list.append(chunk)
            embeddings=[]
            
            for chunk in cleaned_quantized_list:
                try:
                    embedding = mel_2bar.encode([chunk])
                    embeddings.append(embedding)
                except NoExtractedExamplesError as e:
                    print(f"Skipping chunk, insufficient note data")
                except MultipleExtractedExamplesError as e:
                    print(f"Skipping chunk, multiple examples extracted")    
                continue
            if(len(embeddings)>0):
                all_embeddings[filename] = embeddings
            

        except Exception as e:
            print(f"  -> Could not process file {filename}. Error: {e}")

    return all_embeddings

# --- Example Usage ---


In [None]:
track_id = "TRAAAGR128F425B14B"
embeddings_dict = get_embeddings_for_song(track_id)
print(f"Generated embeddings for track {track_id}:")
for midi_file, embeddings in embeddings_dict.items():
    print(f"  {midi_file}: {len(embeddings)} embeddings")
    

Processing 1d9d16a9da90c090809c153754823c2b.mid...
Processing 1d9d16a9da90c090809c153754823c2b.mid...
Skipping chunk, multiple examples extracted
Skipping chunk, multiple examples extracted
Skipping chunk, multiple examples extracted
Skipping chunk, multiple examples extracted
Skipping chunk, multiple examples extracted
Skipping chunk, multiple examples extracted
Skipping chunk, multiple examples extracted
Skipping chunk, multiple examples extracted
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, insufficient note data
Skipping chunk, multiple examples extracted
Skipping chunk, multiple examples extracted
Skipping chunk, multiple exam

In [20]:



# --- NEW: Main Orchestrator Function ---
def process_tracks_from_db(
    start_line_to_process: int=0,
    num_tracks_to_process: int = 1,
    db_path: str='data_sets/track_metadata.db'
) -> dict:
    """
    Reads track_ids from an SQLite database, generates embeddings for each,
    and returns a nested dictionary of all embeddings.

    Args:
        num_tracks_to_process (int, optional): The number of tracks to process.
                                               If None, processes all tracks. Defaults to 1.
        db_path (str): Path to the track_metadata.db SQLite file.
       

    Returns:
        A dictionary where keys are track_ids and values are the dictionaries
        returned by get_embeddings_for_song.
    """
    if not os.path.exists(db_path):
        print(f"Error: Database not found at {db_path}")
        return {}

    print(f"Connecting to database: {db_path}")
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # 1. Fetch the track_ids from the database
    query = "SELECT track_id FROM songs" # Assuming the table is named 'songs'
    if num_tracks_to_process is not None and num_tracks_to_process > 0:
        num_tracks_to_query=num_tracks_to_process
        query += f" LIMIT {start_line_to_process}, {num_tracks_to_query}"
    
    print("Fetching track_ids...")
    cursor.execute(query)
    # Fetch all rows and flatten the list of tuples [('id1',), ('id2',)] -> ['id1', 'id2']
    track_ids = [row[0] for row in cursor.fetchall()]
    conn.close()
    
    if not track_ids:
        print("No track_ids found in the database.")
        return {}

    total_tracks = len(track_ids)
    print(f"Found {total_tracks} track_ids to process.")

    # 2. Iterate through track_ids and get embeddings for each
    all_track_embeddings = {}
    for i, track_id in enumerate(track_ids):
        print(f"\n--- Processing track {i + 1}/{total_tracks}: {track_id} ---")
        
        song_embeddings = get_embeddings_for_song(track_id)
        
        if song_embeddings:
            all_track_embeddings[track_id] = song_embeddings
            print(f"-> Success: Found and processed {len(song_embeddings)} MIDI file(s) for this track.")
            
    print("\n--- Processing Complete ---")
    return all_track_embeddings




    



In [None]:


# Example usage of the orchestrator function
all_embeddings = process_tracks_from_db(
    start_line_to_process=0,
    num_tracks_to_process=100)


Connecting to database: data_sets/track_metadata.db
Fetching track_ids...
Found 100 track_ids to process.

--- Processing track 1/100: TRAAAAK128F9318786 ---

--- Processing track 2/100: TRAAAAV128F421A322 ---

--- Processing track 3/100: TRAAAAW128F429D538 ---

--- Processing track 4/100: TRAAAAY128F42A73F0 ---

--- Processing track 5/100: TRAAABD128F429CF47 ---

--- Processing track 6/100: TRAAACN128F9355673 ---

--- Processing track 7/100: TRAAACV128F423E09E ---

--- Processing track 8/100: TRAAADJ128F4287B47 ---

--- Processing track 9/100: TRAAADT12903CCC339 ---

--- Processing track 10/100: TRAAADZ128F9348C2E ---

--- Processing track 11/100: TRAAAEA128F935A30D ---

--- Processing track 12/100: TRAAAED128E0783FAB ---

--- Processing track 13/100: TRAAAEF128F4273421 ---

--- Processing track 14/100: TRAAAEM128F93347B9 ---

--- Processing track 15/100: TRAAAEW128F42930C0 ---

--- Processing track 16/100: TRAAAFD128F92F423A ---

--- Processing track 17/100: TRAAAFI12903CE4F0E ---

-