In [None]:
import torch, ricci_regularization, json
import numpy as np
import geomstats

circumference1 = geomstats.geometry.hypersphere.Hypersphere(dim=1)
circumference2 = geomstats.geometry.hypersphere.Hypersphere(dim=1)
#Building torus as a product $\mathcal{T} = \mathcal{S}^1 \times \mathcal{S}^1$ 
torus = geomstats.geometry.product_manifold.ProductManifold((circumference1,circumference2))

In [None]:
setting_numbers = [4]
experiment_numbers = [2]
k_means_setup_numbers = [0,1,2,3,4,5,6,7,8,9]
for setting_number in setting_numbers:
    for experiment_number in experiment_numbers:
        for k_means_setup_number in k_means_setup_numbers:
            pretrained_AE_setting_name = f'MNIST_Setting_{setting_number}_exp{experiment_number}'
            Path_clustering_setup = f"../experiments/{pretrained_AE_setting_name}/K_means_setup_{k_means_setup_number}"
            with open(Path_clustering_setup + f"/params.json", "r") as f_Riemannian:
                Riemannian_k_means_params = json.load(f_Riemannian)
            encoded_points_to_cluster = torch.tensor(Riemannian_k_means_params["encoded_points"])
            K = Riemannian_k_means_params["K"]
            N = Riemannian_k_means_params["N"]
            selected_labels = Riemannian_k_means_params["selected_labels"]
            ground_truth_labels = Riemannian_k_means_params["ground_truth_labels"]

            # 2. Geomstats K-means: Euclidean metric on torus latent space
            #this adds an environmental variable
            #%env GEOMSTATS_BACKEND=pytorch

            #Putting MNIST data on torus
            circ_1_coordinates = torus.factors[0].intrinsic_to_extrinsic_coords(encoded_points_to_cluster[:,0]).reshape(2,-1).T
            circ_2_coordinates = torus.factors[1].intrinsic_to_extrinsic_coords(encoded_points_to_cluster[:,1]).reshape(2,-1).T
            MNIST_data_on_torus_4d = np.concatenate((circ_1_coordinates,circ_2_coordinates),axis = 1).reshape(-1,2,2) # cos\phi, sin \phi, cos \psi, sin \psi
            kmeans = geomstats.learning.kmeans.RiemannianKMeans(torus, K, tol=1e-3) # in geomstats it is called Riemannian K-means, but it is Euclidean on the chosen local chart of the torus
            kmeans.fit(MNIST_data_on_torus_4d)
            kmeans_latent_space_euclidean_labels = kmeans.labels_
            cluster_centers = kmeans.centroids_# kmeans.cluster_centers_
            """
            ricci_regularization.RiemannianKmeansTools.manifold_plot_selected_labels(encoded_points2plot=encoded_points_to_cluster,
                    encoded_points_labels=kmeans_latent_space_euclidean_labels,
                    selected_labels=selected_labels,
                    plot_title="Encoded points colored by Euclidean K-means via geomstats",
                    save_plot=False)
            """
            ## Saving the results
            # Define experiment parameters
            params = {
                "K": K,  # Number of clusters
                "N": N,  # Number of points to be clustered
                "selected_labels": selected_labels,  # Labels used for clustering
                
                "ground_truth_labels": ground_truth_labels,
                "Euclidean_k_means_labels": kmeans_latent_space_euclidean_labels.tolist(),
                "encoded_points": encoded_points_to_cluster.tolist()
            }
            # Save to JSON file
            saving_path_parameters = f"{Path_clustering_setup}/Euclidean_k_means_params.json"
            with open(saving_path_parameters, "w") as f:
                json.dump(params, f, indent=4)

            print(f"Parameters saved to {saving_path_parameters}")

