# Imports

In [None]:
import torch, yaml, os
import ricci_regularization
import matplotlib.pyplot as plt
import cv2 #to make videos

# I. Data loading

In [None]:
Path_experiment = '../../experiments/MNIST_Setting_3_config.yaml'
with open(Path_experiment, 'r') as yaml_file:
#with open('../experiments/Synthetic_Setting_1/Synthetic_Setting_1_config.yaml', 'r') as yaml_file:
#with open('../experiments/Swissroll_exp5_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)
# Load data loaders based on YAML configuration
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
print("Experiment results loaded successfully.")
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

In [None]:
torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config,additional_path='../')

print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)

In [None]:
# Create a folder to store the PNG images
shots_folder_name = "/generated_pics"
shots_folder_adress = '../../experiments/'+yaml_config["experiment"]["name"]+ shots_folder_name
if not os.path.exists(shots_folder_adress):
    os.makedirs(shots_folder_adress)
    print("A folder created for saved images to create a gif at:", shots_folder_adress)

In [None]:
batch_size = 128
K = 2#3 # number of clusters
data = test_dataset.data
N = 10#15 #len(data)
m = 10 # intermediate points on every geodesic
D = yaml_config["architecture"]["input_dim"]
d = yaml_config["architecture"]["latent_dim"]

# Limit dataset to the first n samples
subset_indices = list(range(N))
mnist_subset = torch.utils.data.Subset(data, subset_indices)
dataloader = torch.utils.data.DataLoader(mnist_subset, batch_size=batch_size, shuffle=False)

# encoding into latent space
torus_ae.cpu()
torus_ae.eval()

# Encode samples into latent space
encoded_points = []
with torch.no_grad():  # No need to compute gradients
    for images in dataloader:
#        print(images.shape)
        latent = torus_ae.encoder2lifting( (images.reshape(-1, D)).to(torch.float32) )  # Pass images through the encoder
        encoded_points.append(latent)
encoded_points = torch.cat(encoded_points)

In [None]:
# visualizing initialization
centers = ricci_regularization.initialize_centers(encoded_points, K, N)

plt.scatter(encoded_points[:,0],encoded_points[:,1], label = "encoded data")
plt.scatter(centers[:,0], centers[:,1], c="red", label = "centers", marker='*', s = 60)
plt.xlim(-torch.pi, torch.pi)
plt.ylim(-torch.pi, torch.pi)
plt.legend()
plt.show()

# Loss function and optimizer parameters setting

In [None]:

def construct_interpolation_points_on_segments_connecting_centers2encoded_data(starting_points, final_points, num_aux_points =10):
    """
    Connect every point in `starting_points` to every point in `final_points` with intermediate points.

    Args:
        starting_points (torch.Tensor): Tensor of shape (num_data_points, latent_dim) representing points.
        final_points (torch.Tensor): Tensor of shape (num_clusters, latent_dim) representing center points.
        num_aux_points (int): Number of intermediate points (including endpoints) per segment.

    Returns:
        torch.Tensor: Tensor of shape (num_data_points * num_clusters * num_aux_points, latent_dim) containing all intermediate points.
    """
    # Check that the final dimensions of inputs match
    if starting_points.shape[-1] != final_points.shape[-1] or final_points.shape[-1] != starting_points.shape[-1]:
        raise ValueError(
            f"Mismatch in dimensions: 'starting_points' and 'final_points' must have the same final dimension. "
            f"Got starting_points with shape {starting_points.shape}, final_points with shape {final_points.shape}. "
        )

    # Generate interpolation parameters (num_aux_points values between 0 and 1)
    t = torch.linspace(0, 1, steps=num_aux_points).to(starting_points.device).view(1, 1, num_aux_points, 1)  # Shape: (1, 1, num_aux_points, 1)

    # Reshape starting_points and final_points for broadcasting
    starting_points_expanded = starting_points.unsqueeze(1).unsqueeze(2)  # Shape: (num_starting_points, 1, 1, points_dim)
    final_points_expanded = final_points.unsqueeze(0).unsqueeze(2)        # Shape: (1, num_final_points, 1, points_dim)

    # Compute all intermediate points using linear interpolation
    all_points_on_geodesics = starting_points_expanded + t * (final_points_expanded - starting_points_expanded)  # Shape: (num_data_points, num_clusters, num_aux_points, latent_dim)

    # Select interpolation_points cutting of the starting and the final point for every segment
    interpolation_points = all_points_on_geodesics[:,:,1:-1,:]
    return interpolation_points

