In [None]:
# minimal imports
import torch, yaml
import ricci_regularization
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from ricci_regularization.Schauder import NumericalGeodesics
from ricci_regularization import RiemannianKmeansTools

# Hyperparameters

In [None]:
# experiment setup
K = 2 # number of clusters
N = 10 # number of points to be clustered 

mode = "Schauder" # can be also "Interpolation_points"

# specific parameters 
n_max = 3  # Schauder basis complexity (only for Schauder)
step_count = 50  # Number of interpolation steps (for both methods)

# optimization parameters
beta = 1.e-3 # Frechet mean learning rate #beta is learning_rate_frechet_mean (outer loop)
learning_rate = 1.e-3 # learning_rate_geodesics (inner loop)
num_iter_outer = 10 # number of Frechet mean updates (outer loop)
num_iter_inner = 15 # number of geodesics refinement interations per 1 Frechet mean update (inner loop)

# Uploading the pretrained AE


In [None]:
Path_experiment = '../experiments/MNIST_Setting_3_config.yaml'
with open(Path_experiment, 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
print("Experiment results loaded successfully.")
# Loading data
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders
print("Data loaders created successfully.")

# Loading the pre-tained AE
torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config)
print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)

# Picking dataset to cluster

In [None]:
# collecting the dataset that we want to cluster
# we use some random N points of the test dataset that we will cluster
# This could be done differently, e.g. by simply picking random points
D = yaml_config["architecture"]["input_dim"]
d = yaml_config["architecture"]["latent_dim"]
data = test_dataset.data

# Limit dataset to the first N samples
subset_indices = list(range(N))
mnist_subset = torch.utils.data.Subset(data, subset_indices)
dataset_batch_size = 128
dataloader = torch.utils.data.DataLoader(mnist_subset, batch_size=dataset_batch_size, shuffle=False)
# encoding into latent space
torus_ae.cpu()
torus_ae.eval()

# Encode samples into latent space
encoded_points = []
with torch.no_grad():  # No need to compute gradients
    for images in dataloader:
#        print(images.shape)
        latent = torus_ae.encoder2lifting( (images.reshape(-1, D)).to(torch.float32) )  # Pass images through the encoder
        encoded_points.append(latent)
encoded_points = torch.cat(encoded_points)


# Setting parameters to optimize

In [None]:
initial_centroids = RiemannianKmeansTools.initialize_centers(encoded_points, K, N) 
current_centroids = torch.clone(initial_centroids) 

if mode == "Interpolation_points":
    geodesic_solver = None
    # Initialize geodesic segments
    parameters_of_geodesics = RiemannianKmeansTools.construct_interpolation_points_on_segments_connecting_centers2encoded_data(
            encoded_points, 
            initial_centroids, 
            num_aux_points = step_count)
elif mode == "Schauder":
    geodesic_solver = NumericalGeodesics(n_max, step_count)
    # Get Schauder basis
    N_max = geodesic_solver.schauder_bases["zero_boundary"]["N_max"]
    basis = geodesic_solver.schauder_bases["zero_boundary"]["basis"]
    # Define parameters (batch_size × N_max × dim)
    parameters_of_geodesics = torch.zeros((N, K, N_max, d), requires_grad=True)
init_parameters = torch.clone(parameters_of_geodesics) # save initial segments
# Set optimizer params
parameters = torch.nn.Parameter(parameters_of_geodesics) # Wrap as a parameter

optimizer = torch.optim.SGD([parameters], lr=learning_rate)

cluster_index_of_each_point = None
meaningful_geodesics = None

#losses
meaningful_geodesics_loss_history = []
meaningful_geodesics_loss_history_by_cluster = []
norm_Frechet_mean_gradient_history = []

"""
# visualizing initialization (optional)
plt.scatter(encoded_points[:,0],encoded_points[:,1], label = "encoded data")
plt.scatter(initial_centroids[:,0], initial_centroids[:,1], c="red", label = "initial_centroids", marker='*', s = 60)
plt.xlim(-torch.pi, torch.pi)
plt.ylim(-torch.pi, torch.pi)
plt.legend()
plt.show()
"""

# The algorithm  

