<center> <h1> Clustering using Conditional Variational Autoencoder</h1> </center>

In [None]:
import numpy as np
import torch
import pandas as pd
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
from sklearn.mixture import GaussianMixture
import warnings
warnings.filterwarnings("ignore")
np.random.seed(42)
sns.set_style('darkgrid')

In [None]:
from clustering_utils import umap, tSNE, compute_most_represented_class_per_cluster, substitute_classes_labels
from cVAE_architecture import cVAE

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
batch_size = 128
cmap = plt.cm.get_cmap('viridis', 10)

## I. MNIST

In [None]:
# Define a transform to preprocess the data
transform = transforms.Compose([transforms.ToTensor()])

# Load the MNIST dataset
train_dataset = torchvision.datasets.MNIST(
    root='./mnist_data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(
    root='./mnist_data', train=False, transform=transform, download=True)

# Create data loaders to handle batch processing
batch_size = 128
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False)

X_train = []
Y_train = []
for images, labels in train_loader:
    X_train.append(images)
    Y_train.append(labels)
X_train = torch.cat(X_train, dim=0)
Y_train = torch.cat(Y_train, dim=0)

X_test = []
Y_test = []
for images, labels in test_loader:
    X_test.append(images)
    Y_test.append(labels)
X_test = torch.cat(X_test, dim=0)
Y_test = torch.cat(Y_test, dim=0)

# Print shapes for verification
print("X_train shape:", X_train.shape)
print("Y_train shape:", Y_train.shape)
print("X_test shape:", X_test.shape)
print("Y_test shape:", Y_test.shape)

## I.1 Latent_dim = 16

In [None]:
autoencoder_cvae = cVAE(num_labels=10, latent_dim=16)
autoencoder_cvae.load_state_dict(torch.load(
    './cVAE_models/cVAE_MNIST_zdim_16_epochs_15.pth'))

Clustering using GaussianMixture

In [None]:
z_test = autoencoder_cvae.encode(X_test, Y_test)
z_train = autoencoder_cvae.encode(X_train, Y_train)
n_clusters = 10
z_test = z_test.detach().numpy()
z_train = z_train.detach().numpy()
y_test = Y_test.detach().numpy()
clustering = GaussianMixture(
    n_components=n_clusters, covariance_type='full').fit(z_train)
clustering_labels = clustering.predict(z_test)

In [None]:
umap(z_test, clustering_labels=clustering_labels,
     true_labels=y_test, cmap=cmap, latent_space=16, data="MNIST")

In [None]:
tSNE(z_test, clustering_labels=clustering_labels,
     true_labels=y_test, cmap=cmap, latent_space=16, data="MNIST")

Samples of each cluster

In [None]:
# Samples from each cluster
nb_samples = 10
l = list(set(y_test))

fig, axes = plt.subplots(n_clusters, nb_samples, figsize=(15, 8))
for label2 in l:
    for sample_index in range(nb_samples):
        axes[label2][sample_index].imshow(
            X_test[np.where(clustering_labels == label2)[0][sample_index]][0], cmap='gray')
        axes[label2][sample_index].grid(False)
        #axes[label][sample_index].set_title(f"Sample {sample_index} for cluster {label}")

plt.show()

In [None]:
class_equivalence = compute_most_represented_class_per_cluster(
    clustering_labels, y_test)
print("class_equivalence: ", class_equivalence)
y_pred = substitute_classes_labels(clustering_labels, class_equivalence)
acc_16_mnist = accuracy_score(y_test, y_pred)
print(acc_16_mnist)

In [None]:
s_16_mnist = silhouette_score(z_test, clustering_labels)
db_16_mnist = davies_bouldin_score(z_test, clustering_labels)
ch_16_mnist = calinski_harabasz_score(z_test, clustering_labels)
print(
    f"Silhouette: {s_16_mnist:.4f}, DB: {db_16_mnist:.4f}, CH: {ch_16_mnist:.4f}")

## I.2 Latent_dim = 32

In [None]:
latent_dim = 32
autoencoder_cvae_32 = cVAE(num_labels=10, latent_dim=32)
autoencoder_cvae_32.load_state_dict(torch.load(
    './cVAE_models/cVAE_MNIST_zdim_32_epochs_15.pth'))

