# Similarity Search Demo
Similarity Search demo through audio-to-audio as well as the text-to-audio search.

### Load the model and tokenizer

In [1]:
import torch

from src.modules.clap_model import CLAPModel
from transformers import RobertaTokenizer

model = CLAPModel.from_pretrained("yuhuacheng/clap-musicgen-1sec")
tokenizer = RobertaTokenizer.from_pretrained("yuhuacheng/clap-roberta-finetuned")

def get_device():
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print('using evice: ', device)
    return device

device = get_device()
model.to(device);

  from .autonotebook import tqdm as notebook_tqdm
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)
Some weights of RobertaModel were not initialized from the model checkpoint at model/roberta_finetuned and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


using evice:  mps


### Load the eval data

In [2]:
import pandas as pd

data = pd.read_csv("data/train_10000_split.csv")
# eval_data = data[data['split'] == 'eval'].sample(100, random_state=42)
eval_data = data[data['split'] == 'eval']

print(eval_data.shape)
eval_data.head()


(1049, 14)


Unnamed: 0,id,video_url,audio_url,image_url,major_model_version,model_name,tags,prompt,lyrics,is_en,genres,top_genres,caption,split
24,0039ab46-7ed0-4f0d-ab4f-29188497cc2c,https://cdn1.suno.ai/0039ab46-7ed0-4f0d-ab4f-2...,https://cdn1.suno.ai/0039ab46-7ed0-4f0d-ab4f-2...,https://cdn2.suno.ai/image_0039ab46-7ed0-4f0d-...,v3.5,chirp-v3,electric fast-paced rock,a background song for car chasing scene,[Verse]\nRev the engines hear the roar\nMetal ...,True,['rock'],['rock'],The music with styles or genres of electric fa...,eval
33,0047cdda-d7fe-4b5c-b675-63f9caa1e7da,https://cdn1.suno.ai/0047cdda-d7fe-4b5c-b675-6...,https://cdn1.suno.ai/0047cdda-d7fe-4b5c-b675-6...,https://cdn2.suno.ai/image_0047cdda-d7fe-4b5c-...,v3.5,chirp-v3,pop hip hop female voice,"Hip hop, pop song about liking apples, female ...",[Verse]\nGot a shiny red delight\nIn my hand i...,True,['pop'],['pop'],The music with styles or genres of pop hip hop...,eval
53,006d173d-dba9-4811-aaad-caa3662d48a3,https://cdn1.suno.ai/006d173d-dba9-4811-aaad-c...,https://cdn1.suno.ai/006d173d-dba9-4811-aaad-c...,https://cdn2.suno.ai/image_006d173d-dba9-4811-...,v3.5,chirp-v3,collaborative hip-hop,I want this song to have rhyme and a rap that ...,"[Verse]\nStep into the ring, mics swing, hands...",True,['hip-hop'],['hip-hop/rap'],The music with styles or genres of collaborati...,eval
60,00831752-ac48-4201-afc8-d15d2ab3b2fc,https://cdn1.suno.ai/00831752-ac48-4201-afc8-d...,https://cdn1.suno.ai/00831752-ac48-4201-afc8-d...,https://cdn2.suno.ai/image_00831752-ac48-4201-...,v3.5,chirp-v3,melodic acoustic country,a country song about tommy pace having finger ...,[Verse]\nIn a town where the shadows loom long...,True,['country'],['country'],The music with styles or genres of melodic aco...,eval
71,00986d4e-7e3d-408a-9d15-13c3dcf33e13,https://cdn1.suno.ai/00986d4e-7e3d-408a-9d15-1...,https://cdn1.suno.ai/00986d4e-7e3d-408a-9d15-1...,https://cdn2.suno.ai/image_00986d4e-7e3d-408a-...,v3.5,chirp-v3,smooth electronic edm,A smooth edm song about a cozy rainy day,[Verse]\nGlistening on the window pane\nRaindr...,True,['electronic'],['electronic'],The music with styles or genres of smooth elec...,eval


### Load the waveforms and produce a dataset list

In [3]:
from src.utils import parallel_download
from src.utils import parallel_load_audio
from src.preprocessor import AudioPreprocessor

