# Imports

In [None]:
import torch, yaml
import ricci_regularization
import matplotlib.pyplot as plt
import os
import numpy as np
import imageio


# I. Data loading

In [None]:
Path_experiment = '../../experiments/MNIST_Setting_3_config.yaml'
with open(Path_experiment, 'r') as yaml_file:
#with open('../experiments/Synthetic_Setting_1/Synthetic_Setting_1_config.yaml', 'r') as yaml_file:
#with open('../experiments/Swissroll_exp5_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)
# Load data loaders based on YAML configuration
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
print("Experiment results loaded successfully.")
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

In [None]:
torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config,additional_path='../')

print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)

In [None]:
# Create a folder to store the PNG images
shots_folder_name = "/generated_pics"
shots_folder_adress = '../../experiments/'+yaml_config["experiment"]["name"]+ shots_folder_name
if not os.path.exists(shots_folder_adress):
    os.makedirs(shots_folder_adress)
    print("A folder created for saved images to create a gif at:", shots_folder_adress)

In [None]:
batch_size = 128
K = 2 # number of clusters
data = test_dataset.data
N = 7 #len(data)
m = 10 # intermediate points on every geodesic
D = yaml_config["architecture"]["input_dim"]
d = yaml_config["architecture"]["latent_dim"]

# Limit dataset to the first n samples
subset_indices = list(range(N))
mnist_subset = torch.utils.data.Subset(data, subset_indices)
dataloader = torch.utils.data.DataLoader(mnist_subset, batch_size=batch_size, shuffle=False)

# encoding into latent space
torus_ae.cpu()
torus_ae.eval()

# Encode samples into latent space
encoded_points = []
with torch.no_grad():  # No need to compute gradients
    for images in dataloader:
#        print(images.shape)
        latent = torus_ae.encoder2lifting( (images.reshape(-1, D)).to(torch.float32) )  # Pass images through the encoder
        encoded_points.append(latent)
encoded_points = torch.cat(encoded_points)


In [None]:
# visualizing initialization
centers = ricci_regularization.initialize_centers(encoded_points, K, N)

plt.scatter(encoded_points[:,0],encoded_points[:,1], label = "encoded data")
plt.scatter(centers[:,0], centers[:,1], c="red", label = "centers", marker='*', s = 60)
plt.xlim(-torch.pi, torch.pi)
plt.ylim(-torch.pi, torch.pi)
plt.legend()
plt.show()

In [None]:
segments = ricci_regularization.connect_centers2encoded_data_with_segments(encoded_points, centers, m)

In [None]:
ricci_regularization.plot_octopus(segments, saving_adress=shots_folder_adress, iter=0,silent=False)

# Loss function and optimizer parameters setting

In [None]:
def compute_energy(points_on_geodesics, decoder=torus_ae.decoder_torus, num_data_points = N, num_classes = K, num_aux_points = m, latent_dim = d):
    #assert points_on_geodesics.shape == torch.Size([num_data_points, num_classes, num_aux_points, latent_dim])
    decoded_points = decoder(points_on_geodesics)
    computed_energy = (( decoded_points[:,:,1:,:] - decoded_points[:,:,:-1,:] ) ** 2 ).sum() # comute sum of Euclidean energies of all the curves in R^D
    # make sure that optimization is parallelized
    # Warning! the outpiut is the single scalar, i.e the sum of all the energies
    return computed_energy

def compute_lengths(points_on_geodesics, decoder=torus_ae.decoder_torus, num_data_points = N, num_classes = K, num_aux_points = m, latent_dim = d):
    #assert segments.shape == torch.Size([num_data_points, num_classes, num_aux_points, latent_dim])
    if points_on_geodesics.shape != torch.Size([num_data_points, num_classes, num_aux_points, latent_dim]):
        points_on_geodesics = points_on_geodesics.unsqueeze(0)
    decoded_points = decoder(points_on_geodesics)
    tangent_vectors = decoded_points[:,:,1:,:] - decoded_points[:,:,:-1,:]
    computed_lengths = torch.sqrt((tangent_vectors**2).sum(dim=(-2,-1))) # comute Euclidean compute_lengths of the curves in R^D
    return computed_lengths
