In [None]:
# minimal imports
import torch, yaml, os, json
import ricci_regularization
import matplotlib.pyplot as plt
from tqdm import tqdm
from ricci_regularization import RiemannianKmeansTools
import time
import numpy as np

# Hyperparameters

In [None]:
k_means_setup_number = 5
periodicity_mode = True
pretrained_AE_setting_name = 'MNIST_Setting_3_exp1'
Path_AE_config = f'../experiments/{pretrained_AE_setting_name}_config.yaml'
with open(Path_AE_config, 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

In [None]:
# experiment setup
selected_labels = yaml_config["dataset"]["selected_labels"]
K = len(selected_labels) # number of clusters
N = 300 # number of points to be clustered

# Select labels to make subset of mnist data points to be clustered
random_seed_picking_points = 2 #k_means_setup_number

mode = "Schauder" 
#mode = "Interpolation_points" # alternative option

# specific parameters 
n_max = 7  # Schauder basis complexity (only for Schauder)
step_count = 100  # Number of interpolation steps (for both methods)

# optimization parameters
beta = 2.e-4 # Frechet mean learning rate #beta is learning_rate_frechet_mean (outer loop)
learning_rate = 0.1e-4 # learning_rate_geodesics (inner loop)
num_iter_outer = 50 # number of Frechet mean updates (outer loop)
num_iter_inner = 10 # number of geodesics refinement interations per 1 Frechet mean update (inner loop)
device = "cuda"

# Uploading the pretrained AE + creatong directory for results


In [None]:

    
#experiment_name = yaml_config["experiment"]["name"]
Path_pictures = "../experiments/" + pretrained_AE_setting_name + f"/K_means_setup_{k_means_setup_number}"
# Check and create directories based on configuration
if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
    os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
    print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
else:
    print(f"Directiry already exists: {Path_pictures}")

# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
print("Experiment results loaded successfully.")
# Loading data
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders
print("Data loaders created successfully.")

# Loading the pre-tained AE
torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config)
print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)
torus_ae.cpu()
torus_ae.eval()
print("AE sent to cpu and eval mode activated successfully.")

# Picking dataset to be clusterized

In [None]:
# collecting the dataset that we want to cluster
# we use some random N points of the test dataset that we will cluster
# This could be done differently, e.g. by simply picking random points
D = yaml_config["architecture"]["input_dim"]
d = yaml_config["architecture"]["latent_dim"]

In [None]:
#filtering poins to choose N of them with labels in selected_labels
#clusters can be unbalanced
list_encoded_data_filtered = []
list_labels_filtered = []
for data,label in train_loader:
    mask_batch = torch.isin(label, torch.tensor(selected_labels)) # mask will be used to chose only labels in selected_labels
    data_filtered = data[mask_batch]
    labels_filtered = label[mask_batch]
    enc_images = torus_ae.encoder2lifting(data_filtered.reshape(-1, D)).detach()
    list_encoded_data_filtered.append(enc_images)
    list_labels_filtered.append(labels_filtered)
    #print(labels_filtered)
all_encoded_data_filtered = torch.cat(list_encoded_data_filtered)
all_labels_filtered = torch.cat(list_labels_filtered)
#randomly picking N points with selected labels
torch.manual_seed(random_seed_picking_points)
indices = torch.randperm(len(all_encoded_data_filtered))[:N]  # Randomly shuffle and pick first N
encoded_points = all_encoded_data_filtered[indices]
ground_truth_labels = all_labels_filtered[indices]

In [None]:
# manifold plot
RiemannianKmeansTools.manifold_plot_selected_labels(all_encoded_data_filtered,
            all_labels_filtered,selected_labels,
            saving_folder=Path_pictures, plot_title="Manifold plot for all points with selected labels",
            file_saving_name="Manifold_plot_selected_labels")

In [None]:
# Create the scatter plot for points to cluster
RiemannianKmeansTools.manifold_plot_selected_labels(encoded_points,
            ground_truth_labels,selected_labels,
            saving_folder=Path_pictures, plot_title="Encoded Points Colored by Ground Truth Labels",
            file_saving_name="ground_truth_labels")

# Setting parameters to optimize

In [None]:
initial_centroids = RiemannianKmeansTools.initialize_centers(encoded_points, K, N) 
current_centroids = torch.clone(initial_centroids) 

if mode == "Interpolation_points":
    geodesic_solver = None
    # Initialize geodesic segments
    parameters_of_geodesics = RiemannianKmeansTools.construct_interpolation_points_on_segments_connecting_centers2encoded_data(
            encoded_points, 
            initial_centroids, 
            num_aux_points = step_count)