def geodesics_from_parameters_interpolation_points(parameters_of_geodesics, end_points):
    """
    Constructs geodesics from parameters of the geodesics and end points. 

    Parameters:
    - parameters_of_geodesics (torch.Tensor): Interpolation parameters with shape 
      (num_starting_points, num_clusters, num_interpolation_points, latent_dim).
    - end_points (list of torch.Tensor): [starting_points, final_points], where:
      - starting_points: Shape (num_starting_points, latent_dim).
      - final_points: Shape (num_clusters, latent_dim).

    Returns:
    - torch.Tensor: Complete geodesics with shape 
      (num_starting_points, num_clusters, num_interpolation_points + 2, latent_dim).
    """
    # reading the shapes of the parameters
    num_starting_points, num_clusters, num_interpolation_points, latent_dim = parameters_of_geodesics.shape
    starting_points, final_points = end_points
    # starting_points are usually encoded data
    # final_points are usually cluster centers  

    #expand starting_points
    starting_points_expanded = starting_points.unsqueeze(1).unsqueeze(2) # Shape: (num_starting_points, 1, 1, latent_dim)
    starting_points_expanded = starting_points_expanded.expand(num_starting_points, num_clusters , 1, latent_dim)
    #expand final_points
    final_points_expanded = final_points.unsqueeze(0).unsqueeze(2)  # Shape: (1, num_clusters, 1, latent_dim)
    final_points_expanded = final_points_expanded.expand(num_starting_points, num_clusters , 1, latent_dim)
    # concatenate the starting points, the interpolation_points and final_points  along the dimention associated interpolation_points
    all_points_on_geodesics = torch.cat((starting_points_expanded, parameters_of_geodesics, final_points_expanded),dim=2) 
    return all_points_on_geodesics

In [None]:
parameters_of_geodesics = construct_interpolation_points_on_segments_connecting_centers2encoded_data(encoded_points, centers, num_aux_points=m)

In [None]:
geodesics = geodesics_from_parameters_interpolation_points(parameters_of_geodesics, end_points=[encoded_points, centers])

In [None]:
ricci_regularization.plot_octopus(geodesics, saving_folder=shots_folder_adress, suffix=0,verbose=True) 
# silent -> show_plot; verbose_mode aka -v

In [None]:
# rewrite compute_energy(parameters_of_geodesics, end_points, decoder ) #seperate params from endpoints because the first are updated the second are not
# geodesics_from_parameters_interpolation_points(parameters_of_geodesics, end_points)
# geodesics_from_parameters_schauder(parameters_of_geodesics, end_points)
# keep the information about the geodesics parametrization mode as a parameter of all the functions: interpolating_points, schauder
def compute_energy(parameters_of_geodesics, end_points, decoder=torus_ae.decoder_torus):
    points_on_geodesics = geodesics_from_parameters_interpolation_points(parameters_of_geodesics, end_points)
    # add option of schauder basis
    decoded_points = decoder(points_on_geodesics)
    computed_energy = (( decoded_points[:,:,1:,:] - decoded_points[:,:,:-1,:] ) ** 2 ).sum() # comute sum of Euclidean energies of all the curves in R^D
    # make sure that optimization is parallelized
    # Warning! the outpiut is the single scalar, i.e the sum of all the energies
    return computed_energy

