NB! Geomstats package is required.

The latent space of the AE is topologically a $d-$ dimensional torus $\mathcal{T}^d$, i.e. it can be considered as a periodic box $[-\pi, \pi]^d$. 

The notebook includes 3 K-means clustering applications.
1) For input data (uses skit.learn K-means see https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html)
2) For output data of the AE (also uses skit.learn K-means)
3) The torus latent space of the AE with Euclidean metric (uses geomstats package, see https://geomstats.github.io/notebooks/07_practical_methods__riemannian_kmeans.html#)

In this notebook data is the part of MNIST dataset with 2 selected labels (5 and 8).

F-scores (see https://en.wikipedia.org/wiki/F-score) of clusterizations vs ground truth labels are comuted. The efficiency of clusterization are computed.

The contents of the notebook are:

1) Setting hyperparameters, dataset loading, plotting embedded data for a pre-trained AE.
2) Geomstats K-means: Euclidean metric on torus latent space
3) K-means in input data space
4) K-means in a output data space
5) F-scores comparison

In [None]:
# prerequisites
%matplotlib inline
import sklearn
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import torch
import yaml,os, ricci_regularization, json
import numpy as np
from tqdm.notebook import tqdm

# 1. Setting hyperparameters, dataset loading, plotting embedded data for a pre-trained AE.

In [None]:
k_means_setup_number = 0
pretrained_AE_setting_name = 'MNIST_Setting_3_exp5'
Path_clustering_setup = f"../experiments/{pretrained_AE_setting_name}/K_means_setup_{k_means_setup_number}"
Path_experiment = f'../experiments/{pretrained_AE_setting_name}_config.yaml'
mode = "selected_points" # clustering only selected points
#mode = "all_points"

In [None]:
with open(Path_clustering_setup + f"/params.json", "r") as f_Riemannian:
    Riemannian_k_means_params = json.load(f_Riemannian)
encoded_points_to_cluster = torch.tensor(Riemannian_k_means_params["encoded_points"])
K = Riemannian_k_means_params["K"]
N = Riemannian_k_means_params["N"]
selected_labels = Riemannian_k_means_params["selected_labels"]
ground_truth_labels = Riemannian_k_means_params["ground_truth_labels"]

## Plots

In [None]:
"""
plt.figure(figsize=(8, 6))
plt.title("Encoded points selected for clustering colored by ground truth labels")
plt.scatter(encoded_points_to_cluster[:,0],encoded_points_to_cluster[:,1], c=ground_truth_labels, marker='o', edgecolor='none', cmap=ricci_regularization.discrete_cmap(K, 'jet'))
plt.colorbar(ticks=range(K))
plt.grid(True)
#plt.savefig(f"{Path_pictures}/latent_space.pdf",format="pdf")
"""

# 2. Geomstats K-means: Euclidean metric on torus latent space

In [None]:
#this adds an environmental variable
#%env GEOMSTATS_BACKEND=pytorch

import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.kmeans import RiemannianKMeans

In [None]:
circumference1 = Hypersphere(dim=1)
circumference2 = Hypersphere(dim=1)

Building torus as a product $\mathcal{T} = \mathcal{S}^1 \times \mathcal{S}^1$ 

In [None]:
from geomstats.geometry.product_manifold import ProductManifold
torus = ProductManifold((circumference1,circumference2))

Putting MNIST data on torus

In [None]:
circ_1_coordinates = torus.factors[0].intrinsic_to_extrinsic_coords(encoded_points_to_cluster[:,0]).reshape(2,-1).T
circ_2_coordinates = torus.factors[1].intrinsic_to_extrinsic_coords(encoded_points_to_cluster[:,1]).reshape(2,-1).T
MNIST_data_on_torus_4d = np.concatenate((circ_1_coordinates,circ_2_coordinates),axis = 1).reshape(-1,2,2) # cos\phi, sin \phi, cos \psi, sin \psi

In [None]:
kmeans = RiemannianKMeans(torus, K, tol=1e-3) # in geomstats it is called Riemannian K-means, but it is Euclidean on the chosen local chart of the torus
kmeans.fit(MNIST_data_on_torus_4d)
kmeans_latent_space_euclidean_labels = kmeans.labels_
cluster_centers = kmeans.centroids_# kmeans.cluster_centers_

In [None]:
ricci_regularization.RiemannianKmeansTools.manifold_plot_selected_labels(encoded_points2plot=encoded_points_to_cluster,
        encoded_points_labels=kmeans_latent_space_euclidean_labels,
        selected_labels=selected_labels,
        plot_title="Encoded points colored by Euclidean K-means via geomstats",
        save_plot=False)

