In [None]:
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline
plt.gray();
from matplotlib.pyplot import imshow

from sklearn.cluster import KMeans, MeanShift, estimate_bandwidth
from sklearn.mixture import GaussianMixture as GMM

from sklearn.datasets import load_digits, fetch_olivetti_faces, fetch_openml
from skimage.data import lfw_subset

In [None]:
def show_dataset(X, N, max_size=400):

    plt.figure(figsize=(20, 20))
    for i, f in enumerate(X[:max_size]):
        plt.subplot(20, 20, i + 1)

        plt.imshow(f.reshape((N, N)), vmax=1)

        plt.axis('off')

    plt.show()

In [None]:
def show_clusters(X, N, max_clusters=20, max_size=99):
    # for kmeans and mean shift
        
    for i, c in enumerate(clustering.cluster_centers_[:max_clusters]):

        plt.figure(figsize=(20, 10))

        plt.subplot(20, 20, 1)
        plt.imshow(c.reshape((N, N)), vmax=1)
        plt.axis('off')
        plt.title('{}'.format(i))

        for i, p in enumerate(X[clustering.labels_ == i][:max_size]):
            plt.subplot(20, 20, i + 2)

            plt.imshow(p.reshape((N, N)), vmax=1)

            plt.axis('off')
        plt.show()

In [None]:
def show_clusters_GMM(X, N, max_clusters=20, max_size=99):
    # for GMM
    
    labels = clustering.predict(X)

    for i, c in enumerate(clustering.means_[:max_clusters]):

        plt.figure(figsize=(20, 10))

        plt.subplot(20, 20, 1)
        plt.imshow(c.reshape((N, N)), vmax=1)
        plt.axis('off')
        plt.title('{}'.format(i))

        for i, p in enumerate(X[labels == i][:max_size]):
            plt.subplot(20, 20, i + 2)

            plt.imshow(p.reshape((N, N)), vmax=1)

            plt.axis('off')
        plt.show()

# digits

In [None]:
dataset = load_digits()
X = dataset.data / 16
print(X.shape[0]) # number of samples

In [None]:
show_dataset(X, N=8)

## kmeans

In [None]:
n_clusters=10
clustering = KMeans(n_clusters=n_clusters)

In [None]:
%time clustering.fit(X)

In [None]:
clustering.n_iter_

In [None]:
show_clusters(X, N=8)

## mean shift

In [None]:
bw = estimate_bandwidth(X)
clustering = MeanShift(bandwidth=bw/2)

In [None]:
%time clustering.fit(X)

In [None]:
len(clustering.cluster_centers_)

In [None]:
clustering.n_iter_

In [None]:
show_clusters(X, N=8)

## GMM

In [None]:
n_clusters=10
clustering = GMM(n_components=n_clusters, covariance_type='diag', n_init=10)

In [None]:
%time clustering.fit(X)

In [None]:
clustering.n_iter_

In [None]:
show_clusters_GMM(X, N=8)

# Olivetti faces

In [None]:
dataset = fetch_olivetti_faces(shuffle=False)
X = dataset.data
print(X.shape[0]) # number of samples

In [None]:
show_dataset(X, N=64)

## kmeans

In [None]:
n_clusters=10
clustering = KMeans(n_clusters=n_clusters)

In [None]:
%time clustering.fit(X)

In [None]:
clustering.n_iter_

In [None]:
show_clusters(X, N=64)

## mean shift

In [None]:
bw = estimate_bandwidth(X)
clustering = MeanShift(bandwidth=bw*0.6)

In [None]:
%time clustering.fit(X)

In [None]:
len(clustering.cluster_centers_)

In [None]:
clustering.n_iter_

In [None]:
show_clusters(X, N=64)

## GMM

In [None]:
n_clusters=10
clustering = GMM(n_components=n_clusters, covariance_type='diag', n_init=10)

In [None]:
%time clustering.fit(X)

In [None]:
clustering.n_iter_

In [None]:
show_clusters_GMM(X, N=64)

# LFW

In [None]:
X = lfw_subset().reshape((-1, 25*25))
print(X.shape[0]) # number of samples

In [None]:
show_dataset(X, N=25)

## kmeans

In [None]:
n_clusters=10
clustering = KMeans(n_clusters=n_clusters).fit(X)

In [None]:
%time clustering.fit(X)

In [None]:
clustering.n_iter_

In [None]:
show_clusters(X, N=25, max_size=400)

## mean shift

In [None]:
bw = estimate_bandwidth(X)
clustering = MeanShift(bandwidth=bw*0.6)

In [None]:
%time clustering.fit(X)

In [None]:
len(clustering.cluster_centers_)

In [None]:
clustering.n_iter_

In [None]:
show_clusters(X, N=25)

## GMM

In [None]:
n_clusters=10
clustering = GMM(n_components=n_clusters, covariance_type='diag', n_init=10)

In [None]:
%time clustering.fit(X)

In [None]:
clustering.n_iter_

In [None]:
show_clusters_GMM(X, N=25)

# MNIST

In [None]:
mnist = fetch_openml('mnist_784', version=1)
X = mnist.data / 255
print(X.shape[0]) # number of samples

In [None]:
show_dataset(X, N=28)

## kmeans

In [None]:
n_clusters=20
clustering = KMeans(n_clusters=n_clusters, n_init=1)

In [None]:
%time clustering.fit(X)

In [None]:
clustering.n_iter_

In [None]:
show_clusters(X, N=28)