elif mode == "Schauder":
    geodesic_solver = ricci_regularization.Schauder.NumericalGeodesics(n_max, step_count)
    # Get Schauder basis
    N_max = geodesic_solver.schauder_bases["zero_boundary"]["N_max"]
    basis = geodesic_solver.schauder_bases["zero_boundary"]["basis"]
    # Define parameters (batch_size × N_max × dim)
    parameters_of_geodesics = torch.zeros((N, K, N_max, d), requires_grad=True)
init_parameters = torch.clone(parameters_of_geodesics) # save initial segments
# Set optimizer params
parameters = torch.nn.Parameter(parameters_of_geodesics) # Wrap as a parameter

optimizer = torch.optim.SGD([parameters], lr=learning_rate)

cluster_index_of_each_point = None
geodesics_to_nearest_centroids = None

#losses
history = []


# visualizing initialization (optional)
plt.title("K-means initialization")
plt.scatter(encoded_points[:,0],encoded_points[:,1], label = "encoded data")
plt.scatter(initial_centroids[:,0], initial_centroids[:,1], c="red", label = "initial_centroids", marker='*', s = 60)
plt.xlim(-torch.pi, torch.pi)
plt.ylim(-torch.pi, torch.pi)
plt.legend()
plt.show()

# The algorithm  

In [None]:
# timing
start_time = time.time()
# sending the nn to selected device (usually it should be cuda)
torus_ae.to(device)
# ----------------------------
# Riemannian K-means Algorithm
# ----------------------------
# Outer loop 
t = tqdm(range(num_iter_outer), desc="Outer Loop iteration: 0")
for iter_outer in t:    
    # Inner loop (refining geodesics)
    for iter_inner in range(num_iter_inner):
#for iter_outer in range(num_iter_outer):
    # Inner loop (refining geodesics)
#    for iter_inner in range(num_iter_inner):
        optimizer.zero_grad()  # Zero gradients
        # Compute the loss
        energies_of_geodesics = RiemannianKmeansTools.compute_energy(
                mode = mode, 
                parameters_of_geodesics=parameters, 
                end_points = [encoded_points, current_centroids],
                decoder = torus_ae.decoder_torus,
                geodesic_solver = geodesic_solver,
                reduction="none", device=device, 
                periodicity_mode=periodicity_mode)
        loss_geodesics = energies_of_geodesics.sum()
        # Backpropagation: compute gradients
        loss_geodesics.backward()
        # Update parameters
        optimizer.step()
        # Store the loss value
    # end inner loop
    energies_of_geodesics = energies_of_geodesics.cpu()
    # compute geodesic_curve of shape (N,K,step_count,d)
    # compute a vector of length of all geodesics shape (N,K)
    with torch.no_grad():
        geodesic_curve, lengths_of_geodesics = RiemannianKmeansTools.compute_lengths(
                mode = mode,
                parameters_of_geodesics=parameters,
                end_points = [encoded_points, current_centroids],
                decoder = torus_ae.decoder_torus,
                geodesic_solver = geodesic_solver,
                reduction="none", device=device, 
                periodicity_mode=periodicity_mode,
                return_geodesic_curve=True) 
    lengths_of_geodesics = lengths_of_geodesics.cpu() # shape (N,K)
    geodesic_curve = geodesic_curve.cpu()

    # retrieve the class membership of each point by finding the closest cluster centroid 
    cluster_index_of_each_point = torch.argmin(lengths_of_geodesics, dim=1) # shape (N)
    batch_indices = torch.arange(N) # this is needed, since   geodesic_curve[:, cluster_index_of_each_point, :, :] will produce a tensor of shape (N,N,step_count,d)
    # pick only geodesics connecting points to cluster relevant centroids where the points are assigned
    geodesics_to_nearest_centroids = geodesic_curve[batch_indices, cluster_index_of_each_point, :, :].detach() # shape (N,step_count,d)

    # v is the direction to move the cluster centroids # shape (N,d)
    v = geodesics_to_nearest_centroids[:,-1,:] - geodesics_to_nearest_centroids[:,-2,:]
    v = v / v.norm(dim=1).unsqueeze(-1) # find the last segments of the geod shape (N,d)
    
    # Compute weighted Frechet mean gradient for each cluster
    weighted_v = lengths_of_geodesics[:, 0].unsqueeze(-1) * v  # Shape: (N, d)
    # Create a one-hot encoding of the cluster indices
    one_hot_clusters = torch.nn.functional.one_hot(cluster_index_of_each_point, num_classes=K).float()  # Shape: (N, K)
    # Compute the gradients for each cluster
    Frechet_mean_gradient = one_hot_clusters.T @ weighted_v  # Shape: (K, d)
    # Update cluster centroids
    with torch.no_grad():
        current_centroids += - beta * Frechet_mean_gradient  # Update all centroids simultaneously

    # Compute average Frechet mean gradient norm among the K clusters on step iter_outer 
    average_Frechet_mean_gradient_norm = (Frechet_mean_gradient.norm(dim=1).mean()).item()

    # saving the lengths of geodesics_to_nearest_centroids
    geodesics_to_nearest_centroids_lengths = lengths_of_geodesics[batch_indices, cluster_index_of_each_point]
    
    # save intra-class variance
    intraclass_variance = (1/N) * energies_of_geodesics[batch_indices, cluster_index_of_each_point]
    
    #compute the sum of geodesic length for each cluster
    #scatter_add_ is the reverse of torch.gather
    length_of_geodesics_to_nearest_centroids_by_cluster = torch.zeros(K, dtype=geodesics_to_nearest_centroids_lengths.dtype)
    length_of_geodesics_to_nearest_centroids_by_cluster.scatter_add_(0, cluster_index_of_each_point, geodesics_to_nearest_centroids_lengths)    
    
    #compute the Intra-class variance, i.e. sum of geodesic energy for each cluster
    #scatter_add_ is the reverse of torch.gather
    intraclass_variance_by_cluster = torch.zeros(K, dtype=geodesics_to_nearest_centroids_lengths.dtype)
    intraclass_variance_by_cluster.scatter_add_(0, cluster_index_of_each_point, intraclass_variance)    
    
    history_item = {
        "intraclass_variance"                              : intraclass_variance.detach().sum().numpy(),
        "intraclass_variance_by_cluster"                   : intraclass_variance_by_cluster.unsqueeze(0).detach().numpy(), 
        "norm_Frechet_mean_gradient"                       : average_Frechet_mean_gradient_norm,
        "geodesics_to_nearest_centroids_lengths"           : geodesics_to_nearest_centroids_lengths.detach().sum().numpy(),
        "geodesics_to_nearest_centroids_lengths_by_cluster": length_of_geodesics_to_nearest_centroids_by_cluster.unsqueeze(0).detach().numpy()
    }
    history.append( history_item )
    t.set_description(f"Outer Loop iteration: {iter_outer+1}, Centroid gradient norm:{average_Frechet_mean_gradient_norm:.4f}, Total geodesic energy:{loss_geodesics:.4f}")  # Update description dynamically