loss_geodesics = compute_energy(points_on_geodesics=segments, decoder=torus_ae.decoder_torus)
loss_geodesics

In [None]:
learning_rate = 0.5e-3
num_iter = 1500
# Define parameters (for example, weights to optimize)
segments = ricci_regularization.connect_centers2encoded_data_with_segments(encoded_points, centers, m) #initialize the segents between centers and data points
init_segments = torch.clone(segments)
segments = torch.nn.Parameter(segments) # Wrap as a parameter

optimizer = torch.optim.SGD([segments], lr=learning_rate)

# Inner loop (refining geodesics)

In [None]:
loss_history = []
for iter_num in range(num_iter):
    
    optimizer.zero_grad()  # Zero gradients

    # Compute the loss
    loss_geodesics = compute_energy(points_on_geodesics=segments, decoder=torus_ae.decoder_torus)

    # Backpropagation: compute gradients
    loss_geodesics.backward()

    # Zero out gradients for the first and last points (don't want them updated)
    segments.grad[:, :, 0, :] = 0.  # First points along 'geodesics' (data_point)
    segments.grad[:, :, -1, :] = 0.  # Last points along 'geodesics' (center)

    # Update parameters
    optimizer.step()

    # Store the loss value
    loss_history.append(loss_geodesics.item())
    print(f"Iteration #{iter_num + 1}, loss: {loss_geodesics.item():.3f}")    

In [None]:
ricci_regularization.plot_octopus(segments.detach())

In [None]:
# check that the first and the last points did not move
assert torch.equal( init_segments[:,:,0,:], segments[:,:,0,:])
assert torch.equal( init_segments[:,:,-1,:], segments[:,:,-1,:])

In [None]:
segments.shape

# Outer loop: Frechet mean update + membership update

In [None]:
beta = 1e-2 # Frechet mean learning rate
lengths_of_geod = compute_lengths(segments, torus_ae.decoder_torus) # comute a vector of length of all geodesics shape (N,K)
memberships = torch.argmin(lengths_of_geod, dim=1) # retrieve the class membership of each point by finding the closest cluster center shape (N)

#batch_indices = torch.arange(N)
#meaningful_geodesics = segments[batch_indices, memberships, :, :] # pick only geodesics connecting points to cluster centers where the points are assigned shape (N,m,d)
meaningful_geodesics = segments[:, memberships, :, :] # pick only geodesics connecting points to cluster centers where the points are assigned shape (N,m,d)
v = meaningful_geodesics[:,-1,:] - meaningful_geodesics[:,-2,:] #!!! think of weighted average between the last vector and some previous ones
# Renormalization is numerically unstable. 
v = v / v.norm(dim=1).unsqueeze(-1) # find the last segments of the geod shape (N,d)
# numerical stability?

# Redo without the loop and comment all operations
# update cluster centers with weight beta
for i in range(K):
    cluster_mask = memberships == i 
    v_i = v[cluster_mask] 
    l_i = lengths_of_geod[cluster_mask][:,0]
    with torch.no_grad():
        FM_gradient = torch.sum( l_i.unsqueeze(-1) * v_i, dim=0 ) # output is d dimensional vector
        #centers[i] = centers[i] - beta * FM_gradient
        segments[:, i, -1, :] += - beta * FM_gradient # update i-th cluster center ( only moving the very last point on a geodesic)
        print(f"\nNorm of gradient of FM lossfor {i}-th cluster", FM_gradient.norm().item())
#centers = centers.detach()
#print(centers)

In [None]:
v_i.shape

In [None]:
l_i.unsqueeze(-1).shape

In [None]:
print("segments shape:", segments.shape)
print("memberships shape:", memberships.shape)

In [None]:
# print 
print("lengths of geodesics", compute_lengths(segments) )
print("memberships", memberships)

In [None]:
ricci_regularization.plot_octopus(segments.detach(),memberships,meaningful_geodesics.detach())