In [None]:
from ricci_regularization import RiemannianKmeansTools
knn_neighbours_count = 7
point_size = 100
background_opacity = 0.4
colormap = 'jet'
for setting_number in setting_numbers:
        for experiment_number in experiment_numbers:
                f_measure_euclidean_list = []
                f_measure_riemannian_list = [] 
                pretrained_AE_setting_name = f'MNIST_Setting_{setting_number}_exp{experiment_number}'
                for k_means_setup_number in k_means_setup_numbers:
                        Path_clustering_setup = f"../experiments/{pretrained_AE_setting_name}/K_means_setup_{k_means_setup_number}"
                        with open(Path_clustering_setup + f"/params.json", "r") as f_Riemannian:
                                Riemannian_k_means_params = json.load(f_Riemannian)
                        with open(Path_clustering_setup + f"/Euclidean_k_means_params.json", "r") as f_Euclidean:
                                Euclidean_k_means_params = json.load(f_Euclidean)
                        contour_levels_count = Euclidean_k_means_params["K"]+1 # should be k+1
                        ground_truth_labels = torch.tensor(Riemannian_k_means_params['ground_truth_labels'])
                        # check that ground truth labels are saved correctly (are the same in both methods)
                        assert torch.equal(ground_truth_labels,
                                        torch.tensor(Euclidean_k_means_params['ground_truth_labels']))

                        labels_assigned_by_Euclidean_k_means = torch.tensor(Euclidean_k_means_params['Euclidean_k_means_labels'])
                        labels_assigned_by_Riemannian_k_means = torch.tensor(Riemannian_k_means_params['Riemannian_k_means_labels'])
                        encoded_points = torch.tensor(Riemannian_k_means_params['encoded_points'])
                        assert torch.equal(encoded_points,
                                        torch.tensor(Euclidean_k_means_params['encoded_points']))
                        selected_labels = Riemannian_k_means_params["selected_labels"]
                        # F_measure
                        f_measure_riemannian = RiemannianKmeansTools.compute_f_measure(labels_assigned_by_Riemannian_k_means, 
                                ground_truth_labels)
                        f_measure_riemannian_list.append(f_measure_riemannian)
                        open(f"{Path_clustering_setup}/f_measure_Riemannian.txt", "w").write(str(f_measure_riemannian))

                        f_measure_euclidean = RiemannianKmeansTools.compute_f_measure(labels_assigned_by_Euclidean_k_means, 
                                ground_truth_labels)
                        open(f"{Path_clustering_setup}/f_measure_Euclidean.txt", "w").write(str(f_measure_euclidean))
                        f_measure_euclidean_list.append(f_measure_euclidean)
                        # Voronoi cells
                        # labels by ground truth labels
                        RiemannianKmeansTools.plot_knn_decision_boundary(encoded_points,
                                labels_for_coloring=ground_truth_labels,
                                contour_levels_count=contour_levels_count,
                                neighbours_number=knn_neighbours_count, selected_labels=selected_labels,
                                saving_folder=Path_clustering_setup, cmap_background=colormap,cmap_points=colormap,
                                background_opacity = background_opacity, points_size = point_size,
                                plot_title= f"Points in latent space colored by ground truth labels, \nVoronoi's cells colored by {knn_neighbours_count}-NN.",
                                file_saving_name=f"Decision_boundary_ground_truth_labels_{pretrained_AE_setting_name}",
                                verbose=False)
                        # labels by the Riemannian clustering 
                        RiemannianKmeansTools.plot_knn_decision_boundary(encoded_points,ground_truth_labels=ground_truth_labels,
                                labels_for_coloring=labels_assigned_by_Riemannian_k_means, contour_levels_count=contour_levels_count,
                                neighbours_number=knn_neighbours_count, selected_labels=selected_labels,
                                points_size=point_size,background_opacity=background_opacity,
                                saving_folder=Path_clustering_setup, cmap_background=colormap,cmap_points=colormap,
                                plot_title= f"Points in latent space colored by labels assigned by\n Riemannian k-means, Voronoi's cells colored by {knn_neighbours_count}-NN.",
                                file_saving_name=f"Decision_boundary_Riemannian_k-means_labels_{pretrained_AE_setting_name}",
                                verbose=False)
                        # labels by the Euclidean clustering 
                        RiemannianKmeansTools.plot_knn_decision_boundary(encoded_points, ground_truth_labels=ground_truth_labels,
                                labels_for_coloring=labels_assigned_by_Euclidean_k_means, contour_levels_count=contour_levels_count,
                                neighbours_number=knn_neighbours_count, selected_labels=selected_labels,
                                points_size=point_size,background_opacity=background_opacity,
                                saving_folder=Path_clustering_setup, cmap_background=colormap,cmap_points=colormap,
                                plot_title= f"Points in latent space colored by labels assigned by\n Euclidean k-means, Voronoi's cells colored by {knn_neighbours_count}-NN.",
                                file_saving_name=f"Decision_boundary_Euclidean_k-means_labels_{pretrained_AE_setting_name}",
                                verbose=False)
                #end for
                n_samples = len(k_means_setup_numbers)
                dict = {
                        "euclidean_F_measure_mean":np.array(f_measure_euclidean_list).mean(),
                        "euclidean_F_measure_std":np.array(f_measure_euclidean_list).std(),
                        "euclidean_F_measure_SEM":np.array(f_measure_euclidean_list).std()/np.sqrt(n_samples),
                        "riemannian_F_measure_mean":np.array(f_measure_riemannian_list).mean(),
                        "riemannian_F_measure_std":np.array(f_measure_riemannian_list).std(),
                        "riemannian_F_measure_SEM":np.array(f_measure_riemannian_list).std()/np.sqrt(n_samples)
                        }
                with open(f"../experiments/{pretrained_AE_setting_name}/f_measure_stats.json", "w") as f:
                        json.dump(dict, f, indent=4)