def compute_lengths(parameters_of_geodesics, end_points, decoder=torus_ae.decoder_torus):
    points_on_geodesics = geodesics_from_parameters_interpolation_points(parameters_of_geodesics, end_points)
    # add option of schauder basis
    #if points_on_geodesics.shape != torch.Size([num_data_points, num_classes, num_aux_points, latent_dim]):
    #    points_on_geodesics = points_on_geodesics.unsqueeze(0)
    decoded_points = decoder(points_on_geodesics)
    tangent_vectors = decoded_points[:,:,1:,:] - decoded_points[:,:,:-1,:]
    computed_lengths = torch.sqrt((tangent_vectors**2).sum(dim=(-2,-1))) # comute Euclidean compute_lengths of the curves in R^D
    return computed_lengths
loss_geodesics = compute_energy(parameters_of_geodesics=parameters_of_geodesics, end_points=[encoded_points, centers],decoder=torus_ae.decoder_torus)
loss_geodesics

# version 0 without local charts update

In [None]:
beta = 1.e-3 # Frechet mean learning rate #beta is learning_rate_frechet_mean
learning_rate = 1.e-3 # learning_rate_geodesics
num_iter_outer = 100 #10
num_iter_inner = 15 # number of geodesics refinement interations per 1 Frechet mean update
cluster_index_of_each_point = None
meaningful_geodesics = None

#loss_history = []
meaningful_geodesics_loss_history = []
meaningful_geodesics_loss_history_by_cluster = []
norm_Frechet_mean_gradient_history = []

# Initialize geodesic segments
centers = ricci_regularization.initialize_centers(encoded_points, K, N) #centers -> initial_centers
new_centers = torch.clone(centers) #-> current_centers
parameters_of_geodesics = construct_interpolation_points_on_segments_connecting_centers2encoded_data(encoded_points, centers, num_aux_points=m)

#segments = ricci_regularization.connect_centers2encoded_data_with_segments(encoded_points, centers, m) #initialize the segents between centers and data points
init_parameters = torch.clone(parameters_of_geodesics) # save initial segments
# Set optimizer params
parameters = torch.nn.Parameter(parameters_of_geodesics) # Wrap as a parameter

optimizer = torch.optim.SGD([parameters], lr=learning_rate)

In [None]:
points_on_geodesics = geodesics_from_parameters_interpolation_points(parameters_of_geodesics, end_points=[encoded_points, centers])
ricci_regularization.plot_octopus(points_on_geodesics.detach())

In [None]:
# add stopping criterium e.g. delta of energy

