In [None]:
# minimal imports
import torch, yaml, json
from ricci_regularization import RiemannianKmeansTools

In [None]:
# experiment setup
N = 300 # number of points to be clustered
periodicity_mode = True
mode = "Schauder" 
#mode = "Interpolation_points" # alternative option

# specific parameters 
n_max = 7  # Schauder basis complexity (only for Schauder)
step_count = 100  # Number of interpolation steps (for both methods)

# optimization parameters
beta = 1.e-4 # Frechet mean learning rate #beta is learning_rate_frechet_mean (outer loop)
learning_rate = 1.e-5 # learning_rate_geodesics (inner loop)
num_iter_outer = 75 # number of Frechet mean updates (outer loop)
num_iter_inner = 10 # number of geodesics refinement interations per 1 Frechet mean update (inner loop)
device = "cuda"

In [None]:
setting_numbers = [4]
experiment_numbers = [2] #["3_alt"]#[1,2,3]
k_means_setup_numbers = [0,1,2,3,4,5,6,7,8,9]#[0,1,2,3,4]
for setting_number in setting_numbers:
    for experiment_number in experiment_numbers:
        pretrained_AE_setting_name = f'MNIST_Setting_{setting_number}_exp{experiment_number}'
        Path_AE_config = f'../experiments/{pretrained_AE_setting_name}_config.yaml'
        with open(Path_AE_config, 'r') as yaml_file:
            yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)
        selected_labels = yaml_config["dataset"]["selected_labels"]
        K = len(selected_labels) # number of clusters
        torus_ae, validation_dataset = RiemannianKmeansTools.get_validation_dataset(yaml_config)
        for k_means_setup_number in k_means_setup_numbers:
            Path_clustering_setup = f"../experiments/{pretrained_AE_setting_name}/K_means_setup_{k_means_setup_number}"
            encoded_points, ground_truth_labels = RiemannianKmeansTools.load_points_for_clustering(validation_dataset = validation_dataset, 
                                                    random_seed_picking_points=k_means_setup_number,
                                                    yaml_config=yaml_config,
                                                    torus_ae=torus_ae,Path_clustering_setup=Path_clustering_setup, N=N)

            # Create the scatter plot for points to cluster
            RiemannianKmeansTools.manifold_plot_selected_labels(encoded_points,
                        ground_truth_labels,selected_labels,
                        saving_folder=Path_clustering_setup, plot_title="Encoded Points Colored by Ground Truth Labels",
                        file_saving_name="ground_truth_labels",verbose=False)
            params = {
                "K": K,  # Number of clusters
                "N": N,  # Number of points to be clustered
                "selected_labels": selected_labels,  # Labels used for clustering
                "mode": mode,  # Can be "Schauder" or "Interpolation_points"
                "periodicity_mode": periodicity_mode, # it is a flag!!
                # Specific parameters
                "n_max": n_max,  # Schauder basis complexity
                "step_count": step_count,  # Number of interpolation steps
                
                # Optimization parameters
                "beta": beta,  # Frechet mean learning rate
                "learning_rate": learning_rate,  # Learning rate for geodesics
                "num_iter_outer": num_iter_outer,  # Number of Frechet mean updates
                "num_iter_inner": num_iter_inner,
                "device": device,
                "torus_ae": torus_ae,
                "d": yaml_config["architecture"]["latent_dim"]
            }

            results = RiemannianKmeansTools.Riemannian_k_means_fit(encoded_points, params)
            loss_history = results["history"]
            geodesic_curve = results["geodesic_curve"]
            labels_Rieamanian_k_means = results["Riemannian_k_means_labels"]
            RiemannianKmeansTools.Riemannian_k_means_losses_plot(loss_history, Path_pictures = Path_clustering_setup, verbose = False)
            """
            RiemannianKmeansTools.plot_octopus(
                geodesic_curve.detach(), 
                memberships = torch.tensor(labels_Rieamanian_k_means),
                ground_truth_labels=ground_truth_labels,
                saving_folder=Path_clustering_setup,suffix=0, periodicity_mode=periodicity_mode,
                show_points_in_original_local_charts=False, verbose=False, size_of_points=2)
            """
            torch.save(geodesic_curve, Path_clustering_setup+"/geodesic_curve.pt") # saving all geodesics
            del results["geodesic_curve"]
            del results["history"]
            del params["torus_ae"]
            results["ground_truth_labels"] = ground_truth_labels.tolist()
            results["encoded_points"] = encoded_points.tolist()
            params.update(results)
            # Save to JSON file
            with open(Path_clustering_setup+"/params.json", "w") as f:
                json.dump(params, f, indent=4)

            print(f"Parameters saved to {Path_clustering_setup}/params.json")