# version -1 without local charts update

In [None]:
beta = 0.5e-3 # Frechet mean learning rate
learning_rate = 1e-3
num_iter_outer = 200
num_geod_iter = 10 # number of geodesics refinement interations per 1 FM update
memberships = None
meaningful_geodesics = None
loss_history = []
meaningful_geodesics_loss_history = []
norm_FM_grad_history = []

# Initialize geodesic segments
centers = ricci_regularization.initialize_centers(encoded_points, K, N)
segments = ricci_regularization.connect_centers2encoded_data_with_segments(encoded_points, centers, m) #initialize the segents between centers and data points
init_segments = torch.clone(segments) # save initial segments
# Set optimizer params:
segments = torch.nn.Parameter(segments) # Wrap as a parameter

optimizer = torch.optim.SGD([segments], lr=learning_rate)

In [None]:
ricci_regularization.plot_octopus(segments.detach())

In [None]:
# add stopping criterium e.g. delta of energy

In [None]:
# Outer loop 
for iter_outer in range(num_iter_outer):
    # Inner loop (refining geodesics)
    for iter_inner in range(num_geod_iter):
        
        optimizer.zero_grad()  # Zero gradients

        # Compute the loss
        loss_geodesics = compute_energy(points_on_geodesics=segments, decoder=torus_ae.decoder_torus)

        # Backpropagation: compute gradients
        loss_geodesics.backward()

        # Zero out gradients for the first and last points (don't want them updated)
        segments.grad[:, :, 0, :] = 0.  # First points along 'geodesics'
        segments.grad[:, :, -1, :] = 0.  # Last points along 'geodesics'

        # Update parameters
        optimizer.step()

        # Store the loss value
        loss_history.append(loss_geodesics.item())
        # saving plots
        #ricci_regularization.plot_octopus(segments,memberships=memberships,meaningful_geodesics=meaningful_geodesics, 
        #             saving_adress=shots_folder_adress, iter=iter_outer*num_geod_iter + iter_inner,silent=True)
        #print(f"Iteration #{iter_inner + 1}, loss: {loss_geodesics.item():.3f}")    
    lengths_of_geod = compute_lengths(segments, torus_ae.decoder_torus) # comute a vector of length of all geodesics shape (N,K)
    memberships = torch.argmin(lengths_of_geod, dim=1) # retrieve the class membership of each point by finding the closest cluster center shape (N)

    batch_indices = torch.arange(N)
    meaningful_geodesics = segments[batch_indices, memberships, :, :].detach() # pick only geodesics connecting points to cluster centers where the points are assigned shape (N,m,d)
    
    # saving the lengths of meaningful geodesics
    meaningful_geodesics_lengths = compute_lengths(meaningful_geodesics)
    meaningful_geodesics_loss_history.append( meaningful_geodesics_lengths.detach().mean().item() )

    v = meaningful_geodesics[:,-1,:] - meaningful_geodesics[:,-2,:]
    v = v / v.norm(dim=1).unsqueeze(-1) # find the last segments of the geod shape (N,d)
    # update cluster centers with weight beta
    for i in range(K):
        average_FM_grad_norm = 0.
        cluster_mask = memberships == i 
        v_i = v[cluster_mask] 
        l_i = lengths_of_geod[cluster_mask][:,0]
        with torch.no_grad():
            FM_gradient = torch.sum( l_i.unsqueeze(-1) * v_i, dim=0 )
            segments[:, i, -1, :] += - beta * FM_gradient # update i-th cluster center ( only moving the very last point on a geodesic)
            #print(f"\nNorm of gradient of FM lossfor {i}-th cluster", FM_gradient.norm().item())
        average_FM_grad_norm += FM_gradient.norm().item()/K
        # !save all of cluster FM_grad seperately
    norm_FM_grad_history.append(average_FM_grad_norm)

# Plotting losses

In [None]:
# Assuming norm_FM_grad_history, meaningful_geodesics_loss_history, loss_history are arrays or tensors
fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Create a figure with 1 row and 3 columns

