<a href="https://colab.research.google.com/github/samipn/clustering_demos/blob/main/image_clustering_imagebind.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment (h): Image Clustering with ImageBind Embeddings

This notebook shows how to use Meta's ImageBind model to obtain image embeddings, cluster them with K-Means, and evaluate clustering quality.


In [1]:
!pip install --quiet git+https://github.com/facebookresearch/ImageBind.git
!pip install --quiet timm einops


  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m63.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m55.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m 

In [2]:
import os
import torch
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind.data import load_and_transform_vision_data

device = "cuda" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)


Downloading imagebind weights to .checkpoints/imagebind_huge.pth ...


100%|██████████| 4.47G/4.47G [00:29<00:00, 160MB/s]


ImageBindModel(
  (modality_preprocessors): ModuleDict(
    (vision): RGBDTPreprocessor(
      (cls_token): tensor((1, 1, 1280), requires_grad=True)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Sequential(
          (0): PadIm2Video()
          (1): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
        )
      )
      (pos_embedding_helper): SpatioTemporalPosEmbeddingHelper(
        (pos_embed): tensor((1, 257, 1280), requires_grad=True)
        
      )
    )
    (text): TextPreprocessor(
      (pos_embed): tensor((1, 77, 1024), requires_grad=True)
      (mask): tensor((77, 77), requires_grad=False)
      
      (token_embedding): Embedding(49408, 1024)
    )
    (audio): AudioPreprocessor(
      (cls_token): tensor((1, 1, 768), requires_grad=True)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10), bias=False)
        (norm_layer): LayerNorm((768,), eps=1e-05, elementwise_affine=

In [13]:
# Load images from folder and compute embeddings
image_folder = "/content/images"  # TODO: put your images here (or mount Google Drive)

# Create the image folder if it doesn't exist
if not os.path.exists(image_folder):
    os.makedirs(image_folder)
    print(f"Created directory: {image_folder}")

# Clear existing images in the folder to ensure fresh download of samples
for f in os.listdir(image_folder):
    file_path = os.path.join(image_folder, f);
    if os.path.isfile(file_path):
        os.remove(file_path);

image_paths = [
    os.path.join(image_folder, f)
    for f in os.listdir(image_folder)
    if f.lower().endswith((".png", ".jpg", ".jpeg"))
];

# If no images are found (or if cleared), download some sample images
if not image_paths:
    print("No images found. Downloading sample images...");
    # Updated sample image URLs with highly reliable direct links
    sample_image_urls = [
        "https://images.dog.ceo/breeds/puggle/IMG_069300.jpg", # Dog
        "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg", # Cat
        "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6f/Nissan_GT-R_R35_IMG_1537.jpg/1280px-Nissan_GT-R_R35_IMG_1537.jpg" # Car
    ];
    for url in sample_image_urls:
        !wget -P {image_folder} {url}

    # Re-scan for image paths after downloading
    image_paths = [
        os.path.join(image_folder, f)
        for f in os.listdir(image_folder)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ];


print("Found images:");
for p in image_paths:
    print(p);

vision_inputs = load_and_transform_vision_data(image_paths, device=device);

with torch.no_grad():
    embeddings_dict = model({ModalityType.VISION: vision_inputs});
image_embeddings = embeddings_dict[ModalityType.VISION].cpu().numpy();
print("Image embeddings shape:", image_embeddings.shape);


No images found. Downloading sample images...
--2025-12-02 00:24:56--  https://images.dog.ceo/breeds/puggle/IMG_069300.jpg
Resolving images.dog.ceo (images.dog.ceo)... 104.21.17.246, 172.67.178.228, 2606:4700:3034::6815:11f6, ...
Connecting to images.dog.ceo (images.dog.ceo)|104.21.17.246|:443... connected.
HTTP request sent, awaiting response... 404 Not Found
2025-12-02 00:24:56 ERROR 404: Not Found.

--2025-12-02 00:24:56--  https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg
Resolving upload.wikimedia.org (upload.wikimedia.org)... 208.80.153.240, 2620:0:860:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|208.80.153.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 215264 (210K) [image/jpeg]
Saving to: ‘/content/images/1200px-Cat03.jpg’


2025-12-02 00:24:57 (2.21 MB/s) - ‘/content/images/1200px-Cat03.jpg’ saved [215264/215264]

--2025-12-02 00:24:57--  https://upload.wikimedia.org/wikipedia/commons/thum

In [14]:
# Cluster image embeddings & evaluate
num_clusters = 1  # adjust based on your dataset
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
labels = kmeans.fit_predict(image_embeddings)

sil = silhouette_score(image_embeddings, labels) if num_clusters > 1 else float("nan")
print("Silhouette score:", sil)

for path, label in zip(image_paths, labels):
    print(f"Image: {os.path.basename(path)} -> Cluster {label}")


Silhouette score: nan
Image: 1200px-Cat03.jpg -> Cluster 0