In [None]:
z_test = autoencoder_cvae_32.encode(X_test, Y_test)
z_train = autoencoder_cvae_32.encode(X_train, Y_train)
n_clusters = 10
z_test = z_test.detach().numpy()
z_train = z_train.detach().numpy()
y_test = Y_test.detach().numpy()
clustering = GaussianMixture(
    n_components=n_clusters, covariance_type='full').fit(z_train)
clustering_labels = clustering.predict(z_test)

In [None]:
umap(z_test, clustering_labels=clustering_labels,
     true_labels=y_test, cmap=cmap, latent_space=32, data="MNIST")

In [None]:
tSNE(z_test, clustering_labels=clustering_labels,
     true_labels=y_test, cmap=cmap, latent_space=32, data="MNIST")

In [None]:
# Samples from each cluster
nb_samples = 10
l = list(set(y_test))

fig, axes = plt.subplots(n_clusters, nb_samples, figsize=(15, 8))
for label2 in l:
    for sample_index in range(nb_samples):
        axes[label2][sample_index].imshow(
            X_test[np.where(clustering_labels == label2)[0][sample_index]][0], cmap='gray')
        axes[label2][sample_index].grid(False)

plt.show()

In [None]:
class_equivalence = compute_most_represented_class_per_cluster(
    clustering_labels, y_test)
print("class_equivalence: ", class_equivalence)
y_pred = substitute_classes_labels(clustering_labels, class_equivalence)
acc_32_mnist = accuracy_score(y_test, y_pred)
print("accuracy = ", acc_32_mnist)

In [None]:
s_32_mnist = silhouette_score(z_test, clustering_labels)
db_32_mnist = davies_bouldin_score(z_test, clustering_labels)
ch_32_mnist = calinski_harabasz_score(z_test, clustering_labels)

print(
    f"Silhouette: {s_32_mnist:.4f}, DB: {db_32_mnist:.4f}, CH: {ch_32_mnist:.4f}")

## II. FashionMnist 

In [None]:
# Define a transform to preprocess the data
transform = transforms.Compose([transforms.ToTensor()])

# Load the MNIST dataset
train_dataset = torchvision.datasets.FashionMNIST(
    root='./FashionMNIST_data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.FashionMNIST(
    root='./FashionMNIST_data', train=False, transform=transform, download=True)

# Create data loaders to handle batch processing
batch_size = 128
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False)

X_train = []
Y_train = []
for images, labels in train_loader:
    X_train.append(images)
    Y_train.append(labels)
X_train = torch.cat(X_train, dim=0)
Y_train = torch.cat(Y_train, dim=0)

X_test = []
Y_test = []
for images, labels in test_loader:
    X_test.append(images)
    Y_test.append(labels)
X_test = torch.cat(X_test, dim=0)
Y_test = torch.cat(Y_test, dim=0)

# Print shapes for verification
print("X_train shape:", X_train.shape)
print("Y_train shape:", Y_train.shape)
print("X_test shape:", X_test.shape)
print("Y_test shape:", Y_test.shape)

In [None]:
img = X_train[6][0, :, :]
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()
print("label = ", Y_train[6])

## II.1 latent_dim = 16

In [None]:
latent_dim = 16
autoencoder_cvae = cVAE(num_labels=10, latent_dim=16)
autoencoder_cvae.load_state_dict(torch.load(
    './cVAEFM_models/cVAE_MNIST_zdim_16_epochs_25.pth'))

In [None]:
z_test = autoencoder_cvae.encode(X_test, Y_test)
z_train = autoencoder_cvae.encode(X_train, Y_train)
n_clusters = 10
z_test = z_test.detach().numpy()
z_train = z_train.detach().numpy()
y_test = Y_test.detach().numpy()
clustering = GaussianMixture(
    n_components=n_clusters, covariance_type='full').fit(z_train)
clustering_labels = clustering.predict(z_test)

In [None]:
umap(z_test, clustering_labels=clustering_labels, true_labels=y_test,
     cmap=cmap, latent_space=16, data="FashionMNIST")

In [None]:
tSNE(z_test, clustering_labels=clustering_labels, true_labels=y_test,
     cmap=cmap, latent_space=16, data="FashionMNIST")

In [None]:
nb_samples = 10
l = list(set(y_test))

fig, axes = plt.subplots(n_clusters, nb_samples, figsize=(15, 8))
for label2 in l:
    for sample_index in range(nb_samples):
        axes[label2][sample_index].imshow(
            X_test[np.where(clustering_labels == label2)[0][sample_index]][0], cmap='gray')
        axes[label2][sample_index].grid(False)

plt.show()