audio_dir = 'data/audios'
max_workers = 4
parallel_download(eval_data, audio_dir, max_workers=max_workers)

ap = AudioPreprocessor(
    resample_rate=32000, # sample rate configured for the pretrained EnCodec from MusicGen model
    to_mono=True,
    sec_to_sample=1,
    start_sec=10,
    chunk_duration=1,
)

all_dataset_list = parallel_load_audio(
    train_data=eval_data,
    audio_dir=audio_dir,
    ap=ap,
    max_workers=max_workers
)

print(f"Total TrainingSample items: {len(all_dataset_list)}")

100%|██████████| 1049/1049 [00:00<00:00, 141696.72it/s]


All downloads complete!


100%|██████████| 1049/1049 [00:35<00:00, 29.94it/s]

Total TrainingSample items: 1049





###

## Similarity Search

In [4]:
from IPython.core.display import display, HTML

def generate_iframe_table(source_id, derived_ids):
    # Start the table with a header row
    html_code = """
    <table border="1" cellspacing="5" cellpadding="5">
        <tr>
            <th>Source</th>
    """

    # Add headers dynamically for each derived ID
    for i in range(len(derived_ids)):
        html_code += f"<th>Top {i+1}</th>"
    
    html_code += "</tr>\n"

    # Add the source row
    html_code += f"""
        <tr>
            <td><iframe src="https://suno.com/embed/{source_id}" width="400" height="200"></iframe></td>
    """

    # Add iframes for derived IDs
    for derived_id in derived_ids:
        html_code += f'<td><iframe src="https://suno.com/embed/{derived_id}" width="400" height="200"></iframe></td>'
    
    html_code += "</tr>\n</table>"

    # Display the table in Jupyter Notebook
    display(HTML(html_code))

def generate_iframe_table_with_source_tag(source_tag, track_ids):
    # Start the table with a header row
    html_code = """
    <table border="1" cellspacing="5" cellpadding="5">
        <tr>
            <th style="width: 200px;">Source</th>
    """

    # Add headers dynamically for each track ID
    for i in range(len(track_ids)):
        html_code += f"<th>Top {i+1}</th>"
    
    html_code += "</tr>\n"

    # Add the source tag as text (instead of an iframe)
    html_code += f"""
        <tr>
            <td style="width: 200px;"><strong>{source_tag}</strong></td>
    """

    # Add iframes for the track IDs
    for track_id in track_ids:
        html_code += f'<td><iframe src="https://suno.com/embed/{track_id}" width="400" height="200"></iframe></td>'
    
    html_code += "</tr>\n</table>"

    # Display the table in Jupyter Notebook
    display(HTML(html_code))


  from IPython.core.display import display, HTML


### 🎵 **Audio-to-Audio Search**

In [5]:
import torch
import torch.nn.functional as F

from typing import List, Dict

from src.utils import TrainingSample

def compute_audio_embeddings(dataset_list: List[TrainingSample], model: CLAPModel, device: str, batch_size: int = 8) -> Dict[str, torch.Tensor]:
    audio_embeddings = {}
    model.eval()
    
    # Group samples by track_id
    track_batches = {}
    for d in dataset_list:
        track_id = d.id
        if track_id not in track_batches:
            track_batches[track_id] = []
        track_batches[track_id].append(d.waveform.unsqueeze(0))  # Add batch dimension
    
    for track_id, waveforms in track_batches.items():
        waveforms = torch.cat(waveforms, dim=0).to(device)  # Stack all waveforms for this track_id
        
        batched_embeddings = []
        with torch.no_grad():
            for i in range(0, waveforms.size(0), batch_size):
                batch_waveforms = waveforms[i:i+batch_size]
                batch_embeddings = model.audio_encoder([track_id] * batch_waveforms.size(0), batch_waveforms)  # Replicate track_id
                batched_embeddings.append(batch_embeddings)
        
        # Concatenate all batches for this track_id
        audio_embeddings[track_id] = torch.cat(batched_embeddings, dim=0)
    
    return audio_embeddings