# Plot norm_FM_grad_history
axes[0].plot(norm_FM_grad_history, marker='o', label='FM Grad History')
axes[0].set_title('Averege shift of centers (Fréchet mean gradient norm)')
axes[0].set_xlabel('Outer loop iterations')
axes[0].set_ylabel('Loss')
axes[0].legend()

# Plot meaningful_geodesics_loss_history
axes[1].plot(meaningful_geodesics_loss_history, marker='o', label='Geodesics Loss History', color='orange')
axes[1].set_title('Meaningfull geodesics length')
axes[1].set_xlabel('Outer loop iterations')
axes[1].legend()

# Plot loss_history
axes[2].plot(loss_history, label='All geodesics length', color='green')
axes[2].set_title('All geodesics length')
axes[2].set_xlabel(f'All iterations: {num_geod_iter} inner  per outer loop iter')
axes[2].legend()

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
ricci_regularization.plot_octopus(segments, memberships=memberships,meaningful_geodesics=meaningful_geodesics)

# Creating a GIF

In [None]:
# Create a GIF from the PNG images
png_files = sorted([f for f in os.listdir(shots_folder_adress) if f.endswith('.png')])
images = []
for file in png_files:
    image_path = os.path.join(shots_folder_adress, file)
    images.append(imageio.imread(image_path))  # Read each PNG image

# Create the GIF
output_gif = "output_animation.gif"
imageio.mimsave(shots_folder_adress + '/' + output_gif, images, duration=0.001)  # Adjust the duration for frame speed

print(f"GIF created and saved as {output_gif}")

# Video

In [None]:
import cv2
import os

In [None]:
# Specify the directory containing PNGs and the output video name
images_folder = shots_folder_adress
output_video = "output_video.avi"

# Set video parameters
frame_rate = 30
images = sorted([img for img in os.listdir(images_folder) if img.endswith(".png")])
if not images:
    raise ValueError("No PNG images found in the specified directory.")


In [None]:

# Read the first image to get dimensions
first_image_path = os.path.join(images_folder, images[0])
frame = cv2.imread(first_image_path)
height, width, layers = frame.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')  # You can use other codecs like 'mp4v'
video = cv2.VideoWriter(output_video, fourcc, frame_rate, (width, height))

for image in images:
    img_path = os.path.join(images_folder, image)
    frame = cv2.imread(img_path)
    video.write(frame)

video.release()
cv2.destroyAllWindows()
print(f"Video saved as {output_video}")


In [None]:
average_cluster_center_shift_norm = (segments[0,:,-1,:] - init_segments[0,:,-1,:]).detach().norm(dim = 1).mean()
print("Averega center's shift:", average_cluster_center_shift_norm.item())

# Recentering local charts

In [None]:
new_centers = torch.tensor([[ 2., -2.5]])
#        [-2.5458,  2.2106],
#        [-1.0967,  2.2219]])

In [None]:
new_segments = ricci_regularization.connect_centers2encoded_data_with_segments(encoded_points, centers=new_centers, num_aux_points=10)

In [None]:
ricci_regularization.plot_octopus(segments=new_segments)#, xlim=None, ylim=None)

In [None]:
def compute_updated_segments(segments):
    # recognize the shape of segments:
    N = segments.shape[0] # num data points
    K = segments.shape[1] # num clusters
    m = segments.shape[2] # num auxilliary points
    d = segments.shape[3] # latent dimension
    updated_segments = segments.clone()
    # adapting segments to their local charts
    for i in range(N): # this is very bad! REDO with a mask
        for j in range(K):
            for dim in range(d):
                if torch.abs( segments[i,j,-1,dim] - segments[i,j,0,dim] ) > torch.pi:
                    # choose direction of the point shift
                    sign = torch.sgn( segments[i,j,-1,dim] - segments[i,j,0,dim] )
                    shift = sign * 2 * torch.pi
                    # shift the point 
                    updated_segments[i,j,0,dim] += shift
    # Generate interpolation parameters (m values between 0 and 1)
    t = torch.linspace(0, 1, steps=m).to(encoded_points.device).view(1, 1, m, 1)  # Shape: (1, 1, m, 1)

    new_centers = segments[:,:,-1,:]
    # Reshape encoded_points and centers for broadcasting
    new_start_points = updated_segments[:,0,0,:].unsqueeze(1).unsqueeze(2)   # Shape: (n, 1, 1, d)
    centers_expanded = new_centers.unsqueeze(2)        # Shape: (1, k, 1, d)

    # Compute all intermediate points using linear interpolation
    updated_segments = new_start_points + t * (centers_expanded - new_start_points)  # Shape: (n, k, m, d)
    return updated_segments