## Saving the results

In [None]:
# Define experiment parameters
params = {
    "K": K,  # Number of clusters
    "N": N,  # Number of points to be clustered
    "selected_labels": selected_labels,  # Labels used for clustering
    
    "ground_truth_labels": ground_truth_labels,
    "Euclidean_k_means_labels": kmeans_latent_space_euclidean_labels.tolist(),
    "encoded_points": encoded_points_to_cluster.tolist()
}

# Save to JSON file
saving_path_parameters = f"{Path_clustering_setup}/Euclidean_k_means_params.json"
with open(saving_path_parameters, "w") as f:
    json.dump(params, f, indent=4)

print(f"Parameters saved to {saving_path_parameters}")

In [None]:
# stop here. Pieces of code below are yet to do
raise Exception("Stopping point: Review output before proceeding.")

# 3. K-means in input data space

# Data loading

In [None]:
# Read from JSON file
with open(Path_clustering_setup + "/params.json", "r") as f:
    json_config = json.load(f)

# Print or use the params dictionary
#print(json_config)
K = json_config["K"]
with open(Path_experiment, 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

experiment_name = yaml_config["experiment"]["name"]
Path_pictures = "../experiments/" + yaml_config["experiment"]["name"] + f"/Euclidean_K_means_setup_{k_means_setup_number}"
# Check and create directories based on configuration
if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
    os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
    print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
else:
    print(f"Directiry already exists: {Path_pictures}")

# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
print("Experiment results loaded successfully.")
# Loading data
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders
print("Data loaders created successfully.")

# Loading the pre-tained AE
torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config)
print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)

In [None]:
# the whole dataset
D = yaml_config["architecture"]["input_dim"]
torus_ae.cpu()
colorlist = []
enc_list = []
feature_space_encoding_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1,D)))
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder_to_lifting(data.view(-1,D)))
    colorlist.append(labels) 

input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

In [None]:
"""
# picking only the selected part
N = json_config["N"]
selected_labels = json_config["selected_labels"]
# we use some random N points of the test dataset that we will cluster
# This could be done differently, e.g. by simply picking random points
D = yaml_config["architecture"]["input_dim"]
d = yaml_config["architecture"]["latent_dim"]

# Extract data and ground truth labels of the subset
all_data = test_dataset.data
all_labels = test_dataset.targets
mask = torch.isin(all_labels, torch.tensor(selected_labels)) # mask will be used to chose only labels in selected_labels
# Filter dataset
data_filtered = all_data[mask]
labels_filtered = all_labels[mask]
torch.manual_seed(0)
indices = torch.randperm(len(data_filtered))[:N]  # Randomly shuffle and pick first N
mnist_subset = data_filtered[indices]
ground_truth_labels = labels_filtered[indices]

# meaningless alternative 
#data = test_dataset.data
#subset_indices = list(range(N))
#mnist_subset = torch.utils.data.Subset(data, subset_indices)

# constructing dataloader for the mnist_subset
dataset_batch_size = 128
dataloader = torch.utils.data.DataLoader(mnist_subset, batch_size=dataset_batch_size, shuffle=False)
# encoding into latent space
torus_ae.cpu()
torus_ae.eval()

# Encode samples into latent space
encoded_points = []
with torch.no_grad():  # No need to compute gradients
    for images in dataloader:
#        print(images.shape)
        latent = torus_ae.encoder_to_lifting( (images.reshape(-1, D)).to(torch.float32) )  # Pass images through the encoder
        encoded_points.append(latent)
encoded_points = torch.cat(encoded_points)
#filtering poins to choose N of them with labels in selected_labels
#clusters can be unbalanced
list_encoded_data_filtered = []
list_labels_filtered = []
for data,label in train_loader:
    mask_batch = torch.isin(label, torch.tensor(selected_labels)) # mask will be used to chose only labels in selected_labels
    data_filtered = data[mask_batch]
    labels_filtered = label[mask_batch]
    enc_images = torus_ae.encoder_to_lifting(data_filtered.reshape(-1, D)).detach()
    list_encoded_data_filtered.append(enc_images)
    list_labels_filtered.append(labels_filtered)
    #print(labels_filtered)
all_encoded_data_filtered = torch.cat(list_encoded_data_filtered)
all_labels_filtered = torch.cat(list_labels_filtered)
#randomly picking N points with selected labels
torch.manual_seed(0)
indices = torch.randperm(len(all_encoded_data_filtered))[:N]  # Randomly shuffle and pick first N
encoded_points_selected = all_encoded_data_filtered[indices]
ground_truth_labels_selected = all_labels_filtered[indices]
"""