#timing
end_time = time.time()
algorithm_execution_time = end_time - start_time

In [None]:
norm_Frechet_mean_gradient_history = []
geodesics_to_nearest_centroids_lengths_by_cluster_history = []
geodesics_to_nearest_centroids_lengths_history = []
intraclass_variance_by_cluster_history = []
intraclass_variance_history = []
for i in range(len(history)):
    norm_Frechet_mean_gradient_history.append(history[i]["norm_Frechet_mean_gradient"])
    geodesics_to_nearest_centroids_lengths_by_cluster_history.append(history[i]["geodesics_to_nearest_centroids_lengths_by_cluster"])
    geodesics_to_nearest_centroids_lengths_history.append(history[i]["geodesics_to_nearest_centroids_lengths"])
    intraclass_variance_by_cluster_history.append(history[i]["intraclass_variance_by_cluster"])
    intraclass_variance_history.append(history[i]["intraclass_variance"])

# Losses

In [None]:
# Plotting losses
# In this cell: 
# norm_Frechet_mean_gradient_history
# geodesics_to_nearest_centroids_lengths_history
# loss_history 
# are arrays or tensors

fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Create a figure with 1 row and 3 columns

# Plot norm_Frechet_mean_gradient_history
axes[0].plot(norm_Frechet_mean_gradient_history, marker='o', markersize=3) 
axes[0].set_title('Average norm of gradients of centroids')
axes[0].set_xlabel('Outer loop iterations')
axes[0].set_ylabel('Loss')

# Plot geodesics_to_nearest_centroids lengths by cluster
# Generate a color palette with distinct colors
colors = plt.cm.jet(torch.linspace(0, 1, K))  # Use a colormap (e.g., 'viridis')

#lengths_of_geodesics_to_nearest_centroids_concatenated = torch.cat((geodesics_to_nearest_centroids_lengths_by_cluster_history), dim=0).detach()
lengths_of_geodesics_to_nearest_centroids_concatenated = np.concatenate(geodesics_to_nearest_centroids_lengths_by_cluster_history)
for i in range(K):
    axes[1].plot(lengths_of_geodesics_to_nearest_centroids_concatenated[:, i],marker='o',markersize=3,
                 label=f'Lengths of geodesics in cluster {i}', color=colors[i])
    axes[1].set_xlabel('Outer Loop Iterations')
    axes[1].set_ylabel('Loss')
    axes[1].legend()