def mod_pi(segments): # only for plotting, local chart quiting has to be fixed
    # Returns the coordinates of points in the initial local chart
    return torch.remainder(segments + torch.pi, 2*torch.pi) - torch.pi

In [None]:
updated_segments = compute_updated_segments(new_segments).detach()
#updated_segments_mod_pi = mod_pi(updated_segments)

In [None]:
plt.scatter(new_segments[:,:,0,0], new_segments[:,:,0,1], c = 'green',zorder = 10, label = "before shift")
plt.scatter(updated_segments[:,:,0,0], updated_segments[:,:,0,1], c = 'magenta', s = 100, label = "after shift")
# add an arrow from not upd to upd segments
plt.xlim(-2*torch.pi, 2*torch.pi)
plt.ylim(-2*torch.pi, 2*torch.pi)
plt.legend()
plt.show()


In [None]:
ricci_regularization.plot_octopus(new_segments, xlim=2*torch.pi,ylim=2*torch.pi)
# add the grid plotting option

In [None]:
ricci_regularization.plot_octopus(updated_segments,xlim=2*torch.pi, ylim=2*torch.pi)

In [None]:
compute_energy(new_segments)

In [None]:
compute_energy(updated_segments)

In [None]:
#ricci_regularization.plot_octopus(updated_segments_mod_pi,xlim= torch.pi, ylim=torch.pi)

In [None]:
#compute_energy(updated_segments_mod_pi)

In [None]:
shift_array = [ -2 * torch.pi, 0., 2 * torch.pi]
segments_array = []
for shift_x in shift_array:
    for shift_y in shift_array:
        segments_array.append(updated_segments + shift_x * torch.tensor([1.,0.]) + shift_y * torch.tensor([0.,1.]))

In [None]:
for segments in segments_array:
    for i in range(N):
            for j in range(K):
                #if j == 0:
                #    color = "blue"
                #else:
                color = "orange"
                plt.plot(segments[i,j,:,0], segments[i,j,:,1],'-',marker='o', c = color, markersize=3)
        # plot centers
    centers = segments[0,:,-1,:]
    # plot the datapoints (the starting points on all the geodesics, colored by memberships if specified):
    plt.scatter(centers[:,0], centers[:,1], c="red", label = "centers", marker='*', edgecolor='black', s = 170,zorder = 10)
    plt.scatter(segments[:,0,0,0], segments[:,0,0,1], c="green", label = "centers", marker='o', s = 30,zorder = 10)
plt.xlim(-torch.pi, torch.pi)
plt.ylim(-torch.pi, torch.pi)

# version 0 (with periodicity)

In [None]:
beta = 0.5e-3 # Frechet mean learning rate
learning_rate = 1e-3
num_iter_outer = 1
num_geod_iter = 20 # number of geodesics refinement interations per 1 FM update
memberships = None
meaningful_geodesics = None
loss_history = []
meaningful_geodesics_loss_history = []
norm_FM_grad_history = []

# Initialize geodesic segments
centers = ricci_regularization.initialize_centers(encoded_points, K, N)
segments = ricci_regularization.connect_centers2encoded_data_with_segments(encoded_points, centers, m) #initialize the segents between centers and data points
init_segments = torch.clone(segments) # save initial segments
# Set optimizer params:
segments = torch.nn.Parameter(segments) # Wrap as a parameter

optimizer = torch.optim.SGD([segments], lr=learning_rate)
ricci_regularization.plot_octopus(segments.detach())