In [None]:
# Outer loop 
for iter_outer in range(num_iter_outer):
    # Inner loop (refining geodesics)
    for iter_inner in range(num_iter_inner):
        optimizer.zero_grad()  # Zero gradients
        # Compute the loss
        loss_geodesics = compute_energy(parameters_of_geodesics=parameters, end_points=[encoded_points, new_centers],decoder=torus_ae.decoder_torus)
        # Backpropagation: compute gradients
        loss_geodesics.backward()
        # Update parameters
        optimizer.step()
        # Store the loss value
        #loss_history.append(loss_geodesics.item())
        # saving plots. NB! It slows down the training as it contains loops. It serves to make the video afterwards
        #ricci_regularization.plot_octopus(points_on_geodesics,memberships=cluster_index_of_each_point,meaningful_geodesics=meaningful_geodesics, 
        #             saving_folder=shots_folder_adress, suffix=iter_outer*num_iter_inner + iter_inner,verbose=False)
        #print(f"Iteration #{iter_inner + 1}, loss: {loss_geodesics.item():.3f}")    
    # end inner loop

    # compute a vector of length of all geodesics shape (N,K)
    lengths_of_geodesics = compute_lengths(parameters_of_geodesics=parameters_of_geodesics, end_points=[encoded_points, new_centers],decoder=torus_ae.decoder_torus) 
    
    # retrieve the class membership of each point by finding the closest cluster center shape (N)
    cluster_index_of_each_point = torch.argmin(lengths_of_geodesics, dim=1) 
    
    batch_indices = torch.arange(N)
    points_on_geodesics = geodesics_from_parameters_interpolation_points(parameters_of_geodesics, end_points=[encoded_points, new_centers])

    # pick only geodesics connecting points to cluster centers where the points are assigned shape (N,m,d)
    meaningful_geodesics = points_on_geodesics[batch_indices, cluster_index_of_each_point, :, :].detach() 

    # v is the direction to move the cluster centers
    v = meaningful_geodesics[:,-1,:] - meaningful_geodesics[:,-2,:]
    v = v / v.norm(dim=1).unsqueeze(-1) # find the last segments of the geod shape (N,d)
    
    #---------------------------------------------------------------    
    # Update cluster centers with weight beta:
    #---------------------------------------------------------------
    # Assuming cluster_index_of_each_point is a tensor of shape (N,) containing cluster indices
    # and K is the number of clusters
    # Expand cluster_index_of_each_point to index into v and lengths_of_geodesics
    cluster_index_of_each_point_expanded = cluster_index_of_each_point.unsqueeze(-1).expand(-1, v.size(-1))
    # Compute weighted Frechet mean gradient for each cluster
    weighted_v = lengths_of_geodesics[:, 0].unsqueeze(-1) * v  # Shape: (N, d)

    # Create a one-hot encoding of the cluster indices
    one_hot_clusters = torch.nn.functional.one_hot(cluster_index_of_each_point, num_classes=K).float()  # Shape: (N, K)

    # Compute the gradients for each cluster
    Frechet_mean_gradient = one_hot_clusters.T @ weighted_v  # Shape: (K, d)
    """
    # Initialize gradients accumulator for all clusters
    Frechet_mean_gradient = torch.zeros((K, v.size(-1)), device=v.device)
    # Accumulate Frechet mean gradients for each cluster using scatter_add
    Frechet_mean_gradient.scatter_add_(0, cluster_index_of_each_point_expanded, weighted_v) #rewrite it with gather so that it is clearer
    """
    
    # Update cluster centers
    with torch.no_grad():
        new_centers += - beta * Frechet_mean_gradient  # Update all centers simultaneously

    # Compute average Frechet mean gradient norm
    average_Frechet_mean_gradient_norm = (Frechet_mean_gradient.norm(dim=1).mean()).item()
    # Append to norm history
    norm_Frechet_mean_gradient_history.append(average_Frechet_mean_gradient_norm)

    # saving the lengths of meaningful geodesics
    meaningful_geodesics_lengths = torch.gather(lengths_of_geodesics,1,cluster_index_of_each_point_expanded)[:,0]
    meaningful_geodesics_loss_history.append( meaningful_geodesics_lengths.detach().sum().item() )

    #compute the sum of geodesic length for each cluster
    total_length_of_meaningful_geodesics_by_cluster = torch.zeros(K, dtype=meaningful_geodesics_lengths.dtype)
    total_length_of_meaningful_geodesics_by_cluster.scatter_add_(0, cluster_index_of_each_point, meaningful_geodesics_lengths)    
    meaningful_geodesics_loss_history_by_cluster.append(total_length_of_meaningful_geodesics_by_cluster.unsqueeze(0))
    """
    # use gather instead of masks and the loop
    for i in range(K):
        average_Frechet_mean_gradient_norm = 0.
        cluster_mask = cluster_index_of_each_point == i 
        v_i = v[cluster_mask] 
        l_i = lengths_of_geodesics[cluster_mask][:,0]
        with torch.no_grad():
            FM_gradient = torch.sum( l_i.unsqueeze(-1) * v_i, dim=0 )
            new_centers[i] += - beta * FM_gradient # update i-th cluster center ( only moving the very last point on a geodesic)
            #print(f"\nNorm of gradient of Frechet mean lossfor {i}-th cluster", FM_gradient.norm().item())
        average_Frechet_mean_gradient_norm += FM_gradient.norm().item()/K
        # !save all of cluster Frechet mean gradients seperately
    norm_Frechet_mean_gradient_history.append(average_Frechet_mean_gradient_norm)
    """