# Plot geodesics_to_nearest_centroids_lengths_history
axes[1].plot(geodesics_to_nearest_centroids_lengths_history, marker='o', markersize=3, 
             label='Lengths of geodesics in all clusters', color='green')
axes[1].set_title('Lengths of geodesics to nearest centroids')
axes[1].set_xlabel('Outer loop iterations')
axes[1].legend(loc= 'upper right')

intraclass_variance_concatenated = np.concatenate(intraclass_variance_by_cluster_history)
#torch.cat((intraclass_variance_by_cluster_history), dim=0).detach()
for i in range(K):
    axes[2].plot(intraclass_variance_concatenated[:, i],marker='o',markersize=3,
                 label=f'Variance of geodesics of cluster {i} ', color=colors[i])
    axes[2].set_xlabel('Outer Loop Iterations')
    axes[2].set_ylabel('Loss')
    axes[2].legend()

# Plot geodesics_to_nearest_centroids_lengths_history
axes[2].plot(intraclass_variance_history, marker='o', markersize=3,
             label='Intra-class variance', color='green')
axes[2].set_title('Intra-class variances')
axes[2].set_xlabel('Outer loop iterations')
axes[2].legend()

# Adjust layout
plt.tight_layout()
plt.savefig(f"{Path_pictures}/kmeans_losses.pdf",bbox_inches='tight', format="pdf")
plt.show()


print('Final values of losses:')
print('-----------------------')
print(f'Intra-class variance: {intraclass_variance_history[-1]:.3f}')
print(f'Lengths of geodesics to nearest centroids: {geodesics_to_nearest_centroids_lengths_history[-1]:.3f}')
print(f'Centroid gradient average norm: {norm_Frechet_mean_gradient_history[-1]:.3f}')
print(f'Centroid shift average norm: {beta*norm_Frechet_mean_gradient_history[-1]:.5f}')

# Plotting results

In [None]:
if mode == "Interpolation_points":
    geodesic_curve = RiemannianKmeansTools.geodesics_from_parameters_interpolation_points(
        parameters_of_geodesics,
        end_points = [encoded_points, current_centroids])

elif mode == "Schauder":
    geodesic_curve = RiemannianKmeansTools.geodesics_from_parameters_schauder(
        geodesic_solver, 
        parameters_of_geodesics, 
        end_points = [encoded_points, current_centroids],periodicity_mode=periodicity_mode)
RiemannianKmeansTools.plot_octopus(
    geodesic_curve.detach(), 
    memberships = cluster_index_of_each_point,
    saving_folder=Path_pictures,suffix=0, periodicity_mode=periodicity_mode)

In [None]:
print("Total centroid shifts during training:\n", (initial_centroids -  current_centroids))
average_cluster_center_shift_norm = (current_centroids - initial_centroids).detach().norm(dim = 1).mean()
print("Average norm of these shifts:", average_cluster_center_shift_norm.item())

# Saving parameters of the experiment

In [None]:
# Define experiment parameters
params = {
    "K": K,  # Number of clusters
    "N": N,  # Number of points to be clustered
    "selected_labels": selected_labels,  # Labels used for clustering
    "mode": mode,  # Can be "Schauder" or "Interpolation_points"
    
    # Specific parameters
    "n_max": n_max,  # Schauder basis complexity
    "step_count": step_count,  # Number of interpolation steps
    
    # Optimization parameters
    "beta": beta,  # Frechet mean learning rate
    "learning_rate": learning_rate,  # Learning rate for geodesics
    "num_iter_outer": num_iter_outer,  # Number of Frechet mean updates
    "num_iter_inner": num_iter_inner,  # Number of geodesic refinement iterations
    "time_secs": algorithm_execution_time, # Computed using time
    "ground_truth_labels": ground_truth_labels.tolist(),
    "Riemannian_k_means_labels": cluster_index_of_each_point.tolist(),
    "encoded_points": encoded_points.tolist()
}

# Save to JSON file
with open(Path_pictures+"/params.json", "w") as f:
    json.dump(params, f, indent=4)

print(f"Parameters saved to {Path_pictures}/params.json")


Saving separetely all the optimized geodesics, shape (N, K, step_count, d)

In [None]:
# save additional info
torch.save(geodesic_curve, Path_pictures+"/geodesic_curve.pt")
print(f"Discretized geodesic curves saved to {Path_pictures}/geodesic_curve.pt")