# Image Clustering with ImageBind LLM Embeddings


## 1. Install and Import Libraries

In [1]:
# Install required libraries
!pip install torch torchvision --quiet
!pip install transformers --quiet
!pip install Pillow --quiet
!pip install umap-learn --quiet
!pip install hdbscan --quiet

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import requests
from io import BytesIO
import os
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.metrics import silhouette_score, adjusted_rand_score, normalized_mutual_info_score
from sklearn.metrics import calinski_harabasz_score, davies_bouldin_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
import warnings
warnings.filterwarnings('ignore')

# Set random seed and plotting style
np.random.seed(42)
plt.style.use('seaborn-v0_8-whitegrid')
print("Base libraries imported successfully!")

Base libraries imported successfully!


In [2]:
# Import PyTorch and vision models
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from transformers import CLIPProcessor, CLIPModel
import umap
import hdbscan

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

Using device: cpu
PyTorch version: 2.9.0+cu126


## 2. Load CLIP Model (Alternative to ImageBind)

Note: ImageBind requires specific setup. We'll use CLIP as it's more readily available and provides similar multimodal embeddings. The concepts and techniques are directly transferable to ImageBind.

In [None]:
# Load CLIP model
print("Loading CLIP model...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = clip_model.to(device)
clip_model.eval()
print("CLIP model loaded successfully!")

Loading CLIP model...


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

## 3. Download Sample Images

In [None]:
# Create image URLs for different categories
# Using placeholder images that represent different categories

# We'll use CIFAR-10 dataset for a reliable image source
print("Loading CIFAR-10 dataset...")

# Download CIFAR-10
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

cifar_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Dataset loaded! Total images: {len(cifar_dataset)}")
print(f"Classes: {class_names}")

In [None]:
# Select a subset of images for clustering (5 classes, 20 images each)
selected_classes = [0, 1, 3, 5, 8]  # airplane, automobile, cat, dog, ship
selected_class_names = [class_names[i] for i in selected_classes]
n_per_class = 20

images = []
labels = []
pil_images = []

# Get original images without transform for CLIP
cifar_raw = datasets.CIFAR10(root='./data', train=False, download=False)

# Collect images
class_counts = {c: 0 for c in selected_classes}

for idx in range(len(cifar_raw)):
    img, label = cifar_raw[idx]
    if label in selected_classes and class_counts[label] < n_per_class:
        pil_images.append(img)
        labels.append(selected_classes.index(label))  # Remap to 0-4
        class_counts[label] += 1

    if all(c >= n_per_class for c in class_counts.values()):
        break

y_true = np.array(labels)
print(f"Selected {len(pil_images)} images from {len(selected_classes)} classes")
print(f"Classes: {selected_class_names}")

In [None]:
# Visualize sample images from each class
fig, axes = plt.subplots(len(selected_classes), 5, figsize=(12, 10))

for class_idx, class_name in enumerate(selected_class_names):
    class_images = [img for img, lbl in zip(pil_images, labels) if lbl == class_idx][:5]

    for img_idx, img in enumerate(class_images):
        axes[class_idx, img_idx].imshow(img)
        axes[class_idx, img_idx].axis('off')
        if img_idx == 0:
            axes[class_idx, img_idx].set_ylabel(class_name, fontsize=12)

plt.suptitle('Sample Images from Each Category', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Generate Image Embeddings with CLIP

In [None]:
def get_clip_image_embeddings(images, model, processor, device, batch_size=32):
    """
    Generate CLIP embeddings for a list of PIL images.
    """
    all_embeddings = []

    for i in range(0, len(images), batch_size):
        batch_images = images[i:i+batch_size]

        # Process images
        inputs = processor(images=batch_images, return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get embeddings
        with torch.no_grad():
            image_features = model.get_image_features(**inputs)

        # Normalize embeddings
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        all_embeddings.append(image_features.cpu().numpy())

    return np.vstack(all_embeddings)

# Generate embeddings
print("Generating CLIP embeddings...")
image_embeddings = get_clip_image_embeddings(pil_images, clip_model, clip_processor, device)
print(f"Embeddings shape: {image_embeddings.shape}")

## 5. Visualize Embeddings

In [None]:
# Reduce dimensions with UMAP for visualization
print("Reducing dimensions with UMAP...")
umap_reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
embeddings_2d = umap_reducer.fit_transform(image_embeddings)
print("UMAP reduction complete!")

In [None]:
# Visualize embeddings colored by true labels
plt.figure(figsize=(12, 8))

colors = ['blue', 'red', 'green', 'orange', 'purple']
for class_idx, class_name in enumerate(selected_class_names):
    mask = y_true == class_idx
    plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                c=colors[class_idx], label=class_name,
                alpha=0.7, s=80, edgecolors='k')

plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.title('CLIP Image Embeddings (UMAP Visualization)', fontsize=14, fontweight='bold')
plt.legend(title='Category')
plt.tight_layout()
plt.show()

## 6. K-Means Clustering

In [None]:
# Apply K-Means clustering
n_clusters = len(selected_classes)

kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
labels_kmeans = kmeans.fit_predict(image_embeddings)

# Evaluate
ari_kmeans = adjusted_rand_score(y_true, labels_kmeans)
nmi_kmeans = normalized_mutual_info_score(y_true, labels_kmeans)
sil_kmeans = silhouette_score(image_embeddings, labels_kmeans)

print("K-Means Clustering Results:")
print("="*50)
print(f"Adjusted Rand Index: {ari_kmeans:.4f}")
print(f"Normalized Mutual Info: {nmi_kmeans:.4f}")
print(f"Silhouette Score: {sil_kmeans:.4f}")

In [None]:
# Visualize K-Means results
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# True labels
for class_idx, class_name in enumerate(selected_class_names):
    mask = y_true == class_idx
    axes[0].scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                    c=colors[class_idx], label=class_name,
                    alpha=0.7, s=80, edgecolors='k')
axes[0].set_xlabel('UMAP Dimension 1')
axes[0].set_ylabel('UMAP Dimension 2')
axes[0].set_title('True Labels')
axes[0].legend()

# K-Means clusters
scatter = axes[1].scatter(embeddings_2d[:, 0], embeddings_2d[:, 1],
                          c=labels_kmeans, cmap='viridis',
                          alpha=0.7, s=80, edgecolors='k')
axes[1].set_xlabel('UMAP Dimension 1')
axes[1].set_ylabel('UMAP Dimension 2')
axes[1].set_title(f'K-Means Clustering\nARI: {ari_kmeans:.3f}')
plt.colorbar(scatter, ax=axes[1], label='Cluster')

plt.suptitle('Image Clustering Results', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 7. Hierarchical Clustering

In [None]:
# Apply Hierarchical Clustering
hclust = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
labels_hclust = hclust.fit_predict(image_embeddings)

# Evaluate
ari_hclust = adjusted_rand_score(y_true, labels_hclust)
nmi_hclust = normalized_mutual_info_score(y_true, labels_hclust)
sil_hclust = silhouette_score(image_embeddings, labels_hclust)

print("Hierarchical Clustering Results:")
print("="*50)
print(f"Adjusted Rand Index: {ari_hclust:.4f}")
print(f"Normalized Mutual Info: {nmi_hclust:.4f}")
print(f"Silhouette Score: {sil_hclust:.4f}")

In [None]:
# Create dendrogram
from scipy.cluster.hierarchy import dendrogram, linkage

Z = linkage(image_embeddings, method='ward')

plt.figure(figsize=(16, 8))

# Create labels with class names
dendrogram_labels = [f"{selected_class_names[l]}_{i}" for i, l in enumerate(y_true)]

dendrogram(Z, labels=dendrogram_labels, leaf_rotation=90, leaf_font_size=6,
           color_threshold=15)
plt.title('Hierarchical Clustering Dendrogram (CLIP Embeddings)', fontsize=14, fontweight='bold')
plt.xlabel('Image')
plt.ylabel('Distance')
plt.tight_layout()
plt.show()

## 8. HDBSCAN Clustering

In [None]:
# Apply HDBSCAN
hdbscan_clusterer = hdbscan.HDBSCAN(min_cluster_size=5, min_samples=3)
labels_hdbscan = hdbscan_clusterer.fit_predict(image_embeddings)

n_clusters_found = len(set(labels_hdbscan)) - (1 if -1 in labels_hdbscan else 0)
n_noise = (labels_hdbscan == -1).sum()

print(f"HDBSCAN found {n_clusters_found} clusters")
print(f"Noise points: {n_noise}")

# Evaluate (excluding noise)
valid_mask = labels_hdbscan >= 0
if valid_mask.sum() > 0 and n_clusters_found > 1:
    ari_hdbscan = adjusted_rand_score(y_true[valid_mask], labels_hdbscan[valid_mask])
    nmi_hdbscan = normalized_mutual_info_score(y_true[valid_mask], labels_hdbscan[valid_mask])
    sil_hdbscan = silhouette_score(image_embeddings[valid_mask], labels_hdbscan[valid_mask])

    print(f"\nAdjusted Rand Index: {ari_hdbscan:.4f}")
    print(f"Normalized Mutual Info: {nmi_hdbscan:.4f}")
    print(f"Silhouette Score: {sil_hdbscan:.4f}")

## 9. Finding Optimal Number of Clusters

In [None]:
# Elbow method and Silhouette analysis
k_range = range(2, 10)
inertias = []
silhouettes = []

for k in k_range:
    kmeans_temp = KMeans(n_clusters=k, random_state=42, n_init=10)
    labels_temp = kmeans_temp.fit_predict(image_embeddings)
    inertias.append(kmeans_temp.inertia_)
    silhouettes.append(silhouette_score(image_embeddings, labels_temp))

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Elbow curve
axes[0].plot(k_range, inertias, 'bo-', linewidth=2, markersize=8)
axes[0].set_xlabel('Number of Clusters (K)')
axes[0].set_ylabel('Inertia')
axes[0].set_title('Elbow Method')
axes[0].axvline(x=5, color='red', linestyle='--', label=f'K={n_clusters} (True)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Silhouette curve
axes[1].plot(k_range, silhouettes, 'go-', linewidth=2, markersize=8)
axes[1].set_xlabel('Number of Clusters (K)')
axes[1].set_ylabel('Silhouette Score')
axes[1].set_title('Silhouette Analysis')
axes[1].axvline(x=5, color='red', linestyle='--', label=f'K={n_clusters} (True)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('Optimal Number of Clusters Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

optimal_k = list(k_range)[np.argmax(silhouettes)]
print(f"Optimal K by Silhouette: {optimal_k}")

## 10. Visualize Cluster Contents

In [None]:
# Show sample images from each cluster
fig, axes = plt.subplots(n_clusters, 5, figsize=(12, 12))

for cluster_idx in range(n_clusters):
    cluster_mask = labels_kmeans == cluster_idx
    cluster_images = [img for img, mask in zip(pil_images, cluster_mask) if mask]
    cluster_true_labels = y_true[cluster_mask]

    # Get majority class
    if len(cluster_true_labels) > 0:
        majority_class = selected_class_names[np.bincount(cluster_true_labels).argmax()]
    else:
        majority_class = "Empty"

    for img_idx in range(min(5, len(cluster_images))):
        axes[cluster_idx, img_idx].imshow(cluster_images[img_idx])
        axes[cluster_idx, img_idx].axis('off')

    # Fill remaining slots if less than 5 images
    for img_idx in range(len(cluster_images), 5):
        axes[cluster_idx, img_idx].axis('off')

    axes[cluster_idx, 0].set_ylabel(f'Cluster {cluster_idx}\n({majority_class})',
                                     fontsize=10, rotation=0, ha='right', va='center')

plt.suptitle('Sample Images from Each Cluster (K-Means)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 11. Confusion Matrix

In [None]:
# Create confusion matrix
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_true, labels_kmeans)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=range(n_clusters), yticklabels=selected_class_names)
plt.xlabel('Predicted Cluster')
plt.ylabel('True Category')
plt.title('Confusion Matrix: Categories vs Clusters', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 12. Image Similarity Analysis

In [None]:
# Compute cosine similarity matrix
from sklearn.metrics.pairwise import cosine_similarity

similarity_matrix = cosine_similarity(image_embeddings)

# Sort by true labels for better visualization
sorted_indices = np.argsort(y_true)
sorted_similarity = similarity_matrix[sorted_indices][:, sorted_indices]

plt.figure(figsize=(12, 10))
sns.heatmap(sorted_similarity, cmap='RdYlBu_r', vmin=0, vmax=1)

# Add category separators
n_per = n_per_class
for i in range(1, n_clusters):
    plt.axhline(y=i*n_per, color='black', linewidth=2)
    plt.axvline(x=i*n_per, color='black', linewidth=2)

plt.title('Image Similarity Matrix (CLIP Embeddings)', fontsize=14, fontweight='bold')
plt.xlabel('Image (sorted by category)')
plt.ylabel('Image (sorted by category)')
plt.tight_layout()
plt.show()

print("Diagonal blocks show high similarity within categories.")

## 13. Results Summary

In [None]:
# Comprehensive results summary
results = pd.DataFrame([
    {'Method': 'K-Means', 'ARI': ari_kmeans, 'NMI': nmi_kmeans, 'Silhouette': sil_kmeans},
    {'Method': 'Hierarchical', 'ARI': ari_hclust, 'NMI': nmi_hclust, 'Silhouette': sil_hclust},
])

if valid_mask.sum() > 0 and n_clusters_found > 1:
    results = pd.concat([results, pd.DataFrame([{
        'Method': 'HDBSCAN', 'ARI': ari_hdbscan, 'NMI': nmi_hdbscan, 'Silhouette': sil_hdbscan
    }])], ignore_index=True)

print("="*70)
print("IMAGE CLUSTERING RESULTS SUMMARY")
print("="*70)
print("\nAll Methods Comparison:")
print(results.to_string(index=False))

# Best method
best_method = results.loc[results['ARI'].idxmax()]
print(f"\nBest Method: {best_method['Method']}")
print(f"ARI: {best_method['ARI']:.4f}, NMI: {best_method['NMI']:.4f}")

In [None]:
# Visualize results comparison
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(results))
width = 0.25

bars1 = ax.bar(x - width, results['ARI'], width, label='ARI', color='steelblue')
bars2 = ax.bar(x, results['NMI'], width, label='NMI', color='coral')
bars3 = ax.bar(x + width, results['Silhouette'], width, label='Silhouette', color='green')

ax.set_xlabel('Clustering Method')
ax.set_ylabel('Score')
ax.set_title('Image Clustering Methods Comparison', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(results['Method'])
ax.legend()
ax.set_ylim([0, 1])
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

In [None]:
print("="*70)
print("IMAGE CLUSTERING WITH LLM EMBEDDINGS - SUMMARY")
print("="*70)

print("\n1. EMBEDDING MODEL:")
print("   - CLIP (clip-vit-base-patch32): Vision-language model")
print("   - Produces 512-dimensional embeddings")
print("   - Note: ImageBind follows similar principles for multimodal embeddings")

print("\n2. CLUSTERING ALGORITHMS:")
print("   - K-Means: Partition-based clustering")
print("   - Hierarchical: Agglomerative with Ward linkage")
print("   - HDBSCAN: Density-based clustering")

print("\n3. QUALITY METRICS:")
print("   - Adjusted Rand Index (ARI): Agreement with ground truth")
print("   - Normalized Mutual Information (NMI): Information overlap")
print("   - Silhouette Score: Cluster cohesion and separation")

print("\n4. DATASET:")
print(f"   - CIFAR-10 subset: {len(pil_images)} images")
print(f"   - {len(selected_classes)} classes: {', '.join(selected_class_names)}")

print("\n5. KEY FINDINGS:")
print(f"   - Best method: {best_method['Method']} (ARI: {best_method['ARI']:.4f})")
print("   - CLIP embeddings effectively capture visual semantics")
print("   - Similar images cluster together in embedding space")
print("   - Hierarchical clustering provides interpretable structure")

print("\n6. APPLICATIONS:")
print("   - Image organization and retrieval")
print("   - Visual content recommendation")
print("   - Duplicate detection")
print("   - Dataset curation and cleaning")

print("\n" + "="*70)