In [None]:
import os
import shutil
import numpy as np
import torch
from facenet_pytorch import InceptionResnetV1
import torchvision.models as models
import torchvision.transforms as transforms
from transformers import pipeline, AutoTokenizer, AutoModel
from PIL import Image
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import glob
import pickle

version = "0.1"
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

class ImageClusteringAlgorithm:
    def __init__(self):
        self.facenet_model = InceptionResnetV1(pretrained="vggface2").eval().to(device)
        self.resnet_model = models.resnet18(pretrained=True).eval().to(device)
        self.caption_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large", device=device)
        self.language_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
        self.language_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        # Define the transformation layers to make the shapes of the output of facenet, resnet, and language models the same
        self.face_transform = torch.nn.Linear(512, 1000).to(device)
        self.caption_transform = torch.nn.Linear(768, 1000).to(device)

    def compute_features(self, image_directory, use_facenet=True, use_resnet=True, use_caption=True):
        self.image_directory = os.path.expanduser(image_directory)
        file_extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif")
        self.image_paths = []
        for extension in tqdm(file_extensions, "Loading image paths..."):
            self.image_paths.extend(glob.glob(os.path.join(self.image_directory, extension)))
        features = []
        for image_path in tqdm(self.image_paths, desc="Computing features..."):
            feature = self._generate_image_features(image_path, use_facenet, use_resnet, use_caption)
            features.append(feature)
        self.features = np.vstack(features)

    def compute_clusters(self, num_clusters, use_facenet=True, use_resnet=True, use_caption=True):
        self.num_clusters = num_clusters
        selected_features = self._select_features(self.features, use_facenet, use_resnet, use_caption)
        print(f"Clustering images into {self.num_clusters} clusters...")
        kmeans = KMeans(n_clusters=self.num_clusters, random_state=42)
        self.labels = kmeans.fit_predict(selected_features)
        print("Computing silhouette score...")
        silhouette_avg = silhouette_score(selected_features, self.labels)
        print(f"Silhouette score: {silhouette_avg:.4f}")
        self.clusters = [[] for _ in range(self.labels.max() + 1)]
        for idx, label in enumerate(self.labels):
            self.clusters[label].append(self.image_paths[idx])

    def save_clustered_images(self, output_directory):
        output_directory = os.path.expanduser(output_directory)
        os.makedirs(output_directory, exist_ok=True)
        for cluster_id in tqdm(range(self.num_clusters), "Saving clusters..."):
            cluster_dir = os.path.join(output_directory, f"cluster_{cluster_id}")
            os.makedirs(cluster_dir, exist_ok=True)
            for image_path in self.clusters[cluster_id]:
                new_filename = os.path.join(cluster_dir, os.path.basename(image_path))
                shutil.copyfile(image_path, new_filename)

    def _generate_image_features(self, image_path, use_facenet, use_resnet, use_caption):
        image = Image.open(image_path).convert("RGB")
        transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])
        image_tensor = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            resnet_features = self.resnet_model(image_tensor)
            face_features = self.face_transform(self.facenet_model(image_tensor))
            caption_features = self.caption_transform(self._generate_caption(image_path))
        features = np.concatenate((face_features.cpu().detach().numpy(), resnet_features.cpu().detach().numpy(), caption_features.cpu().detach().numpy()), axis=1)
        selected_features = self._select_features(features, use_facenet, use_resnet, use_caption)
        return selected_features

    def _generate_caption(self, image_path):
        caption = self.caption_pipeline(image_path)[0]["generated_text"]
        tokens = self.language_tokenizer.encode(caption, return_tensors="pt").to(device)
        vector = self.language_model(tokens).last_hidden_state.mean(dim=1)
        return vector
    
    def _select_features(self, features, use_facenet=True, use_resnet=True, use_caption=True):
        feature_sizes = [self.face_transform.out_features, self.resnet_model.fc.out_features, self.caption_transform.out_features]
        feature_start_indices = np.cumsum([0] + feature_sizes)
        selected_features = []
        if use_facenet:
            selected_features.append(features[:, feature_start_indices[0]:feature_start_indices[1]])
        if use_resnet:
            selected_features.append(features[:, feature_start_indices[1]:feature_start_indices[2]])
        if use_caption:
            selected_features.append(features[:, feature_start_indices[2]:])
        return np.concatenate(selected_features, axis=1)
    
    def save_features(self, file_name):
        file_name = os.path.expanduser(file_name)
        data = { "features": self.features, "image_paths": self.image_paths, "image_directory": self.image_directory, "version": version }
        with open(file_name, "wb") as file: 
            pickle.dump(data, file)
    
    def load_features(self, file_name):
        file_name = os.path.expanduser(file_name)
        with open(file_name, "rb") as file:
            data = pickle.load(file)
        self.features = data["features"]
        self.image_paths = data["image_paths"]
        self.image_directory = data["image_directory"]

    def save_clusters(self, file_name):
        file_name = os.path.expanduser(file_name)
        data = { "num_clusters": self.num_clusters, "clusters": self.clusters, "version": version }
        with open(file_name, "wb") as file: 
            pickle.dump(data, file)

    def load_clusters(self, file_name):
        with open(file_name, "rb") as file:
            data = pickle.load(file)
        self.num_clusters = data["clusters"]
        self.clusters = data["clusters"]

In [None]:
ica = ImageClusteringAlgorithm()
image_directory = os.path.expanduser("~/Downloads/JPGs")
features_file = os.path.expanduser("~/Downloads/JPGs/FaceResCapF_v0.1.pkl")
clusters_file = os.path.expanduser("~/Downloads/JPGs/FaceResCapC_v0.1.pkl")
if not os.path.exists(features_file):
    ica.compute_features(image_directory)
    ica.save_features(features_file)
else:
    ica.load_features(features_file)
if not os.path.exists(clusters_file):
    num_clusters = 50
    ica.compute_clusters(num_clusters)
    ica.save_clusters(clusters_file)
else:
    ica.load_clusters(clusters_file)
ica.save_clustered_images("~/Downloads/ClusteredJPGs2")

In [None]:
ica = ImageClusteringAlgorithm()
features_file = os.path.expanduser("~/Downloads/JPGs/FaceResCapF_v0.1.pkl")
clusters_file = os.path.expanduser("~/Downloads/JPGs/FaceResCapC_v0.1.pkl")
ica.load_features(features_file)
num_clusters = 50
ica.compute_clusters(num_clusters)
ica.load_features(features_file)
ica.compute_clusters(num_clusters, use_facenet=False)
ica.load_features(features_file)
ica.compute_clusters(num_clusters, use_resnet=False)