<a href="https://colab.research.google.com/github/samipn/clustering_demos/blob/main/audio_clustering_imagebind.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment (i): Audio Clustering with ImageBind Embeddings

This notebook uses ImageBind to extract audio embeddings, clusters them with K-Means, and evaluates clustering quality.

> **Note:** You must provide audio files in `/content/audio` (or adjust the path) when running in Colab.


In [8]:
!pip install --quiet git+https://github.com/facebookresearch/ImageBind.git
!pip install --quiet timm einops soundfile librosa torchcodec

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import torch
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind.data import load_and_transform_audio_data

device = "cuda" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)


Downloading imagebind weights to .checkpoints/imagebind_huge.pth ...


100%|██████████| 4.47G/4.47G [00:17<00:00, 282MB/s]


ImageBindModel(
  (modality_preprocessors): ModuleDict(
    (vision): RGBDTPreprocessor(
      (cls_token): tensor((1, 1, 1280), requires_grad=True)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Sequential(
          (0): PadIm2Video()
          (1): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
        )
      )
      (pos_embedding_helper): SpatioTemporalPosEmbeddingHelper(
        (pos_embed): tensor((1, 257, 1280), requires_grad=True)
        
      )
    )
    (text): TextPreprocessor(
      (pos_embed): tensor((1, 77, 1024), requires_grad=True)
      (mask): tensor((77, 77), requires_grad=False)
      
      (token_embedding): Embedding(49408, 1024)
    )
    (audio): AudioPreprocessor(
      (cls_token): tensor((1, 1, 768), requires_grad=True)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10), bias=False)
        (norm_layer): LayerNorm((768,), eps=1e-05, elementwise_affine=

In [4]:
# Load audio files from folder and compute embeddings
audio_folder = "/content/audio"  # TODO: put your audio files here (or mount Google Drive)

# Create the audio folder if it doesn't exist
if not os.path.exists(audio_folder):
    os.makedirs(audio_folder)
    print(f"Created directory: {audio_folder}. Please upload your audio files here.")

audio_paths = [
    os.path.join(audio_folder, f)
    for f in os.listdir(audio_folder)
    if f.lower().endswith(('.wav', '.mp3', '.flac'))
]

print("Found audio files:")
for p in audio_paths:
    print(p)

# If no audio files are found, provide a message and exit gracefully
if not audio_paths:
    print(f"No audio files found in {audio_folder}. Please upload your audio files (.wav, .mp3, .flac) to this directory.")
else:
    audio_inputs = load_and_transform_audio_data(audio_paths, device=device)

    with torch.no_grad():
        embeddings_dict = model({ModalityType.AUDIO: audio_inputs})
    audio_embeddings = embeddings_dict[ModalityType.AUDIO].cpu().numpy()
    print("Audio embeddings shape:", audio_embeddings.shape)


Created directory: /content/audio. Please upload your audio files here.
Found audio files:
No audio files found in /content/audio. Please upload your audio files (.wav, .mp3, .flac) to this directory.


In [9]:
import requests
import os

audios_to_download = {
    'sample_audio_1.mp3': 'https://file-examples.com/storage/fe39414f4963503b17c7625/2017/11/file_example_MP3_700KB.mp3',
    'sample_audio_2.mp3': 'https://file-examples.com/storage/fe39414f4963503b17c7625/2017/11/file_example_MP3_1MG.mp3'
}

for filename, url in audios_to_download.items():
    filepath = os.path.join(audio_folder, filename)
    if not os.path.exists(filepath):
        print(f"Downloading {filename}...")
        try:
            response = requests.get(url, stream=True)
            response.raise_for_status() # Raise an exception for bad status codes
            with open(filepath, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            print(f"Downloaded {filename} to {filepath}")
        except requests.exceptions.RequestException as e:
            print(f"Error downloading {filename}: {e}")
    else:
        print(f"{filename} already exists at {filepath}")

# Re-run the audio file detection after downloading
audio_paths = [
    os.path.join(audio_folder, f)
    for f in os.listdir(audio_folder)
    if f.lower().endswith(('.wav', '.mp3', '.flac'))
]

print("Found audio files after download attempt:")
for p in audio_paths:
    print(p)

# Proceed with embeddings if files are found
if audio_paths:
    audio_inputs = load_and_transform_audio_data(audio_paths, device=device)

    with torch.no_grad():
        embeddings_dict = model({ModalityType.AUDIO: audio_inputs})
    audio_embeddings = embeddings_dict[ModalityType.AUDIO].cpu().numpy()
    print("Audio embeddings shape:", audio_embeddings.shape)
else:
    print("No audio files found even after download attempt. Please check the URLs or upload manually.")

Downloading sample_audio_1.mp3...
Error downloading sample_audio_1.mp3: 403 Client Error: Forbidden for url: https://file-examples.com/storage/fe39414f4963503b17c7625/2017/11/file_example_MP3_700KB.mp3
Downloading sample_audio_2.mp3...
Error downloading sample_audio_2.mp3: 403 Client Error: Forbidden for url: https://file-examples.com/storage/fe39414f4963503b17c7625/2017/11/file_example_MP3_1MG.mp3
Found audio files after download attempt:
/content/audio/song1.mp3
/content/audio/song16.mp3
Audio embeddings shape: (2, 1024)


In [13]:
# Cluster audio embeddings & evaluate
num_clusters = 2  # Adjusted to be less than or equal to the number of samples (2)
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
labels = kmeans.fit_predict(audio_embeddings)

# Calculate silhouette score only if it's meaningful (i.e., 1 < num_clusters < n_samples)
sil = silhouette_score(audio_embeddings, labels) if (num_clusters > 1 and num_clusters < audio_embeddings.shape[0]) else float("nan")
print("Silhouette score:", sil)

for path, label in zip(audio_paths, labels):
    print(f"Audio: {os.path.basename(path)} -> Cluster {label}")

Silhouette score: nan
Audio: song1.mp3 -> Cluster 0
Audio: song16.mp3 -> Cluster 1
