## Fit Clusters
* This is a stripped-down version of https://github.com/pbaylies/stylegan2-ada-pytorch/blob/main/categorizer.py
* This operates on the prebuilt CLIP ViT-B/32 embeddings from LAION-400m (but should work for image embeddings in general)

In [None]:
import numpy as np
from sklearn.decomposition import PCA
from sklearn.decomposition import FastICA
import pickle
import os

In [None]:
ie1 = np.load("images/img_emb_0.npy")

In [None]:
if os.path.isfile("pca_model.pkl"):
    pca = pickle.load( open( "pca_model.pkl", "rb" ) )
else:
    pca = PCA(n_components=256)
    pca.fit(ie1)
    pickle.dump( pca, open( "pca_model.pkl", "wb" ) )
pca_features = np.float32(pca.transform(ie1))

In [None]:
if os.path.isfile("pca_model.pkl"):
    ica = pickle.load( open( "ica_model.pkl", "rb" ) )
else:
    ica = FastICA(n_components=256, max_iter=1000, tol=2e-4)
    ica.fit(ie1)
    pickle.dump( ica, open( "ica_model.pkl", "wb" ) )
ica_features = np.float32(ica.transform(ie1))

In [None]:
more_features = np.concatenate((pca_features, ica_features), axis=1)

In [None]:
from sklearn.mixture import GaussianMixture
if os.path.isfile("gmm_model.pkl"):
    gmm = pickle.load( open( "gmm_model.pkl", "rb" ) )
else:
    gmm = GaussianMixture(n_components=64, covariance_type='tied', verbose=2, max_iter=200)
    gmm.fit(more_features)
    pickle.dump( gmm, open( "gmm_model.pkl", "wb" ) )
labels = gmm.predict(more_features)
# Note that it's also possible to use predict_proba() to get the probabilities (not calibrated) for *all* the labels:
# https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html#sklearn.mixture.GaussianMixture.predict_proba
np.save("labels_0.npy", labels)

In [None]:
for n in range(1,410):
    print(n)
    ie1 = None
    ie1 = np.load("images/img_emb_" + str(n) + ".npy")
    pca_features = np.float32(pca.transform(ie1))
    ica_features = np.float32(ica.transform(ie1))
    more_features = np.concatenate((pca_features, ica_features), axis=1)
    labels = gmm.predict(more_features)
    np.save("labels_" + str(n) + ".npy", labels)