NB! Geomstats package is required.

The latent space of the AE is topologically a $d-$ dimensional torus $\mathcal{T}^d$, i.e. it can be considered as a periodic box $[-\pi, \pi]^d$. We use $d=2$.

The notebook includes Euclidean K-means clustering on a torus implemented with Geomstats.

In this notebook data is the subset of MNIST dataset with selected labels specified in the yaml file.

1) Setting hyperparameters, dataset loading, plotting embedded data for a pre-trained AE.
2) Geomstats K-means: Euclidean metric on torus latent space + saving the results

In [None]:
# prerequisites
%matplotlib inline
import torch
import yaml,os, ricci_regularization, json
import numpy as np

# 1. Setting hyperparameters, dataset loading, plotting embedded data for a pre-trained AE.

In [None]:
# add an option of taking arbitrary data!!! not only the results of Riemannian K-means.
# But save it differently then!!
k_means_setup_number = 0
pretrained_AE_setting_name = 'MNIST_Setting_3_exp5'
Path_clustering_setup = f"../experiments/{pretrained_AE_setting_name}/K_means_setup_{k_means_setup_number}"
Path_experiment = f'../experiments/{pretrained_AE_setting_name}_config.yaml'
mode = "selected_points" # clustering only selected points
#mode = "all_points"

In [None]:
with open(Path_clustering_setup + f"/params.json", "r") as f_Riemannian:
    Riemannian_k_means_params = json.load(f_Riemannian)
encoded_points_to_cluster = torch.tensor(Riemannian_k_means_params["encoded_points"])
K = Riemannian_k_means_params["K"]
N = Riemannian_k_means_params["N"]
selected_labels = Riemannian_k_means_params["selected_labels"]
ground_truth_labels = Riemannian_k_means_params["ground_truth_labels"]

## Plots

In [None]:
"""
plt.figure(figsize=(8, 6))
plt.title("Encoded points selected for clustering colored by ground truth labels")
plt.scatter(encoded_points_to_cluster[:,0],encoded_points_to_cluster[:,1], c=ground_truth_labels, marker='o', edgecolor='none', cmap=ricci_regularization.discrete_cmap(K, 'jet'))
plt.colorbar(ticks=range(K))
plt.grid(True)
#plt.savefig(f"{Path_pictures}/latent_space.pdf",format="pdf")
"""

# 2. Geomstats K-means: Euclidean metric on torus latent space

In [None]:
#this adds an environmental variable
#%env GEOMSTATS_BACKEND=pytorch

import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.kmeans import RiemannianKMeans

In [None]:
circumference1 = Hypersphere(dim=1)
circumference2 = Hypersphere(dim=1)

Building torus as a product $\mathcal{T} = \mathcal{S}^1 \times \mathcal{S}^1$ 

In [None]:
from geomstats.geometry.product_manifold import ProductManifold
torus = ProductManifold((circumference1,circumference2))

Putting MNIST data on torus

In [None]:
circ_1_coordinates = torus.factors[0].intrinsic_to_extrinsic_coords(encoded_points_to_cluster[:,0]).reshape(2,-1).T
circ_2_coordinates = torus.factors[1].intrinsic_to_extrinsic_coords(encoded_points_to_cluster[:,1]).reshape(2,-1).T
MNIST_data_on_torus_4d = np.concatenate((circ_1_coordinates,circ_2_coordinates),axis = 1).reshape(-1,2,2) # cos\phi, sin \phi, cos \psi, sin \psi

In [None]:
kmeans = RiemannianKMeans(torus, K, tol=1e-3) # in geomstats it is called Riemannian K-means, but it is Euclidean on the chosen local chart of the torus
kmeans.fit(MNIST_data_on_torus_4d)
kmeans_latent_space_euclidean_labels = kmeans.labels_
cluster_centers = kmeans.centroids_# kmeans.cluster_centers_

In [None]:
ricci_regularization.RiemannianKmeansTools.manifold_plot_selected_labels(encoded_points2plot=encoded_points_to_cluster,
        encoded_points_labels=kmeans_latent_space_euclidean_labels,
        selected_labels=selected_labels,
        plot_title="Encoded points colored by Euclidean K-means via geomstats",
        save_plot=False)

## Saving the results

In [None]:
# Define experiment parameters
params = {
    "K": K,  # Number of clusters
    "N": N,  # Number of points to be clustered
    "selected_labels": selected_labels,  # Labels used for clustering
    
    "ground_truth_labels": ground_truth_labels,
    "Euclidean_k_means_labels": kmeans_latent_space_euclidean_labels.tolist(),
    "encoded_points": encoded_points_to_cluster.tolist()
}

# Save to JSON file
saving_path_parameters = f"{Path_clustering_setup}/Euclidean_k_means_params.json"
with open(saving_path_parameters, "w") as f:
    json.dump(params, f, indent=4)

print(f"Parameters saved to {saving_path_parameters}")