In [None]:
# Outer loop 
for iter_outer in range(num_iter_outer):
    # update the segments to take into account periodicity
    #segments = compute_updated_segments(segments)
    # Inner loop (refining geodesics)
    for iter_inner in range(num_geod_iter):
        
        optimizer.zero_grad()  # Zero gradients

        # Compute the loss
        loss_geodesics = compute_energy(points_on_geodesics=segments, decoder=torus_ae.decoder_torus)

        # Backpropagation: compute gradients
        loss_geodesics.backward()

        # Zero out gradients for the first and last points (don't want them updated)
        segments.grad[:, :, 0, :] = 0.  # First points along 'geodesics'
        segments.grad[:, :, -1, :] = 0.  # Last points along 'geodesics'

        # Update parameters
        optimizer.step()

        # Store the loss value
        loss_history.append(loss_geodesics.item())
        #ricci_regularization.plot_octopus(segments,memberships=memberships,meaningful_geodesics=meaningful_geodesics, 
        #             saving_adress=shots_folder_adress, iter=iter_outer*num_geod_iter + iter_inner,silent=True)
        #print(f"Iteration #{iter_inner + 1}, loss: {loss_geodesics.item():.3f}")    
    lengths_of_geod = compute_lengths(segments, torus_ae.decoder_torus) # comute a vector of length of all geodesics shape (N,K)
    memberships = torch.argmin(lengths_of_geod, dim=1) # retrieve the class membership of each point by finding the closest cluster center shape (N)

    batch_indices = torch.arange(N)
    meaningful_geodesics = segments[batch_indices, memberships, :, :].detach() # pick only geodesics connecting points to cluster centers where the points are assigned shape (N,m,d)
    
    # saving the lengths of meaningful geodesics
    meaningful_geodesics_lengths = compute_lengths(meaningful_geodesics)
    meaningful_geodesics_loss_history.append( meaningful_geodesics_lengths.detach().mean().item() )

    v = meaningful_geodesics[:,-1,:] - meaningful_geodesics[:,-2,:]
    v = v / v.norm(dim=1).unsqueeze(-1) # find the last segments of the geod shape (N,d)
    # update cluster centers with weight beta
    for i in range(K):
        average_FM_grad_norm = 0.
        cluster_mask = memberships == i 
        v_i = v[cluster_mask] 
        l_i = lengths_of_geod[cluster_mask][:,0]
        with torch.no_grad():
            FM_gradient = torch.sum( l_i.unsqueeze(-1) * v_i, dim=0 )
            segments[:, i, -1, :] += - beta * FM_gradient # update i-th cluster center ( only moving the very last point on a geodesic)
            #print(f"\nNorm of gradient of FM lossfor {i}-th cluster", FM_gradient.norm().item())
        average_FM_grad_norm += FM_gradient.norm().item()/K
    norm_FM_grad_history.append(average_FM_grad_norm)

In [None]:
# Assuming norm_FM_grad_history, meaningful_geodesics_loss_history, loss_history are arrays or tensors
fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Create a figure with 1 row and 3 columns

# Plot norm_FM_grad_history
axes[0].plot(norm_FM_grad_history, marker='o', label='FM Grad History')
axes[0].set_title('Averege shift of centers (Fréchet mean gradient norm)')
axes[0].set_xlabel('Outer loop iterations')
axes[0].set_ylabel('Loss')
axes[0].legend()

# Plot meaningful_geodesics_loss_history
axes[1].plot(meaningful_geodesics_loss_history, marker='o', label='Geodesics Loss History', color='orange')
axes[1].set_title('Meaningfull geodesics length')
axes[1].set_xlabel('Outer loop iterations')
axes[1].legend()

# Plot loss_history
axes[2].plot(loss_history, label='All geodesics length', color='green')
axes[2].set_title('All geodesics length')
axes[2].set_xlabel(f'All iterations: {num_geod_iter} inner  per outer loop iter')
axes[2].legend()

# Adjust layout
plt.tight_layout()
plt.show()

# here bulshit starts