In [None]:
# choosing the dataset to cluster
if mode == "selected_points":
    #ae_input_points2cluster = 
    encoded_points_to_cluster = encoded_points_selected
    #ae_outpu_points2cluster = 
    ground_truth_labels = ground_truth_labels_selected
# to be done
#elif mode == "all_labels":
#    encoded_points = encoded_points_no_grad


In [None]:
kmeans_input_space = KMeans(n_clusters=K, random_state=0, n_init="auto").fit(input_dataset.reshape(-1,D).detach())
kmeans_input_space_labels = kmeans_input_space.labels_
print(f"k-means clusterisation to {K} clusters")

In [None]:
#plt.figure(figsize=(8, 6))
fig,(ax1,ax2) = plt.subplots(2,1,figsize=(8, 12))
p1 = ax1.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=kmeans_input_space_labels, marker='o', edgecolor='none', cmap=discrete_cmap(K, 'jet'))
plt.colorbar(p1,ticks=range(K))
ax1.title.set_text(f"K-means clusterization on input data, K = {K}, \n Euclidean metric in input space $R^D$")
ax1.grid(True)

correcltly_detected_labels = abs(kmeans_input_space_labels - gt_labels)
if correcltly_detected_labels.sum() < len(gt_labels)//2:
    correcltly_detected_labels = np.logical_not(correcltly_detected_labels)

p2 = ax2.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=correcltly_detected_labels, marker='o', edgecolor='none', cmap=plt.cm.get_cmap("viridis", K))
cbar = plt.colorbar(p2,ticks=[0.25,0.75])
cbar.ax.set_yticklabels(["incorrect","correct"]) 
ax1.title.set_text(f"K-means clusterization on input data, K = {K}, \n Euclidean metric in input space $R^D$")
ax1.grid(True)
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/Kmeans_input_space.pdf",format="pdf")

# 4. K-means in a output data space

In [None]:
kmeans_recon_space = KMeans(n_clusters=K, random_state=0, n_init="auto").fit(recon_dataset.detach())
kmeans_recon_space_labels = kmeans_recon_space.labels_
print(f"k-means clusterisation to {K} clusters")

In [None]:
plt.figure(figsize=(8, 6))

plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=kmeans_recon_space_labels, marker='o', edgecolor='none', cmap=discrete_cmap(K, 'jet'))
plt.colorbar(ticks=range(K))
plt.title(f"K-means clusterization on reconstructed data, K = {K}, \n Euclidean metric in output space $R^D$")
plt.grid(True)
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/Kmeans_latent_space.pdf",format="pdf")

# 5. F-score comparison

In [None]:
kmeans_latent_space_euclidean_permuted_labels = abs(kmeans_latent_space_euclidean_labels - 1)
kmeans_recon_space_permuted_labels = abs(kmeans_recon_space_labels - 1)

kmeans_input_space_permuted_labels = abs(kmeans_input_space_labels - 1)

In [None]:


F_score_latent_space_eucl = max(sklearn.metrics.f1_score(gt_labels,kmeans_latent_space_euclidean_labels),
              sklearn.metrics.f1_score(gt_labels,kmeans_latent_space_euclidean_permuted_labels))

F_score_input_space_eucl = max(sklearn.metrics.f1_score(gt_labels,kmeans_input_space_labels),
              sklearn.metrics.f1_score(gt_labels,kmeans_input_space_permuted_labels))

F_score_recon_space_eucl = max(sklearn.metrics.f1_score(gt_labels,kmeans_recon_space_labels),
              sklearn.metrics.f1_score(gt_labels,kmeans_recon_space_permuted_labels))
curv_w = json_config["losses"]["curv_w"]
print(f"Curvature penalization weight: {curv_w}")
print(f"F-score Euclidean k-means in latent space vs ground truth: \n{F_score_latent_space_eucl}")
print(f"F-score Euclidean k-means in input data space vs ground truth: \n{F_score_input_space_eucl}")
print(f"F-score Euclidean k-means in reconstructed data space vs ground truth: \n{F_score_recon_space_eucl}")

In [None]:
dict = {
    "curv_w" : curv_w,
    "labels": selected_labels,
    "F-score Euclidean k-means in latent space vs ground truth" : F_score_latent_space_eucl,
    "F-score Euclidean k-means in reconstructed data space vs ground truth" : F_score_recon_space_eucl,
    "F-score Euclidean k-means in input data space vs ground truth" : F_score_input_space_eucl
}
with open(f'{Path_pictures}/K-means_exp{experiment_number}.json', 'w') as json_file:
    json.dump(dict, json_file, indent=4)