# Plotting losses

In [None]:
# To do
# two curves one plot
# 1. meaningful geodesics lenght for cluster 1
# 2. meaningful geodesics lenght for cluster 2
# all meaningful geodesics
# !!! add the plot of conditional variance

In [None]:
# Assuming norm_Frechet_mean_gradient_history, meaningful_geodesics_loss_history, loss_history are arrays or tensors
fig, axes = plt.subplots(1, 2, figsize=(K*5, 5))  # Create a figure with 1 row and 3 columns

# Plot norm_Frechet_mean_gradient_history
axes[0].plot(norm_Frechet_mean_gradient_history, marker='o', markersize=3, label='Frechet mean update history')
axes[0].set_title('Averege shift of centers (proxy of Fréchet mean gradient norm)')
axes[0].set_xlabel('Outer loop iterations')
axes[0].set_ylabel('Loss')
axes[0].legend()

# Plot meaningful geodesic lengths by cluster
# Generate a color palette with distinct colors
colors = plt.cm.jet(torch.linspace(0, 1, K))  # Use a colormap (e.g., 'viridis')

lengths_of_meaningful_geodesics_concatenated = torch.cat((meaningful_geodesics_loss_history_by_cluster), dim=0).detach()
for i in range(K):
    axes[1].plot(lengths_of_meaningful_geodesics_concatenated[:, i],marker='o',markersize=3,
                 label=f'Cluster {i} geodesics length', color=colors[i])
    axes[1].set_title('Meaningful geodesics length by cluster')
    axes[1].set_xlabel('Outer Loop Iterations')
    axes[1].set_ylabel('Loss')
    axes[1].legend()

# Plot meaningful_geodesics_loss_history
axes[1].plot(meaningful_geodesics_loss_history, marker='o', markersize=3, label='All clusters geodesics length', color='green')
axes[1].set_title('Meaningfull geodesics length')
axes[1].set_xlabel('Outer loop iterations')
axes[1].legend()

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
geodesics = geodesics_from_parameters_interpolation_points(parameters_of_geodesics, end_points=[encoded_points, new_centers])
ricci_regularization.plot_octopus(geodesics.detach(), memberships=cluster_index_of_each_point,meaningful_geodesics=meaningful_geodesics)

In [None]:
print("center shifts:\n", (centers -  new_centers))
average_cluster_center_shift_norm = (new_centers - centers).detach().norm(dim = 1).mean()
print("Average center's shift:", average_cluster_center_shift_norm.item())

In [None]:
# stop here
raise Exception("Stopping point: Review output before proceeding.")

# Video

In [None]:
# Specify the directory containing PNGs and the output video name
images_folder = shots_folder_adress
output_video = "output_video.avi"

# Set video parameters
frame_rate = 30
images = sorted([img for img in os.listdir(images_folder) if img.endswith(".png")])
if not images:
    raise ValueError("No PNG images found in the specified directory.")


In [None]:

# Read the first image to get dimensions
first_image_path = os.path.join(images_folder, images[0])
frame = cv2.imread(first_image_path)
height, width, layers = frame.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')  # You can use other codecs like 'mp4v'
video = cv2.VideoWriter(output_video, fourcc, frame_rate, (width, height))

for image in images:
    img_path = os.path.join(images_folder, image)
    frame = cv2.imread(img_path)
    video.write(frame)

video.release()
cv2.destroyAllWindows()
print(f"Video saved as {output_video}")


# Recentering local charts (to be done)

In [None]:
new_centers = torch.tensor([[ 2., -2.5]])
#        [-2.5458,  2.2106],
#        [-1.0967,  2.2219]])

In [None]:
new_segments = ricci_regularization.connect_centers2encoded_data_with_segments(encoded_points, centers=new_centers, num_aux_points=10)

In [None]:
ricci_regularization.plot_octopus(segments=new_segments)#, xlim=None, ylim=None)