In [None]:
# Outer loop 
for iter_outer in range(num_iter_outer):
    # Inner loop (refining geodesics)
    for iter_inner in range(num_iter_inner):
        optimizer.zero_grad()  # Zero gradients
        # Compute the loss
        loss_geodesics = RiemannianKmeansTools.compute_energy(
                mode = mode, 
                parameters_of_geodesics=parameters, 
                end_points = [encoded_points, current_centroids],
                decoder = torus_ae.decoder_torus,
                geodesic_solver = geodesic_solver)
        # Backpropagation: compute gradients
        loss_geodesics.backward()
        # Update parameters
        optimizer.step()
        # Store the loss value
    # end inner loop

    # compute a vector of length of all geodesics shape (N,K)
    lengths_of_geodesics = RiemannianKmeansTools.compute_lengths(
            mode = mode,
            parameters_of_geodesics=parameters,
            end_points = [encoded_points, current_centroids],
            decoder = torus_ae.decoder_torus,
            geodesic_solver = geodesic_solver) 
    
    # retrieve the class membership of each point by finding the closest cluster centroid shape (N)
    cluster_index_of_each_point = torch.argmin(lengths_of_geodesics, dim=1) 
    
    batch_indices = torch.arange(N)
    
    if mode == "Interpolation_points":
        geodesic_curve = RiemannianKmeansTools.geodesics_from_parameters_interpolation_points(
                parameters_of_geodesics = parameters_of_geodesics, 
                end_points = [encoded_points, current_centroids])
    elif mode == "Schauder":
        #geodesic_solver = NumericalGeodesics(n_max, step_count)
        geodesic_curve = RiemannianKmeansTools.geodesics_from_parameters_schauder(
                geodesic_solver = geodesic_solver, 
                parameters_of_geodesics = parameters_of_geodesics, 
                end_points = [encoded_points, current_centroids])
        
    # pick only geodesics connecting points to cluster centroids where the points are assigned shape (N,m,d)
    meaningful_geodesics = geodesic_curve[batch_indices, cluster_index_of_each_point, :, :].detach() 

    # v is the direction to move the cluster centroids
    v = meaningful_geodesics[:,-1,:] - meaningful_geodesics[:,-2,:]
    v = v / v.norm(dim=1).unsqueeze(-1) # find the last segments of the geod shape (N,d)
    
    #---------------------------------------------------------------    
    # Update cluster centroids with weight beta:
    #---------------------------------------------------------------
    # Assuming cluster_index_of_each_point is a tensor of shape (N,) containing cluster indices
    # and K is the number of clusters
    # Expand cluster_index_of_each_point to index into v and lengths_of_geodesics
    cluster_index_of_each_point_expanded = cluster_index_of_each_point.unsqueeze(-1).expand(-1, v.size(-1))
    # Compute weighted Frechet mean gradient for each cluster
    weighted_v = lengths_of_geodesics[:, 0].unsqueeze(-1) * v  # Shape: (N, d)

    # Create a one-hot encoding of the cluster indices
    one_hot_clusters = torch.nn.functional.one_hot(cluster_index_of_each_point, num_classes=K).float()  # Shape: (N, K)

    # Compute the gradients for each cluster
    Frechet_mean_gradient = one_hot_clusters.T @ weighted_v  # Shape: (K, d)
    
    # Update cluster centroids
    with torch.no_grad():
        current_centroids += - beta * Frechet_mean_gradient  # Update all centroids simultaneously

    # Compute average Frechet mean gradient norm
    average_Frechet_mean_gradient_norm = (Frechet_mean_gradient.norm(dim=1).mean()).item()
    # Append to norm history
    norm_Frechet_mean_gradient_history.append(average_Frechet_mean_gradient_norm)

    # saving the lengths of meaningful geodesics
    meaningful_geodesics_lengths = torch.gather(lengths_of_geodesics,1,cluster_index_of_each_point_expanded)[:,0]
    meaningful_geodesics_loss_history.append( meaningful_geodesics_lengths.detach().sum().item() )

    #compute the sum of geodesic length for each cluster
    total_length_of_meaningful_geodesics_by_cluster = torch.zeros(K, dtype=meaningful_geodesics_lengths.dtype)
    total_length_of_meaningful_geodesics_by_cluster.scatter_add_(0, cluster_index_of_each_point, meaningful_geodesics_lengths)    
    meaningful_geodesics_loss_history_by_cluster.append(total_length_of_meaningful_geodesics_by_cluster.unsqueeze(0))

# Losses

In [None]:
# Assuming norm_Frechet_mean_gradient_history, meaningful_geodesics_loss_history, loss_history are arrays or tensors
fig, axes = plt.subplots(1, 2, figsize=(K*5, 5))  # Create a figure with 1 row and 3 columns

# Plot norm_Frechet_mean_gradient_history
axes[0].plot(norm_Frechet_mean_gradient_history, marker='o', markersize=3, label='Frechet mean update history')
axes[0].set_title('Averege shift of centers (proxy of Fréchet mean gradient norm)')
axes[0].set_xlabel('Outer loop iterations')
axes[0].set_ylabel('Loss')
axes[0].legend()

# Plot meaningful geodesic lengths by cluster
# Generate a color palette with distinct colors
colors = plt.cm.jet(torch.linspace(0, 1, K))  # Use a colormap (e.g., 'viridis')

lengths_of_meaningful_geodesics_concatenated = torch.cat((meaningful_geodesics_loss_history_by_cluster), dim=0).detach()
for i in range(K):
    axes[1].plot(lengths_of_meaningful_geodesics_concatenated[:, i],marker='o',markersize=3,
                 label=f'Cluster {i} geodesics length', color=colors[i])
    axes[1].set_title('Meaningful geodesics length by cluster')
    axes[1].set_xlabel('Outer Loop Iterations')
    axes[1].set_ylabel('Loss')
    axes[1].legend()

# Plot meaningful_geodesics_loss_history
axes[1].plot(meaningful_geodesics_loss_history, marker='o', markersize=3, label='All clusters geodesics length', color='green')
axes[1].set_title('Meaningfull geodesics length')
axes[1].set_xlabel('Outer loop iterations')
axes[1].legend()

# Adjust layout
plt.tight_layout()
plt.show()

# Plotting

In [None]:
if mode == "Interpolation_points":
    geodesic_curve = RiemannianKmeansTools.geodesics_from_parameters_interpolation_points(
        parameters_of_geodesics,
        end_points = [encoded_points, current_centroids])
elif mode == "Schauder":
    geodesic_curve = RiemannianKmeansTools.geodesics_from_parameters_schauder(
        geodesic_solver, 
        parameters_of_geodesics, 
        end_points = [encoded_points, current_centroids])

RiemannianKmeansTools.plot_octopus(
    geodesic_curve.detach(), 
    memberships = cluster_index_of_each_point,
    meaningful_geodesics = meaningful_geodesics)

In [None]:
print("center shifts:\n", (initial_centroids -  current_centroids))
average_cluster_center_shift_norm = (current_centroids - initial_centroids).detach().norm(dim = 1).mean()
print("Average center's shift:", average_cluster_center_shift_norm.item())