def compute_top_k_similar_tracks(audio_embeddings: Dict[str, torch.Tensor], top_k: int = 5) -> Dict[str, List[str]]:
    track_ids = list(audio_embeddings.keys())
    embeddings = torch.stack([audio_embeddings[tid].mean(dim=0) for tid in track_ids])  # Compute mean embedding per track
    
    # Normalize embeddings for cosine similarity
    embeddings = F.normalize(embeddings, p=2, dim=1)
    
    # Compute cosine similarity matrix
    similarity_matrix = torch.mm(embeddings, embeddings.T)
    
    top_k_similar_tracks = {}
    for idx, track_id in enumerate(track_ids):
        similarities = similarity_matrix[idx]
        top_k_indices = similarities.topk(top_k + 1).indices[1:].tolist()  # Exclude self-similarity
        top_k_similar_tracks[track_id] = [track_ids[i] for i in top_k_indices]
    
    return top_k_similar_tracks


audio_embeddings = compute_audio_embeddings(all_dataset_list, model, device, batch_size=16)
tok_k_audio_to_audio = compute_top_k_similar_tracks(audio_embeddings, top_k=3)

### Example 1 - Source Tags: "melodic acoustic country"

In [6]:
id = "00831752-ac48-4201-afc8-d15d2ab3b2fc"

generate_iframe_table(id, tok_k_audio_to_audio[id])

Source,Top 1,Top 2,Top 3
,,,


### Example 2 - Source Tags: "soft mellow lofi jazz hop"

In [7]:
id = "0a4b6d10-2c82-4937-a31a-8c9158236dad"

generate_iframe_table(id, tok_k_audio_to_audio[id])

Source,Top 1,Top 2,Top 3
,,,


In [8]:
id = "09f12dd0-fa1d-4efa-ba89-25af6009f08c"

generate_iframe_table(id, tok_k_audio_to_audio[id])

Source,Top 1,Top 2,Top 3
,,,


### 💬 **Text-to-Audio Search**

In [9]:
# ==== looking for similar audio from text embeddings ====
sample_captions = [
    'positive jazz',
    'chill house',
    'gangsta rap',
    'dark metal'
    # try it with your own captions!
]

with torch.no_grad():
    tokenized_captions = tokenizer(list(sample_captions), return_tensors="pt", padding=True, truncation=True)
    tokenized_captions = {k: v.to(device) for k, v in tokenized_captions.items()}
    sample_text_embs = model.text_encoder(ids=None, **tokenized_captions)

In [10]:
import torch

def find_top_k_similar_audio(text_embeddings, text_list, audio_dict, k=5):
    # Stack all audio embeddings into a single tensor of shape (num_audio, 1024)
    audio_ids = list(audio_dict.keys())
    audio_embeddings = torch.cat([audio_dict[a_id] for a_id in audio_ids], dim=0)  # Shape (num_audio, 1024)

    # Normalize text and audio embeddings to unit vectors
    text_embeddings = torch.nn.functional.normalize(text_embeddings, p=2, dim=1)  # (B, 1024)
    audio_embeddings = torch.nn.functional.normalize(audio_embeddings, p=2, dim=1)  # (num_audio, 1024)

    # Compute cosine similarity: (B, 1024) @ (1024, num_audio) -> (B, num_audio)
    similarity_matrix = torch.matmul(text_embeddings, audio_embeddings.T)

    # Get top k indices for each text example
    top_k_indices = torch.topk(similarity_matrix, k, dim=1).indices  # (B, k)

    # Map results to dictionary
    text_to_audio_mapping = {
        text_list[i]: [audio_ids[idx] for idx in top_k_indices[i].tolist()]
        for i in range(len(text_list))
    }

    return text_to_audio_mapping

tok_k_text_to_audio = find_top_k_similar_audio(sample_text_embs, sample_captions, audio_embeddings, 3)
for source_tags in sample_captions:
    generate_iframe_table_with_source_tag(source_tags, tok_k_text_to_audio[source_tags])

Source,Top 1,Top 2,Top 3
positive jazz,,,


Source,Top 1,Top 2,Top 3
chill house,,,


Source,Top 1,Top 2,Top 3
gangsta rap,,,


Source,Top 1,Top 2,Top 3
dark metal,,,