In [None]:
def compute_updated_segments(segments):
    # recognize the shape of segments:
    N = segments.shape[0] # num data points
    K = segments.shape[1] # num clusters
    m = segments.shape[2] # num auxilliary points
    d = segments.shape[3] # latent dimension
    updated_segments = segments.clone()
    # adapting segments to their local charts
    for i in range(N): # this is very bad! REDO with a mask
        for j in range(K):
            for dim in range(d):
                if torch.abs( segments[i,j,-1,dim] - segments[i,j,0,dim] ) > torch.pi:
                    # choose direction of the point shift
                    sign = torch.sgn( segments[i,j,-1,dim] - segments[i,j,0,dim] )
                    shift = sign * 2 * torch.pi
                    # shift the point 
                    updated_segments[i,j,0,dim] += shift
    # Generate interpolation parameters (m values between 0 and 1)
    t = torch.linspace(0, 1, steps=m).to(encoded_points.device).view(1, 1, m, 1)  # Shape: (1, 1, m, 1)

    new_centers = segments[:,:,-1,:]
    # Reshape encoded_points and centers for broadcasting
    new_start_points = updated_segments[:,0,0,:].unsqueeze(1).unsqueeze(2)   # Shape: (n, 1, 1, d)
    centers_expanded = new_centers.unsqueeze(2)        # Shape: (1, k, 1, d)

    # Compute all intermediate points using linear interpolation
    updated_segments = new_start_points + t * (centers_expanded - new_start_points)  # Shape: (n, k, m, d)
    return updated_segments
def mod_pi(segments): # only for plotting, local chart quiting has to be fixed
    # Returns the coordinates of points in the initial local chart
    return torch.remainder(segments + torch.pi, 2*torch.pi) - torch.pi

In [None]:
updated_segments = compute_updated_segments(new_segments).detach()
#updated_segments_mod_pi = mod_pi(updated_segments)

In [None]:
plt.scatter(new_segments[:,:,0,0], new_segments[:,:,0,1], c = 'green',zorder = 10, label = "before shift")
plt.scatter(updated_segments[:,:,0,0], updated_segments[:,:,0,1], c = 'magenta', s = 100, label = "after shift")
# add an arrow from not upd to upd segments
plt.xlim(-2*torch.pi, 2*torch.pi)
plt.ylim(-2*torch.pi, 2*torch.pi)
plt.legend()
plt.show()


In [None]:
ricci_regularization.plot_octopus(new_segments, xlim=2*torch.pi,ylim=2*torch.pi)
# add the grid plotting option

In [None]:
ricci_regularization.plot_octopus(updated_segments,xlim=2*torch.pi, ylim=2*torch.pi)

In [None]:
compute_energy(new_segments)

In [None]:
compute_energy(updated_segments)

In [None]:
#ricci_regularization.plot_octopus(updated_segments_mod_pi,xlim= torch.pi, ylim=torch.pi)

In [None]:
#compute_energy(updated_segments_mod_pi)

In [None]:
shift_array = [ -2 * torch.pi, 0., 2 * torch.pi]
segments_array = []
for shift_x in shift_array:
    for shift_y in shift_array:
        segments_array.append(updated_segments + shift_x * torch.tensor([1.,0.]) + shift_y * torch.tensor([0.,1.]))

In [None]:
for segments in segments_array:
    for i in range(N):
            for j in range(K):
                #if j == 0:
                #    color = "blue"
                #else:
                color = "orange"
                plt.plot(segments[i,j,:,0], segments[i,j,:,1],'-',marker='o', c = color, markersize=3)
        # plot centers
    centers = segments[0,:,-1,:]
    # plot the datapoints (the starting points on all the geodesics, colored by memberships if specified):
    plt.scatter(centers[:,0], centers[:,1], c="red", label = "centers", marker='*', edgecolor='black', s = 170,zorder = 10)
    plt.scatter(segments[:,0,0,0], segments[:,0,0,1], c="green", label = "centers", marker='o', s = 30,zorder = 10)
plt.xlim(-torch.pi, torch.pi)
plt.ylim(-torch.pi, torch.pi)