In [2]:
import os
import shutil
import numpy as np
import torch
from facenet_pytorch import InceptionResnetV1
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import glob
import pickle
import warnings

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
version = "0.1"

In [11]:
class ImageClusteringAlgorithm:
    def __init__(self):
        self.facenet_model = InceptionResnetV1(pretrained="vggface2").eval().to(device)
        self.resnet_model = models.resnet50(pretrained=True).eval().to(device)

    def compute_features(self, image_directory):
        self.image_directory = os.path.expanduser(image_directory)
        file_extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif")
        self.image_paths = []
        for extension in file_extensions:
            self.image_paths.extend(glob.glob(os.path.join(self.image_directory, extension)))
        features = []
        for image_path in tqdm(self.image_paths, desc="Computing features..."):
            feature = self._generate_image_features(image_path)
            features.append(feature)
        self.features = np.vstack(features)

    def compute_clusters(self, max_clusters, find_optimal_num_clusters=True):
        max_score = 0
        min_clusters = max_clusters
        if find_optimal_num_clusters:
            min_clusters = 2
        for i in tqdm(range(min_clusters, max_clusters + 1), "Calculating optimal clusters..."):
            kmeans = KMeans(n_clusters=i, random_state=42)
            labels = kmeans.fit_predict(self.features)
            silhouette_avg = silhouette_score(self.features, labels)
            print(f"Silhouette score: {silhouette_avg:.4f}")
            if silhouette_avg > max_score:
                print(f"new best cluster: {i}, silhouette score: {silhouette_avg:.4f}")
                max_score = silhouette_avg
                self.num_clusters = i
                self.labels = labels
                self.clusters = [[] for _ in range(self.labels.max() + 1)]
                for idx, label in enumerate(self.labels):
                    self.clusters[label].append(self.image_paths[idx])

    def save_clustered_images(self, output_directory, dry_run=False):
        output_directory = os.path.expanduser(output_directory)
        os.makedirs(output_directory, exist_ok=True)
        for cluster_id in tqdm(range(self.num_clusters), "Saving clusters..."):
            i = 0
            cluster_name_zfill = len(str(self.num_clusters))
            cluster_name = format(f"{cluster_id}".zfill(cluster_name_zfill))
            cluster_element_zfill = len(str(f"{len(self.clusters[cluster_id])}"))
            for image_path in self.clusters[cluster_id]:
                cluster_element = format(f"{i}".zfill(cluster_element_zfill))
                new_file_name = os.path.join(output_directory, f"cluster_{cluster_name}_{cluster_element}{os.path.splitext(image_path)[1]}")
                i = i + 1
                if dry_run:
                    print(new_file_name)
                else:
                    shutil.copyfile(image_path, new_file_name)

    def _generate_image_features(self, image_path):
        image = Image.open(image_path).convert("RGB")
        transform = transforms.Compose([transforms.Resize((224, 224)), 
                                        transforms.ToTensor(), 
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        image_tensor = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            resnet_features = self.resnet_model(image_tensor)
            face_features = self.facenet_model(image_tensor)
        features = np.concatenate((face_features.cpu().detach().numpy(), resnet_features.cpu().detach().numpy()), axis=1)
        return features
    
    def save_features(self, file_name):
        file_name = os.path.expanduser(file_name)
        data = { "features": self.features, "image_paths": self.image_paths, "image_directory": self.image_directory, "version": version }
        with open(file_name, "wb") as file: 
            pickle.dump(data, file)
    
    def load_features(self, file_name):
        file_name = os.path.expanduser(file_name)
        with open(file_name, "rb") as file:
            data = pickle.load(file)
        self.features = data["features"]
        self.image_paths = data["image_paths"]
        self.image_directory = data["image_directory"]

    def save_clusters(self, file_name):
        file_name = os.path.expanduser(file_name)
        data = { "num_clusters": self.num_clusters, "clusters": self.clusters, "version": version }
        with open(file_name, "wb") as file: 
            pickle.dump(data, file)

    def load_clusters(self, file_name):
        with open(file_name, "rb") as file:
            data = pickle.load(file)
        self.num_clusters = data["num_clusters"]
        self.clusters = data["clusters"]

In [17]:
ica = ImageClusteringAlgorithm()
image_directory = os.path.expanduser("~/Downloads/JPGs")
features_file = os.path.expanduser("~/Downloads/JPGs/Features_v0.1.pkl")
clusters_file = os.path.expanduser("~/Downloads/JPGs/Clusters_v0.1.pkl")
if not os.path.exists(features_file):
    ica.compute_features(image_directory)
    ica.save_features(features_file)
else:
    ica.load_features(features_file)
if not os.path.exists(clusters_file):
    max_clusters = 100
    ica.compute_clusters(max_clusters)
    ica.save_clusters(clusters_file)
else:
    ica.load_clusters(clusters_file)
#ica.save_clustered_images("~/Downloads/ClusteredJPGs")

Calculating optimal clusters...:   1%|          | 1/99 [00:05<08:30,  5.21s/it]

Silhouette score: 0.1132
new best cluster: 2, silhouette score: 0.1132


Calculating optimal clusters...:   2%|▏         | 2/99 [00:11<09:09,  5.66s/it]

Silhouette score: 0.0914


Calculating optimal clusters...:   3%|▎         | 3/99 [00:17<09:27,  5.91s/it]

Silhouette score: 0.0928


Calculating optimal clusters...:   4%|▍         | 4/99 [00:24<10:02,  6.34s/it]

Silhouette score: 0.0722


Calculating optimal clusters...:   5%|▌         | 5/99 [00:31<10:20,  6.60s/it]

Silhouette score: 0.0688


Calculating optimal clusters...:   6%|▌         | 6/99 [00:39<10:58,  7.08s/it]

Silhouette score: 0.0635


Calculating optimal clusters...:   7%|▋         | 7/99 [00:46<10:59,  7.17s/it]

Silhouette score: 0.0626


Calculating optimal clusters...:   8%|▊         | 8/99 [00:55<11:26,  7.54s/it]

Silhouette score: 0.0636


Calculating optimal clusters...:   9%|▉         | 9/99 [01:04<12:00,  8.01s/it]

Silhouette score: 0.0614


Calculating optimal clusters...:  10%|█         | 10/99 [01:14<13:03,  8.80s/it]

Silhouette score: 0.0639


Calculating optimal clusters...:  11%|█         | 11/99 [01:24<13:24,  9.14s/it]

Silhouette score: 0.0627


Calculating optimal clusters...:  12%|█▏        | 12/99 [01:35<13:54,  9.59s/it]

Silhouette score: 0.0633


Calculating optimal clusters...:  13%|█▎        | 13/99 [01:46<14:25, 10.06s/it]

Silhouette score: 0.0616


Calculating optimal clusters...:  14%|█▍        | 14/99 [01:58<15:01, 10.61s/it]

Silhouette score: 0.0608


Calculating optimal clusters...:  15%|█▌        | 15/99 [02:08<14:52, 10.63s/it]

Silhouette score: 0.0600


Calculating optimal clusters...:  16%|█▌        | 16/99 [02:20<14:57, 10.81s/it]

Silhouette score: 0.0587


Calculating optimal clusters...:  17%|█▋        | 17/99 [02:31<15:06, 11.05s/it]

Silhouette score: 0.0606


Calculating optimal clusters...:  18%|█▊        | 18/99 [02:44<15:27, 11.45s/it]

Silhouette score: 0.0590


Calculating optimal clusters...:  19%|█▉        | 19/99 [02:56<15:26, 11.59s/it]

Silhouette score: 0.0576


Calculating optimal clusters...:  20%|██        | 20/99 [03:11<16:41, 12.67s/it]

Silhouette score: 0.0617


Calculating optimal clusters...:  21%|██        | 21/99 [03:23<16:24, 12.62s/it]

Silhouette score: 0.0547


Calculating optimal clusters...:  22%|██▏       | 22/99 [03:37<16:38, 12.97s/it]

Silhouette score: 0.0559


Calculating optimal clusters...:  23%|██▎       | 23/99 [03:51<16:39, 13.15s/it]

Silhouette score: 0.0565


Calculating optimal clusters...:  24%|██▍       | 24/99 [04:05<16:53, 13.52s/it]

Silhouette score: 0.0581


Calculating optimal clusters...:  25%|██▌       | 25/99 [04:20<17:19, 14.05s/it]

Silhouette score: 0.0564


Calculating optimal clusters...:  26%|██▋       | 26/99 [04:36<17:36, 14.47s/it]

Silhouette score: 0.0573


Calculating optimal clusters...:  27%|██▋       | 27/99 [04:51<17:38, 14.70s/it]

Silhouette score: 0.0567


Calculating optimal clusters...:  28%|██▊       | 28/99 [05:07<17:58, 15.19s/it]

Silhouette score: 0.0590


Calculating optimal clusters...:  29%|██▉       | 29/99 [05:24<18:08, 15.55s/it]

Silhouette score: 0.0567


Calculating optimal clusters...:  30%|███       | 30/99 [05:41<18:30, 16.10s/it]

Silhouette score: 0.0577


Calculating optimal clusters...:  31%|███▏      | 31/99 [05:58<18:31, 16.34s/it]

Silhouette score: 0.0575


Calculating optimal clusters...:  32%|███▏      | 32/99 [06:15<18:37, 16.68s/it]

Silhouette score: 0.0569


Calculating optimal clusters...:  33%|███▎      | 33/99 [06:33<18:33, 16.87s/it]

Silhouette score: 0.0568


Calculating optimal clusters...:  34%|███▍      | 34/99 [06:50<18:30, 17.08s/it]

Silhouette score: 0.0559


Calculating optimal clusters...:  35%|███▌      | 35/99 [07:10<18:54, 17.72s/it]

Silhouette score: 0.0560


Calculating optimal clusters...:  36%|███▋      | 36/99 [07:28<18:57, 18.06s/it]

Silhouette score: 0.0582


Calculating optimal clusters...:  37%|███▋      | 37/99 [07:48<19:01, 18.42s/it]

Silhouette score: 0.0574


Calculating optimal clusters...:  38%|███▊      | 38/99 [08:07<19:00, 18.69s/it]

Silhouette score: 0.0549


Calculating optimal clusters...:  39%|███▉      | 39/99 [08:27<18:55, 18.93s/it]

Silhouette score: 0.0553


Calculating optimal clusters...:  40%|████      | 40/99 [08:47<19:07, 19.45s/it]

Silhouette score: 0.0576


Calculating optimal clusters...:  41%|████▏     | 41/99 [09:08<19:08, 19.81s/it]

Silhouette score: 0.0547


Calculating optimal clusters...:  42%|████▏     | 42/99 [09:29<19:03, 20.07s/it]

Silhouette score: 0.0571


Calculating optimal clusters...:  43%|████▎     | 43/99 [09:50<19:10, 20.54s/it]

Silhouette score: 0.0565


Calculating optimal clusters...:  44%|████▍     | 44/99 [10:12<19:15, 21.01s/it]

Silhouette score: 0.0555


Calculating optimal clusters...:  45%|████▌     | 45/99 [10:35<19:16, 21.41s/it]

Silhouette score: 0.0560


Calculating optimal clusters...:  46%|████▋     | 46/99 [10:56<18:50, 21.33s/it]

Silhouette score: 0.0558


Calculating optimal clusters...:  47%|████▋     | 47/99 [11:17<18:30, 21.36s/it]

Silhouette score: 0.0524


Calculating optimal clusters...:  48%|████▊     | 48/99 [11:39<18:13, 21.43s/it]

Silhouette score: 0.0544


Calculating optimal clusters...:  49%|████▉     | 49/99 [12:01<18:05, 21.70s/it]

Silhouette score: 0.0552


Calculating optimal clusters...:  51%|█████     | 50/99 [12:25<18:15, 22.35s/it]

Silhouette score: 0.0566


Calculating optimal clusters...:  52%|█████▏    | 51/99 [12:48<18:03, 22.57s/it]

Silhouette score: 0.0561


Calculating optimal clusters...:  53%|█████▎    | 52/99 [13:12<17:59, 22.96s/it]

Silhouette score: 0.0541


Calculating optimal clusters...:  54%|█████▎    | 53/99 [13:36<17:45, 23.17s/it]

Silhouette score: 0.0543


Calculating optimal clusters...:  55%|█████▍    | 54/99 [14:01<17:52, 23.83s/it]

Silhouette score: 0.0537


Calculating optimal clusters...:  56%|█████▌    | 55/99 [14:26<17:43, 24.18s/it]

Silhouette score: 0.0555


Calculating optimal clusters...:  57%|█████▋    | 56/99 [14:52<17:44, 24.77s/it]

Silhouette score: 0.0526


Calculating optimal clusters...:  58%|█████▊    | 57/99 [15:18<17:33, 25.08s/it]

Silhouette score: 0.0539


Calculating optimal clusters...:  59%|█████▊    | 58/99 [15:44<17:24, 25.47s/it]

Silhouette score: 0.0544


Calculating optimal clusters...:  60%|█████▉    | 59/99 [16:10<17:07, 25.69s/it]

Silhouette score: 0.0565


Calculating optimal clusters...:  61%|██████    | 60/99 [16:37<16:55, 26.04s/it]

Silhouette score: 0.0552


Calculating optimal clusters...:  62%|██████▏   | 61/99 [17:04<16:40, 26.33s/it]

Silhouette score: 0.0556


Calculating optimal clusters...:  63%|██████▎   | 62/99 [17:32<16:27, 26.69s/it]

Silhouette score: 0.0538


Calculating optimal clusters...:  64%|██████▎   | 63/99 [17:59<16:05, 26.83s/it]

Silhouette score: 0.0534


Calculating optimal clusters...:  65%|██████▍   | 64/99 [18:28<15:58, 27.37s/it]

Silhouette score: 0.0528


Calculating optimal clusters...:  66%|██████▌   | 65/99 [18:56<15:42, 27.72s/it]

Silhouette score: 0.0542


Calculating optimal clusters...:  67%|██████▋   | 66/99 [19:25<15:27, 28.12s/it]

Silhouette score: 0.0554


Calculating optimal clusters...:  68%|██████▊   | 67/99 [19:54<15:06, 28.33s/it]

Silhouette score: 0.0528


Calculating optimal clusters...:  69%|██████▊   | 68/99 [20:25<14:59, 29.03s/it]

Silhouette score: 0.0536


Calculating optimal clusters...:  70%|██████▉   | 69/99 [20:55<14:41, 29.39s/it]

Silhouette score: 0.0552


Calculating optimal clusters...:  71%|███████   | 70/99 [21:25<14:17, 29.56s/it]

Silhouette score: 0.0563


Calculating optimal clusters...:  72%|███████▏  | 71/99 [21:55<13:52, 29.72s/it]

Silhouette score: 0.0533


Calculating optimal clusters...:  73%|███████▎  | 72/99 [22:28<13:49, 30.74s/it]

Silhouette score: 0.0543


Calculating optimal clusters...:  74%|███████▎  | 73/99 [22:59<13:21, 30.83s/it]

Silhouette score: 0.0525


Calculating optimal clusters...:  75%|███████▍  | 74/99 [23:31<13:00, 31.20s/it]

Silhouette score: 0.0550


Calculating optimal clusters...:  76%|███████▌  | 75/99 [24:04<12:37, 31.56s/it]

Silhouette score: 0.0520


Calculating optimal clusters...:  77%|███████▋  | 76/99 [24:37<12:16, 32.04s/it]

Silhouette score: 0.0535


Calculating optimal clusters...:  78%|███████▊  | 77/99 [25:10<11:50, 32.29s/it]

Silhouette score: 0.0549


Calculating optimal clusters...:  79%|███████▉  | 78/99 [25:43<11:22, 32.48s/it]

Silhouette score: 0.0540


Calculating optimal clusters...:  80%|███████▉  | 79/99 [26:15<10:50, 32.55s/it]

Silhouette score: 0.0556


Calculating optimal clusters...:  81%|████████  | 80/99 [26:48<10:21, 32.73s/it]

Silhouette score: 0.0539


Calculating optimal clusters...:  82%|████████▏ | 81/99 [27:22<09:56, 33.11s/it]

Silhouette score: 0.0532


Calculating optimal clusters...:  83%|████████▎ | 82/99 [27:56<09:26, 33.33s/it]

Silhouette score: 0.0546


Calculating optimal clusters...:  84%|████████▍ | 83/99 [28:31<08:58, 33.63s/it]

Silhouette score: 0.0539


Calculating optimal clusters...:  85%|████████▍ | 84/99 [29:06<08:30, 34.06s/it]

Silhouette score: 0.0540


Calculating optimal clusters...:  86%|████████▌ | 85/99 [29:41<08:02, 34.46s/it]

Silhouette score: 0.0533


Calculating optimal clusters...:  87%|████████▋ | 86/99 [30:17<07:35, 35.00s/it]

Silhouette score: 0.0550


Calculating optimal clusters...:  88%|████████▊ | 87/99 [30:56<07:14, 36.19s/it]

Silhouette score: 0.0564


Calculating optimal clusters...:  89%|████████▉ | 88/99 [31:33<06:39, 36.34s/it]

Silhouette score: 0.0531


Calculating optimal clusters...:  90%|████████▉ | 89/99 [32:10<06:06, 36.62s/it]

Silhouette score: 0.0570


Calculating optimal clusters...:  91%|█████████ | 90/99 [32:47<05:30, 36.76s/it]

Silhouette score: 0.0545


Calculating optimal clusters...:  92%|█████████▏| 91/99 [33:25<04:56, 37.09s/it]

Silhouette score: 0.0572


Calculating optimal clusters...:  93%|█████████▎| 92/99 [34:03<04:20, 37.28s/it]

Silhouette score: 0.0552


Calculating optimal clusters...:  94%|█████████▍| 93/99 [34:41<03:45, 37.55s/it]

Silhouette score: 0.0540


Calculating optimal clusters...:  95%|█████████▍| 94/99 [35:21<03:11, 38.36s/it]

Silhouette score: 0.0538


Calculating optimal clusters...:  96%|█████████▌| 95/99 [36:01<02:34, 38.69s/it]

Silhouette score: 0.0548


Calculating optimal clusters...:  97%|█████████▋| 96/99 [36:41<01:56, 38.99s/it]

Silhouette score: 0.0549


Calculating optimal clusters...:  98%|█████████▊| 97/99 [37:20<01:18, 39.21s/it]

Silhouette score: 0.0547


Calculating optimal clusters...:  99%|█████████▉| 98/99 [38:00<00:39, 39.39s/it]

Silhouette score: 0.0537


Calculating optimal clusters...: 100%|██████████| 99/99 [38:39<00:00, 23.43s/it]

Silhouette score: 0.0557