In [None]:
def set_optimizer_parameters(learning_rate, max_iterations):
    """
    Sets the parameters for the optimizer.
    
    Parameters:
        learning_rate (float): The learning rate for gradient descent.
        max_iterations (int): The maximum number of iterations for optimization.
    
    Returns:
        dict: Optimizer parameters.
    """
    optimizer_params = {
        'learning_rate': learning_rate,
        'max_iterations': max_iterations
    }
    print(f"Optimizer parameters set: {optimizer_params}")
    return optimizer_params

def update_frechet_mean(data, memberships, K):
    """
    Updates the cluster centers (Frechet mean update) using PyTorch.
    
    Parameters:
        data (torch.Tensor): Data points, shape (n_samples, n_features).
        memberships (torch.Tensor): Membership array, shape (n_samples,).
        K (int): Number of clusters.
    
    Returns:
        torch.Tensor: Updated cluster centers, shape (K, n_features).
    """
    n_samples, n_features = data.shape
    updated_centers = torch.zeros((K, n_features))
    
    for k in range(K):
        cluster_mask = memberships == k  # Mask for points in cluster k
        cluster_points = data[cluster_mask]
        if cluster_points.size(0) > 0:
            updated_centers[k] = cluster_points.mean(dim=0)
    
    print(f"Updated Frechet means: {updated_centers}")
    return updated_centers

def geodesic_update(data, centers, memberships, learning_rate):
    """
    Refines geodesic approximations and updates the parameters using PyTorch.
    
    Parameters:
        data (torch.Tensor): Data points, shape (n_samples, n_features).
        centers (torch.Tensor): Current cluster centers, shape (K, n_features).
        memberships (torch.Tensor): Membership array, shape (n_samples,).
        learning_rate (float): Learning rate for updates.
    
    Returns:
        torch.Tensor: Updated geodesic approximations, shape (n_samples, K, n_features).
    """
    n_samples, n_features = data.shape
    K = centers.shape[0]
    
    # Geodesic approximation: linear interpolation as a simple example
    geodesics = torch.zeros((n_samples, K, n_features))
    
    for i in range(n_samples):
        for k in range(K):
            geodesics[i, k] = data[i] + learning_rate * (centers[k] - data[i])
    
    print(f"Updated geodesics: {geodesics}")
    return geodesics

In [None]:
def initialize_geodesics(data, centers, m=20):
    """
    Initializes geodesics connecting data points to cluster centers.
    
    Parameters:
        data (torch.Tensor): Data points, shape (n_samples, n_features).
        centers (torch.Tensor): Cluster centers, shape (K, n_features).
        m (int): Number of intermediate points on each geodesic.
    
    Returns:
        torch.Tensor: Points on geodesics, shape (n_samples, K, m, n_features).
    """
    n_samples, n_features = data.shape
    K = centers.shape[0]
    
    # Initialize geodesics tensor
    geodesics = torch.zeros((n_samples, K, m, n_features))
    
    # Generate geodesics
    for i in range(n_samples):
        for l in range(K):
            geodesic_start = data[i]
            geodesic_end = centers[l]
            
            # Generate m evenly spaced points along the straight-line geodesic
            for j in range(m):
                t = j / (m - 1)  # Normalized position along the geodesic [0, 1]
                geodesics[i, l, j] = (1 - t) * geodesic_start + t * geodesic_end
    
    print(f"Geodesics shape: {geodesics.shape}")
    return geodesics

In [None]:
"""
# Example data
data = torch.rand(10, 2)  # 10 points in 3D
K = 3  # Number of clusters
learning_rate = 0.01
max_iterations = 100

# Initialization
centers, probabilities = initialize_centers(data, K)

# Optimizer parameters
optimizer_params = set_optimizer_parameters(learning_rate, max_iterations)

# Dummy memberships (random assignment for initialization)
memberships = torch.randint(0, K, (data.shape[0],))

# Geodesic update
geodesics = geodesic_update(data, centers, memberships, optimizer_params['learning_rate'])

# Frechet mean update
updated_centers = update_frechet_mean(data, memberships, K)
"""