In [None]:
class_equivalence = compute_most_represented_class_per_cluster(
    clustering_labels, y_test)
print("class_equivalence: ", class_equivalence)
y_pred = substitute_classes_labels(clustering_labels, class_equivalence)
acc_16_fashion = accuracy_score(y_test, y_pred)
print("accuracy = ", acc_16_fashion)

In [None]:
s_16_fashion = silhouette_score(z_test, clustering_labels)
db_16_fashion = davies_bouldin_score(z_test, clustering_labels)
ch_16_fashion = calinski_harabasz_score(z_test, clustering_labels)

print(
    f"Silhouette: {s_16_fashion:.4f}, DB: {db_16_fashion:.4f}, CH: {ch_16_fashion:.4f}")

## II.2 latent_dim = 32

In [None]:
latent_dim = 32
autoencoder_cvae = cVAE(num_labels=10, latent_dim=32)
autoencoder_cvae.load_state_dict(torch.load(
    './cVAEFM_models/cVAE_MNIST_zdim_32_epochs_25.pth'))

In [None]:
z_test = autoencoder_cvae.encode(X_test, Y_test)
z_train = autoencoder_cvae.encode(X_train, Y_train)
n_clusters = 10
z_test = z_test.detach().numpy()
z_train = z_train.detach().numpy()
y_test = Y_test.detach().numpy()
clustering = GaussianMixture(
    n_components=n_clusters, covariance_type='full').fit(z_train)
clustering_labels = clustering.predict(z_test)

In [None]:
umap(z_test, clustering_labels=clustering_labels, true_labels=y_test,
     cmap=cmap, latent_space=32, data="FashionMNIST")

In [None]:
tSNE(z_test, clustering_labels=clustering_labels, true_labels=y_test,
     cmap=cmap, latent_space=32, data="FashionMNIST")

In [None]:
nb_samples = 10
l = list(set(y_test))

fig, axes = plt.subplots(n_clusters, nb_samples, figsize=(15, 8))
for label2 in l:
    for sample_index in range(nb_samples):
        axes[label2][sample_index].imshow(
            X_test[np.where(clustering_labels == label2)[0][sample_index]][0], cmap='gray')
        axes[label2][sample_index].grid(False)

plt.show()

In [None]:
class_equivalence = compute_most_represented_class_per_cluster(
    clustering_labels, y_test)
print("class_equivalence: ", class_equivalence)
y_pred = substitute_classes_labels(clustering_labels, class_equivalence)
acc_32_fashion = accuracy_score(y_test, y_pred)
print("accuracy = ", acc_32_fashion)

In [None]:
s_32_fashion = silhouette_score(z_test, clustering_labels)
db_32_fashion = davies_bouldin_score(z_test, clustering_labels)
ch_32_fashion = calinski_harabasz_score(z_test, clustering_labels)

print(
    f"Silhouette: {s_32_fashion:.4f}, DB: {db_32_fashion:.4f}, CH: {ch_32_fashion:.4f}")

In [None]:
print(" ")

data = {'Dataset_Z_latentDim': ['Mnist_Z_16', 'Mnist_Z_32', 'FashionMnist_Z_16', 'FashionMnist_Z_32'],
        'Accuracy': [acc_16_mnist, acc_32_mnist, acc_16_fashion, acc_32_fashion],
        'Silhouette': [s_16_mnist, s_32_mnist, s_16_fashion, s_32_fashion],
        'DB': [db_16_mnist, db_32_mnist, db_16_fashion, db_32_fashion],

        'CH': [ch_16_mnist, ch_32_mnist, ch_16_fashion, ch_32_fashion]}

df = pd.DataFrame(data)

print(df.to_markdown(index=False))

<blockquote> 

**Silhouette Score:**
The silhouette score measures how similar an object is to its own cluster (cohesion) compared to other clusters (separation). It ranges from -1 to 1, where a high value indicates that the object is well matched to its own cluster and poorly matched to neighboring clusters.

**Davies-Bouldin Index:**
The Davies-Bouldin index measures the average similarity between each cluster and its most similar cluster. It evaluates both the compactness (small intra-cluster distance) and separation (large inter-cluster distance) of clusters. A lower value indicates better clustering.

**Calinski-Harabasz Index:**
The Calinski-Harabasz index, also known as the Variance Ratio Criterion, measures the ratio of between-cluster variance to within-cluster variance. It tends to be higher when clusters are well-separated and compact. A higher value suggests better clustering.

</blockquote> 