In [3]:
import os
import shutil
import numpy as np
import torch
from facenet_pytorch import InceptionResnetV1
import torchvision.models as models
import torchvision.transforms as transforms
from transformers import pipeline, AutoTokenizer, AutoModel
from PIL import Image
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import glob
import pickle

class ImageClusteringAlgorithm:
    def __init__(self):
        self.facenet_model = InceptionResnetV1(pretrained="vggface2").eval()
        self.resnet_model = models.resnet18(pretrained=True).eval()
        self.caption_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
        self.language_model = AutoModel.from_pretrained("distilbert-base-uncased")
        self.language_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        # Define the transformation layers to make the shapes of the output of facenet, resnet, and language models the same
        self.face_transform = torch.nn.Linear(512, 1000)
        self.caption_transform = torch.nn.Linear(768, 1000)        

    def compute_features(self, image_directory):
        self.image_directory = os.path.expanduser(image_directory)
        file_extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*tif")
        self.image_paths = []
        for extension in tqdm(file_extensions, "Loading image paths"):
            self.image_paths.extend(glob.glob(os.path.join(self.image_directory, extension)))
        features = []
        for image_path in tqdm(self.image_paths, desc="Computing features"):
            feature = self._generate_image_features(image_path)
            features.append(feature)
        self.features = np.vstack(features)

    def compute_clusters(self, num_clusters):
        self.num_clusters = num_clusters
        print(f"Clustering images into {self.num_clusters} clusters...")
        kmeans = KMeans(n_clusters=self.num_clusters, random_state=42)
        self.labels = kmeans.fit_predict(self.features)
        print("Computing silhouette score...")
        silhouette_avg = silhouette_score(self.features, self.labels)
        print(f"Silhouette score: {silhouette_avg:.4f}")
        self.clusters = [[] for _ in range(self.labels.max() + 1)]
        for idx, label in enumerate(self.labels):
            self.clusters[label].append(self.image_paths[idx])

    def save_clustered_images(self, output_directory):
        output_directory = os.path.expanduser(output_directory)
        os.makedirs(output_directory, exist_ok=True)
        for cluster_id in range(self.num_clusters):
            cluster_dir = os.path.join(output_directory, f"cluster_{cluster_id}")
            os.makedirs(cluster_dir, exist_ok=True)
            for image_path in self.clusters[cluster_id]:
                new_filename = os.path.join(cluster_dir, os.path.basename(image_path))
                shutil.copyfile(image_path, new_filename)

    def _generate_image_features(self, image_path):
        image = Image.open(image_path).convert("RGB")
        transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])
        image_tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            resnet_features = self.resnet_model(image_tensor)
            face_features = self.face_transform(self.facenet_model(image_tensor))
            caption_features = self.caption_transform(self._generate_caption(image_path))
        features = np.concatenate((face_features.detach().numpy(), resnet_features.detach().numpy(), caption_features.detach().numpy()))
        return features

    def _generate_caption(self, image_path):
        caption = self.caption_pipeline(image_path)[0]["generated_text"]
        tokens = self.language_tokenizer.encode(caption, return_tensors="pt")
        vector = self.language_model(tokens).last_hidden_state.mean(dim=1)
        return vector
    
    def save_features(self, file_name):
        data = { "features": self.features, "image_paths": self.image_paths, "image_directory": self.image_directory }
        with open(file_name, "wb") as file: 
            pickle.dump(data, file)
    
    def load_features(self, file_name):
        with open(file_name, "b") as file:
            data = pickle.load(file)

        ica = ImageClusteringAlgorithm()
        ica.features = data["features"]
        ica.image_paths = data["image_paths"]
        ica.image_directory = data["image_directory"]

    def save_clusters(self, file_name):
        data = { "num_clusters": self.num_clusters, "clusters": self.clusters }
        with open(file_name, "wb") as file: 
            pickle.dump(data, file)

    def load_features(self, file_name):
        with open(file_name, "b") as file:
            data = pickle.load(file)

        ica = ImageClusteringAlgorithm()
        ica.num_clusters = data["clusters"]
        ica.clusters = data["clusters"]

In [4]:
ica = ImageClusteringAlgorithm()
image_directory = "~/Downloads/JPGs"
features_file = "~/Downloads/JPGs/ImageClusteringAlgorithmFeatures.pkl"
clusters_file = "~/Downloads/JPGs/ImageClusteringAlgorithmClusters.pkl"
if not os.path.exists(features_file):
    ica.compute_features(image_directory)
    ica.save_features(features_file)
else:
    ica.load_features(features_file)
if not os.path.exists(clusters_file):
    num_clusters = 50
    ica.compute_clusters(num_clusters)
    ica.save_clusters(clusters_file)
else:
    ica.load_clusters(clusters_file)
ica.save_clustered_images("~/Downloads/ClusteredJPGs")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Loading image paths: 100%|██████████| 6/6 [00:00<00:00, 106.77it/s]
Computing features:   1%|          | 75/8617 [02:14<4:21:37,  1.84s/it]