NB! Stochman pachage is required.
Type: pip install stochman

This notebook consists of:

0) Imports. Loading AE weights.
1) Demo: a geodesic connecting two poinnts (geodesic bvp via Stochman).
2) Grid of geodesics in latent space (parallelized computation): setting grid parameters + plotting.
3) Logarithmic map: demo
4) Grid of geodesics in a logarithmic chart.
5) Geodesic shooting via Runge-Kutta. Approximation via RK mathod to the solutian of the Cauchy problem to the geodesic ODE. Shooting is used to construct exponential maps.
6) Geodesic length ratio benchmark. Straight lines vs geodesics in a logarithmic chart. Average ratio between random geodesics and straight lines connecting their endpoints in a logaritmic chart with a single base point.
7) Geodesic length ratio for random multiple geodesics and different logmap base points. Same as in 6 length ratio is computed for random geodesics and different random basepoints of the logarithmic maps.
8) Frechet mean slow, (via a Sturm-like method). See description of the method in Section 3 of https://www-sop.inria.fr/asclepios/events/MFCA15/Papers/MFCA15_4_2.pdf.


The geodesics are computed considesing a geodesic boundary value problem (b.v.p.) with the help of Stochman package.

Let $( M, g )$ be a Riemannian manifold. Given points $ p, q \in M $, find a curve $\gamma : [0,1] \to M $ such that:
\begin{equation}
%\label{eq:geodesic_eq}
\begin{aligned}
    \gamma(0) &= p \ , \\
    \gamma(1) &= q \ , \\
    \nabla_{\dot{\gamma}} \dot{\gamma} &= 0 \ .
\end{aligned}
\end{equation}
where $ \nabla $ is the Levi-Civita connection associated with $ g $.

In local coordinates $ (x^1, x^2, \ldots, x^n) $, the geodesic equation is:
\begin{align}
    \frac{d^2 x^i}{dt^2} + \Gamma^i_{jk} \frac{dx^j}{dt} \frac{dx^k}{dt} = 0
\end{align}
where $ \Gamma^i_{jk} $ are the Christoffel symbols.

The length functional $ L $ for a curve $ \gamma $ is given by:
\begin{align*}
    L[\gamma] = \int_0^1 \sqrt{g_{\gamma(t)}(\dot{\gamma}(t), \dot{\gamma}(t))} \, dt
\end{align*}

The energy functional $E$ for a curve $\gamma$ is given by:
\begin{align*}
E[\gamma] = \int_0^1 g_{\gamma(t)}(\dot{\gamma}(t), \dot{\gamma}(t)) \, dt \quad \text{(6)}
\end{align*}

Geodesics are the curves that minimize the length functional $ L $ and also minimize the energy functional $ E $. In the Stochman package, geodesics connecting two points are found as minimizers of energy functionals. Technically, they are approximated by cubic splines through the solution of an optimization problem on the spline coefficients.

The latent space of the AE is topologically a $ 2 $-dimensional torus $\mathcal{T}^2$, i.e., it can be considered as a periodic box $[-\pi, \pi]^2$. We define a Riemannian metric on the latent space as the pull-back of the Euclidean metric in the output space $\mathbb{R}^D$ by the decoder function $\Psi$ of the AE:
\begin{equation}
    g = \nabla \Psi^* \nabla \Psi,
\end{equation}

The main objectives here are to show that:
1) The geometry of the latent space is different from Euclidean.
2) Adding curvature penalization forces the geodesics in logarithmic charts to become closer to straight lines.

# 0. Imports. Loading AE weights.

In [None]:
from tqdm.notebook import tqdm
import torch
import numpy as np
import ricci_regularization
import matplotlib.pyplot as plt
import matplotlib
from stochman.manifold import EmbeddedManifold
from stochman.curves import CubicSpline
import json, yaml, os

In [None]:
violent_saving = True
#with open('../experiments/MNIST_Setting_2_config.yaml', 'r') as yaml_file:
#with open('../experiments/Synthetic_Setting_1/Synthetic_Setting_1_config.yaml', 'r') as yaml_file:
with open('../experiments/Swissroll_exp2_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config)

print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)

In [None]:
experiment_name = yaml_config["experiment"]["name"]

#Path_pictures = yaml_config["experiment"]["path"]
Path_pictures = "../experiments/" + yaml_config["experiment"]["name"]
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name == ["MNIST01"]:
    # k from the JSON configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
elif dataset_name == "MNIST":
    selected_labels = np.arange(10)
    k = 10
elif dataset_name == "Synthetic":
    k = yaml_config["dataset"]["k"] 
print("Experiment name:", experiment_name)
print("Plots saved at:", Path_pictures)

In [None]:
# Move the torus autoencoder to CPU
torus_ae.cpu()

# Initialize lists to store various data
colorlist = []  # List to store labels
enc_list = []   # List to store encoded representations
feature_space_encoding_list = []  # List to store feature space encodings
input_dataset_list = []  # List to store input datasets
recon_dataset_list = []  # List to store reconstructed datasets

# Iterate through the test_loader (or train_loader if uncommented), and collect data
for (data, labels) in tqdm(test_loader, position=0):
    input_dataset_list.append(data)  # Append input data to input_dataset_list
    recon_dataset_list.append(torus_ae(data)[0])  # Append reconstructed data to recon_dataset_list
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1, D)))  # Append feature space encoding to feature_space_encoding_list
    enc_list.append(torus_ae.encoder2lifting(data.view(-1, D)))  # Append encoded representations to enc_list
    colorlist.append(labels)  # Append labels to colorlist

# Concatenate lists to form complete datasets and encoded representations
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()  # Detach encoded points from the computation graph
color_array = torch.cat(colorlist).detach()  # Detach color array from the computation graph

# Plot the encoded points in feature space
plt.figure(figsize=(9,9))
if dataset_name =="Swissroll":
    my_cmap = "jet"
else:
    my_cmap = ricci_regularization.discrete_cmap(k,"jet",bright_colors=True)
    
plt.scatter(encoded_points_no_grad[:, 0], encoded_points_no_grad[:, 1], c=color_array, alpha=0.6, cmap = my_cmap)
plt.ylim(-torch.pi, torch.pi)
plt.xlim(-torch.pi, torch.pi)
plt.show()  # Display the plot

# alternative plotting
#fig = ricci_regularization.point_plot(encoder=torus_ae.encoder2lifting, data=test_dataset, config=yaml_config,batch_idx=0,device="cpu")
#fig.show()


# 1. Demo: a geodesic connecting two poinnts (geodesic bvp via Stochman).

In [None]:
#from stochman.manifold import EmbeddedManifold
# geodesics are computed minimizing "energy" in the embedding of the manifold,
# So no need to compute the Pullback metric. and thus the algorithm is fast
# Define the embedding by the AE decoder
class Autoencoder(EmbeddedManifold):
    def embed(self, c, jacobian = False):
        return torus_ae.decoder_torus(c)

In [None]:
model = Autoencoder()
torch.manual_seed(0)

# for plotting
t = torch.linspace(0.,1.,100)

# p0 and p1 can be chosen anywhere on R^2 with 2\pi periodic metric 
p0 = torch.tensor([0.,-2.]) #+11*torch.pi
p1 = torch.tensor([1.5,1.5]) #+ 11*torch.pi
# find a pair of points with different labels (first in test loader) 
#p0 = encoded_points[torch.where(color_array==selected_labels[0])][0].detach()
#p1 = encoded_points[torch.where(color_array==selected_labels[1])][0].detach()
print(f"start:{p0}, \n end {p1}")
c, success = model.connecting_geodesic(p0, p1) # here the parameter t in c(t)should be a torch.tensor
print("Success:",success.item(),"\n length",model.curve_length(c(t)).item())

In [None]:
points_on_geodesic = c(t).detach()
straight_line = CubicSpline(p0,p1)
straight_line_points2plot = straight_line(t).detach()

geod_length = model.curve_length(c(t)).item()
straight_line_length = model.curve_length(straight_line(t)).item()

In [None]:
plt.figure(figsize=(9,9))
plt.title("Geodesic bvp: straight line vs geodesic.", fontsize=20)
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,alpha=0.6,cmap=my_cmap,zorder=0)
plt.plot(points_on_geodesic[:,0],points_on_geodesic[:,1],c="blue",label=f"Geodesic.\nLength:{geod_length:.3f}",zorder=1e4,linewidth=3.)
plt.plot(straight_line_points2plot[:,0],straight_line_points2plot[:,1],c="green",label=f"Straight line.\nLength:{straight_line_length:.3f}",zorder=1e5,linewidth=3.)
plt.legend(loc="upper left", fontsize=15)
plt.ylim(-torch.pi, torch.pi)
plt.xlim(-torch.pi, torch.pi)
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/geodesic_vs_straight.pdf',bbox_inches='tight',format='pdf')
    plt.savefig(f'{Path_pictures}/geodesic_vs_straight.jpg',bbox_inches='tight',format='jpeg',dpi=400)
plt.show()

# 2. Grid of geodesics in latent space (parallelized computation).

## Creating a grid

In [None]:
# old grid
def create_grid_version_linspace(num_geodesics, x_left=-torch.pi/2, y_bottom=-torch.pi/2,
                                 x_size=torch.pi, y_size=torch.pi):
    """
    Creates a grid of geodesics around a rectangular box using linspace to distribute points.

    Parameters:
    - num_geodesics (int): Number of geodesics to compute.
    - x_left, y_bottom (float): Left and bottom coordinates of the box.
    - x_size, y_size (float): Width and height of the box.

    Returns:
    - horizontal_geodesics (stochman.curves.CubicSpline): Geodesics along the horizontal borders.
    - vertical_geodesics (stochman.curves.CubicSpline): Geodesics along the vertical borders.
    """
    x_right = x_left + x_size  # Right coordinate of the box
    y_top = y_bottom + y_size  # Top coordinate of the box

    # Compute starting and ending points for horizontal geodesics
    starting_points = torch.cat([torch.tensor([x_left, y_bottom + k]) for k in torch.linspace(0, y_size, num_geodesics)]).reshape(num_geodesics, 2)
    end_points = torch.cat([torch.tensor([x_right, y_bottom + k]) for k in torch.linspace(0, y_size, num_geodesics)]).reshape(num_geodesics, 2)

    # Compute starting and ending points for vertical geodesics
    starting_points_vertical = torch.cat([torch.tensor([x_left + k, y_bottom]) for k in torch.linspace(0, x_size, num_geodesics)]).reshape(num_geodesics, 2)
    end_points_vertical = torch.cat([torch.tensor([x_left + k, y_top]) for k in torch.linspace(0, x_size, num_geodesics)]).reshape(num_geodesics, 2)

    # Compute horizontal geodesics
    horizontal_geodesics, _ = model.connecting_geodesic(starting_points, end_points)

    # Compute vertical geodesics
    vertical_geodesics, _ = model.connecting_geodesic(starting_points_vertical, end_points_vertical)

    return horizontal_geodesics, vertical_geodesics



In [None]:
# Create grid version 1
def create_grid_version_1(num_geodesics, x_left=-torch.pi/2, y_bottom=-torch.pi/2,
                          x_size=torch.pi, y_size=torch.pi):
    """
    This function creates a grid of geodesics around a box of 4 border geodesics defined by the parameters.
    
    Parameters:
    - num_geodesics (int): Number of geodesics to compute.
    - x_left, y_bottom (float): Left and bottom coordinates of the box.
    - x_size, y_size (float): Width and height of the box.
    
    Returns:
    - horizontal_geodesics (stochman.curves.CubicSpline): Geodesics along the horizontal borders.
    - vertical_geodesics (stochman.curves.CubicSpline): Geodesics along the vertical borders.
    """
    x_right = x_left + x_size
    y_top = y_bottom + y_size

    # Collect 4 corner points to connect them with geodesics which are on the border of the box
    starting_points_border = torch.cat([torch.tensor([x_left, y_bottom + k]) for k in torch.linspace(0, y_size, 2)]).reshape(2, 2)
    end_points_border = torch.cat([torch.tensor([x_right, y_bottom + k]) for k in torch.linspace(0, y_size, 2)]).reshape(2, 2)
    
    starting_points_vertical_border = torch.cat([torch.tensor([x_left + k, y_bottom]) for k in torch.linspace(0, x_size, 2)]).reshape(2, 2)
    end_points_vertical_border = torch.cat([torch.tensor([x_left + k, y_top]) for k in torch.linspace(0, x_size, 2)]).reshape(2, 2)
    
    # Connect geodesics for horizontal and vertical borders
    horizontal_geodesics_border, _ = model.connecting_geodesic(starting_points_border, end_points_border)
    vertical_geodesics_border, _ = model.connecting_geodesic(starting_points_vertical_border, end_points_vertical_border)
    
    # Find equidistant points on the borders of the grid. Their number is num_geodesics
    t = torch.linspace(0, 1, num_geodesics)
    geodesics2plot_horizontal_border = horizontal_geodesics_border(t).detach()
    geodesics2plot_vertical_border = vertical_geodesics_border(t).detach()
    
    # Extract starting and ending points for horizontal and vertical geodesics
    starting_points = geodesics2plot_vertical_border[0, :, :]
    end_points = geodesics2plot_vertical_border[1, :, :]
    starting_points_vertical = geodesics2plot_horizontal_border[0, :, :]
    end_points_vertical = geodesics2plot_horizontal_border[1, :, :]
    
    # Compute horizontal geodesics
    horizontal_geodesics, _ = model.connecting_geodesic(starting_points, end_points)

    # Compute vertical geodesics
    vertical_geodesics, _ = model.connecting_geodesic(starting_points_vertical, end_points_vertical)
    
    return horizontal_geodesics, vertical_geodesics

# Function to plot the grid version 1
def plot_grid(geodesics2plot_horizontal, geodesics2plot_vertical,
                    savefig=False,show_data=False):
    num_geodesics = geodesics2plot_horizontal.shape[0]
    if show_data == True:
        plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=my_cmap)
    #plt.title(f"Geodesic grid on MNIST with {k} labels with $\lambda_{{\mathrm{{curv}}}}={curv_w}$")
    for i in range(num_geodesics):
        plt.plot(geodesics2plot_horizontal[i,:,0],geodesics2plot_horizontal[i,:,1],c="black")
        plt.plot(geodesics2plot_vertical[i,:,0],geodesics2plot_vertical[i,:,1],c="black")
    plt.xlim(-torch.pi,torch.pi)
    plt.ylim(-torch.pi,torch.pi)
    if savefig == True:
        plt.savefig(f'{Path_pictures}/multiple_geodesics_{experiment_name}.pdf',bbox_inches='tight',format='pdf')
    return plt, geodesics2plot_horizontal, geodesics2plot_vertical

In [None]:
"""
x_left = -0.5 #-torch.pi/2
y_bottom = 1.5 #-torch.pi/2
x_size= 1. #torch.pi
y_size= 1. #torch.pi
"""
x_center = -2.
y_center = 1.
x_size= 1.75 #torch.pi
y_size= 1.75 #torch.pi

x_left = x_center - x_size/2 #-2.0 #-torch.pi/2
y_bottom = y_center - y_size/2 #-2.0 #-torch.pi/2

x_right = x_center + x_size/2 
y_top = y_center + y_size/2

# set num_geodesics
num_geodesics = 11

# create a geodesic grid version 1:
horizontal_geodesics, vertical_geodesics = create_grid_version_1(num_geodesics=num_geodesics, 
        x_left=x_left, y_bottom=y_bottom,
        x_size=x_size, y_size=y_size)

# Call the function to generate the plot
num_approximation_points = 20

# Define the number of approximation points for plotting
t = torch.linspace(0, 1, num_approximation_points)

# Compute geodesics for horizontal and vertical borders
geodesics2plot_horizontal = horizontal_geodesics(t).detach()
geodesics2plot_vertical = vertical_geodesics(t).detach()
p, geodesics2plot_horizontal, geodesics2plot_vertical = plot_grid(geodesics2plot_horizontal,geodesics2plot_vertical, show_data=True)

p.show()

In [None]:
# curvature along geodesics (not used yet)
#scalar_curvature_on_geodesics_bvp = ricci_regularization.Sc_jacfwd_vmap(geodesics2plot_horizontal.reshape(-1,2),function=torus_ae.decoder_torus)
#scalar_curvature_on_geodesics_bvp = scalar_curvature_on_geodesics_bvp.reshape(num_geodesics,num_approximation_points).detach()

In [None]:
plt.figure(dpi=300,figsize=(9,9))
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],
            alpha=0.6, c = color_array,cmap = my_cmap)
#plt.title(f"Geodesic grid on MNIST with {k} labels with $\lambda_{{\mathrm{{curv}}}}={curv_w}$")
for i in range(num_geodesics):
    plt.plot(geodesics2plot_horizontal[i,:,0],geodesics2plot_horizontal[i,:,1],c="black",linewidth=2.5)
    plt.plot(geodesics2plot_vertical[i,:,0],geodesics2plot_vertical[i,:,1],c="black",linewidth=2.5)
plt.xlim(-torch.pi,torch.pi)
plt.ylim(-torch.pi,torch.pi)
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/multiple_geodesics_{experiment_name}_{num_geodesics}.pdf',bbox_inches='tight',format='pdf')
    #plt.savefig(f'{Path_pictures}/multiple_geodesics_{experiment_name}_{num_geodesics}.jpg',bbox_inches='tight',format='jpeg',dpi=400)
plt.show()

# 3. Logarithmic map: demo

$log_{p_0} (p_1) = v, \ $ where $\gamma$ is the geodesic, s.t. $\gamma(0) = p_0$ and $\gamma(1) = p_1$ and $v = \dot \gamma(0), \ \|v\| = $ length of the geodesic. 

In [None]:
p0

In [None]:
p1

In [None]:
geod,_ = model.connecting_geodesic(p0,p1)

Checking that $log_{p_0} (p_1) = \dot \gamma(0)$

In [None]:
model.logmap(p0.unsqueeze(0).detach(),p1.unsqueeze(0).detach())

In [None]:
geod.deriv(torch.zeros(1))

# 4. Grid of geodesics in a logarithmic chart

In [None]:
# choosing base point for logarithmic map
#base_point = torch.tensor([-0.5,-0.8]).unsqueeze(0)

In [None]:
# find the base point for logarithmic chart at the crossing of central geogesics
def find_closest_points(A, B):
    # A is of shape (N, 2), B is of shape (M, 2)
    
    # Compute the pairwise differences between points in A and B
    diff = A[:, None, :] - B[None, :, :]  # Shape: (N, M, 2)
    
    # Compute the squared Euclidean distances
    dist_squared = torch.sum(diff ** 2, dim=-1)  # Shape: (N, M)
    
    # Find the indices of the minimum distance
    min_dist_idx = torch.argmin(dist_squared)
    
    # Convert the flat index to row and column indices
    min_row, min_col = divmod(min_dist_idx.item(), B.shape[0])
    
    # Return the closest points and their distance
    return A[min_row], B[min_col], torch.sqrt(dist_squared[min_row, min_col])

In [None]:
central_index = num_geodesics//2
num_approximation_points = 100
t = torch.linspace(0.,1.,num_approximation_points)

# recompute geodesics for horizontal and vertical borders
geodesics2plot_horizontal = horizontal_geodesics(t).detach()
geodesics2plot_vertical = vertical_geodesics(t).detach()

base_point,_,_ = find_closest_points(geodesics2plot_vertical[central_index],geodesics2plot_horizontal[central_index])
print("base point:",base_point)

In [None]:
# this takes 38 secs for num_geod=12 and num_approximation_points = 20 

num_approximation_points = 20
t = torch.linspace(0.,1.,num_approximation_points)

# recompute geodesics for horizontal and vertical borders
geodesics2plot_horizontal = horizontal_geodesics(t).detach()
geodesics2plot_vertical = vertical_geodesics(t).detach()


horizontal_geodesics2plot_logmap = model.logmap(base_point.repeat(num_geodesics*num_approximation_points,1),geodesics2plot_horizontal.reshape(-1,2))
vertical_geodesics2plot_logmap = model.logmap(base_point.repeat(num_geodesics*num_approximation_points,1),geodesics2plot_vertical.reshape(-1,2))

In [None]:
horizontal_geodesics2plot_logmap = horizontal_geodesics2plot_logmap.reshape(num_geodesics,num_approximation_points,-1)
vertical_geodesics2plot_logmap = vertical_geodesics2plot_logmap.reshape(num_geodesics,num_approximation_points,-1)

In [None]:
# central cross of base geodesics
"""
left_base = geodesics2plot_horizontal[num_geodesics//2][0].unsqueeze(0)
right_base = geodesics2plot_horizontal[num_geodesics//2][-1].unsqueeze(0)
top_base = geodesics2plot_vertical[num_geodesics//2][-1].unsqueeze(0)
bottom_base = geodesics2plot_vertical[num_geodesics//2][0].unsqueeze(0)
print(f"Left {left_base}, top {top_base}, right {right_base}, bottom {bottom_base}")
base_geod,_ = model.connecting_geodesic(base_point.repeat(4,1), torch.cat((left_base,top_base,right_base,bottom_base)))
base_geod_points = base_geod(t).detach()
# takes 5 secs for num_approximation_points = 20
base_geod2plot_logmap = model.logmap(base_point.repeat(4*num_approximation_points,1),base_geod_points.reshape(-1,2))

base_geod2plot_logmap = base_geod2plot_logmap.reshape(4,num_approximation_points,2)
base_point_x = base_point.squeeze()[0]
base_point_y = base_point.squeeze()[1]
"""

In [None]:
plt.figure(dpi=300)
#plt.title(f"Geodesic grid in $T_{{p_0}} M$ after log map with base point $p_0$, \n experiment # {exp_number} with $\lambda_{{\mathrm{{curv}}}}={curv_w}$")
for i in range(num_geodesics):
    color = "orange"
    if i == central_index:
        color = "red"
    plt.plot(horizontal_geodesics2plot_logmap[i,:,0],horizontal_geodesics2plot_logmap[i,:,1],c=color)
    plt.plot(vertical_geodesics2plot_logmap[i,:,0],vertical_geodesics2plot_logmap[i,:,1],c=color)
    #plt.scatter(horizontal_geodesics2plot_logmap[i,:,0], horizontal_geodesics2plot_logmap[i,:,1])
    #plt.scatter(vertical_geodesics2plot_logmap[i,:,0], vertical_geodesics2plot_logmap[i,:,1],c="black")
# base "cross" of geodesics
#for j in range(4):
#    plt.plot(base_geod2plot_logmap[j,:,0],base_geod2plot_logmap[j,:,1],c="red",label="Geodesics through base point" if j==0 else "")
#plt.scatter(base_point[:,0],base_point[:,1],marker = "*",c="blue",s = 120,label = f"Base point $p_0$ = ({base_point_x:.1f}, {base_point_y:.1f})",zorder = 3)
plt.legend()
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/geodesic_grid_logmap_{experiment_name}_{num_geodesics}.pdf',bbox_inches='tight',format='pdf')
    #plt.savefig(f'{Path_pictures}/geodesic_grid_logmap_exp{exp_number}_{num_geodesics}.jpg',bbox_inches='tight',format='jpeg',dpi=400)
plt.show()

# break here

In [None]:
lo

# 5. Geodesic shooting via Runge-Kutta.

In [None]:
def geod_vect(x, dxdt):
    """
    Computes the geodesic vector field on a manifold given initial positions and velocities.

    Parameters:
    - x (tensor): Initial position array.
    - dxdt (tensor): Initial velocity array.

    Returns:
    - dudt (tensor): Rate of change of positions.
    - dvdt (tensor): Rate of change of velocities.
    """
    u = x  # Initial position
    v = dxdt  # Initial velocity
    dudt = v  # Rate of change of positions is velocity

    n = v.shape[0]  # Number of elements in velocity array
    dvdt = torch.zeros(n, 2)  # Initialize rate of change of velocities

    # Compute Christoffel symbols and update velocities
    Ch_at_u = ricci_regularization.Ch_jacfwd_vmap(u, function=torus_ae.decoder_torus)
    for l in range(2):
        for i in range(2):
            for j in range(2):
                dvdt[:, l] -= Ch_at_u[:, l, i, j] * v[:, i] * v[:, j]

    return dudt, dvdt


def rungekutta_vect(f, initial_point_array, initial_speed_array, t, args=()):
    """
    Implements the Runge-Kutta method for solving ordinary differential equations (ODEs).

    Parameters:
    - f (function): Function defining the ODE system.
    - initial_point_array (tensor): Initial positions.
    - initial_speed_array (tensor): Initial velocities.
    - t (array-like): Array of time points.
    - args (tuple, optional): Additional arguments for the ODE function `f`.

    Returns:
    - x (tensor): Array of positions over time.
    - dxdt (tensor): Array of velocities over time.
    """
    n = len(t)  # Number of time steps
    x = torch.zeros((n, *tuple(initial_point_array.shape)))  # Initialize position array
    dxdt = torch.zeros((n, *tuple(initial_speed_array.shape)))  # Initialize velocity array

    x[0] = initial_point_array  # Set initial position
    dxdt[0] = initial_speed_array  # Set initial velocity

    # Iterate through time steps and apply Runge-Kutta method
    for i in range(n - 1):
        dudt, dvdt = f(x[i], dxdt[i], *args)  # Compute derivatives
        dt = t[i+1] - t[i]  # Time step size
        x[i+1] = x[i] + dt * dudt  # Update positions
        dxdt[i+1] = dxdt[i] + dt * dvdt  # Update velocities

    return x, dxdt



In [None]:
#from torch.nn.functional import normalize # if one needs to normalize initial speeds

num_approximation_points = 101 # how good the approximation is
max_parameter_value = 1 #3 # how far to go
time_array = torch.linspace(0, max_parameter_value, num_approximation_points)

#num_geodesics = 20

starting_points = torch.tensor([-0.,0.]).repeat(num_geodesics,1) # common starting point
#starting_points = p0.repeat(num_geodesics,1) # common starting point

maxtangent = 2 # max slope of geodesics 
starting_speeds = torch.cat([torch.tensor([1.,0. + k]) for k in torch.linspace(-maxtangent,maxtangent,num_geodesics) ]).reshape(num_geodesics,2)
#starting_speeds = c.deriv(torch.zeros(1)).reshape(num_geodesics,2)
#starting_speeds = normalize(starting_speeds) #make norms of all speeds equal

geodesics2plot,_ = rungekutta_vect(f=geod_vect,initial_point_array=starting_points,
                                   initial_speed_array=starting_speeds,t=time_array)
geodesics2plot = geodesics2plot.detach()

In [None]:
scalar_curvature_on_geodesics = ricci_regularization.Sc_jacfwd_vmap(geodesics2plot.reshape(-1,2),function=torus_ae.decoder_torus)
scalar_curvature_on_geodesics = scalar_curvature_on_geodesics.reshape(num_approximation_points,num_geodesics).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(2,"jet"))
for i in range(num_geodesics):
    #plt.scatter(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c=time_array,cmap="jet")
    plt.scatter(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c=scalar_curvature_on_geodesics[:,i],cmap="viridis",norm=matplotlib.colors.SymLogNorm(linthresh=1e-2))
    plt.plot(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c="black")
plt.colorbar(label="scalar curvature along geodesics")
plt.show()

## Reconstructing geodesics from logarithmic charts by constructing logarithmic maps via geodesic shooting (solving ODE with initial conditions)  with Runge-Kutta method

In [None]:
#points,_ = rungekutta_vect(geod_vect,base_point.repeat(200,1),horizontal_geodesics2plot_logmap.reshape(200,2),t=time_array)
num_recon_points_on_all_geodesics = horizontal_geodesics2plot_logmap.reshape(-1,2).shape[0]
points,_ = rungekutta_vect(geod_vect,base_point.repeat(num_recon_points_on_all_geodesics,1),horizontal_geodesics2plot_logmap.reshape(-1,2),t=time_array)

In [None]:
points = points.reshape(-1,num_geodesics,num_approximation_points,2)
points = points.detach()

In [None]:
plt.title("Log map verification: reconstructing geodesics in the latent space \nfrom the log. chart using exponential map via geodesics shooting \nwith Runge-Kutta method")
for i in range(num_geodesics):
    plt.plot(points[-1,i,:,0], points[-1,i,:,1],c="blue",label="Reconstructed geodesics" if i==0 else "")
    plt.plot(geodesics2plot_horizontal[i,:,0],geodesics2plot_horizontal[i,:,1],c="orange", label="Original geodesics" if i==0 else "")
plt.legend(loc="lower right")
plt.xlim(-torch.pi,torch.pi)
plt.ylim(-torch.pi,torch.pi)
plt.show()

# 6. Geodesic length ratio benchmark. Straight lines vs geodesics in a logarithmic chart.

In [None]:
num_points = 20
vectors = horizontal_geodesics2plot_logmap[:,-1,:]-horizontal_geodesics2plot_logmap[:,0,:]

In [None]:
t = torch.linspace(0,1,num_points)

In [None]:
straight_lines = horizontal_geodesics2plot_logmap[:,0,:] + torch.tensordot(t.unsqueeze(0),vectors,dims=0).reshape(num_points,num_geodesics,2)

In [None]:
for i in range(num_geodesics):
    plt.plot(horizontal_geodesics2plot_logmap[i,:,0],horizontal_geodesics2plot_logmap[i,:,1],c="orange")
    plt.plot(straight_lines[:,i,0],straight_lines[:,i,1],c="black")

In [None]:
geodesic_lengths = model.curve_length(horizontal_geodesics(t))
print(f"geodesic lengths: {geodesic_lengths}")

In [None]:
exp_base_point_x_i,_ = rungekutta_vect(f=geod_vect,initial_point_array=base_point.repeat(num_points*num_geodesics,1),
                    initial_speed_array=straight_lines.reshape(-1,2), t=time_array)

In [None]:
y = exp_base_point_x_i[-1].reshape(num_points, num_geodesics, 2).detach()
y.shape

In [None]:
plt.title("Geodesics and straight lines in logmap")
for i in range(num_geodesics):
    plt.scatter(y[:,i,0],y[:,i,1])
    #plt.plot(straight_lines[:,i,0],straight_lines[:,i,1],c="black")

In [None]:
geodesics_y_i,_ = model.connecting_geodesic(y[:-1].reshape(-1,2), y[1:].reshape(-1,2))

In [None]:
log_straight_lines_length_approx = model.curve_length(geodesics_y_i(t)).reshape(num_points-1,num_geodesics).sum(dim = 0)
print(f"straight_lines_length_approx: {log_straight_lines_length_approx}")

In [None]:
log_straight_lines_length_approx / geodesic_lengths

In [None]:
geodesic_lengths

In [None]:
geod_length_ratio = (geodesic_lengths/log_straight_lines_length_approx).mean().item()
print(f"geodesic length ratio:\n{geod_length_ratio}")

In [None]:
dict = {"geod_length_ratio":geod_length_ratio}

In [None]:
if violent_saving == True:
    with open(f'{Path_pictures}/geodesic_length_ratio_{experiment_name}.json', 'w') as json_file:
        json.dump(dict, json_file, indent=4)

# 7. Geodesic length ratio for random multiple geodesics and different logmap base points.

In [None]:
torch.manual_seed(0)

num_approximation_points = 20
t = torch.linspace(0,1,num_approximation_points)

num_geodesics = 7
#selecting geodesic start/end points and log map base points randomly
random_starting_points = torch.pi*(torch.rand(num_geodesics,2)-0.5)
random_end_points = torch.pi*(torch.rand(num_geodesics,2)-0.5)
base_points = torch.pi*(torch.rand(num_geodesics,2)-0.5)

random_geodesics, success = model.connecting_geodesic(random_starting_points, random_end_points)
random_geodesics2plot = random_geodesics(t).detach()

In [None]:
random_geodesics2plot.shape

In [None]:
plt.title("Random geodesics and basepoints")
for i in range(num_geodesics):
    plt.scatter(base_points[i,0],base_points[i,1])
    plt.plot(random_geodesics2plot[i,:,0],random_geodesics2plot[i,:,1])
plt.xlim(-torch.pi, torch.pi)
plt.ylim(-torch.pi, torch.pi)
plt.show()

In [None]:
random_geodesics2plot_logmap = model.logmap(base_points.repeat(1,num_approximation_points).reshape(num_approximation_points*num_geodesics,2),random_geodesics2plot.reshape(num_approximation_points*num_geodesics,2))
random_geodesics2plot_logmap = random_geodesics2plot_logmap.reshape(num_geodesics,num_approximation_points,-1)

In [None]:
model.logmap(base_points,base_points)

In [None]:
random_geodesic_lengths = model.curve_length(random_geodesics(t))
print(f"geodesic lengths: {random_geodesic_lengths}")

In [None]:
num_points = 10 #number of intermediate poits on an image of a  a geodesic in log map

vectors = random_geodesics2plot_logmap[:,-1,:]-random_geodesics2plot_logmap[:,0,:]
t = torch.linspace(0,1,num_points)
straight_lines = random_geodesics2plot_logmap[:,0,:] + torch.tensordot(t.unsqueeze(0),vectors,dims=0).reshape(num_points,num_geodesics,2)

In [None]:
plt.title("Geodesics and straight lines in logmap")
for i in range(num_geodesics):
    plt.plot(random_geodesics2plot_logmap[i,:,0],random_geodesics2plot_logmap[i,:,1],c="orange")
    plt.plot(straight_lines[:,i,0],straight_lines[:,i,1],c="black")

In [None]:
exp_base_point_x_i,_ = rungekutta_vect(f=geod_vect,initial_point_array=base_points.repeat(num_points,1).reshape(num_points*num_geodesics,2),
                    initial_speed_array=straight_lines.reshape(-1,2), t=time_array)
y = exp_base_point_x_i[-1].reshape(num_points,num_geodesics, 2)
y = y.detach()

In [None]:
geodesics_y_i,_ = model.connecting_geodesic(y[:-1].reshape(-1,2), y[1:].reshape(-1,2))
geodesics_y_i2plot = geodesics_y_i(t).reshape((num_points-1),num_geodesics,num_points,2).detach()

In [None]:
plt.title("Images of straight lines in logmap through exp maps with appropriate basepoints ")
for i in range(num_geodesics):
    plt.plot(y[:,i,0],y[:,i,1])
    plt.scatter(y[:,i,0],y[:,i,1])
    #plt.plot(geodesics_y_i2plot[:,i,:,0],geodesics_y_i2plot[:,i,:,1],c="black")
    #plt.plot(straight_lines[:,i,0],straight_lines[:,i,1],c="black")

In [None]:
log_straight_lines_length_approx = model.curve_length(geodesics_y_i(t)).reshape(num_points-1,num_geodesics).sum(dim = 0)
print(f"straight_lines_length_approx: {log_straight_lines_length_approx}")

In [None]:
random_geodesic_lengths/log_straight_lines_length_approx

In [None]:
random_geod_length_ratio = (random_geodesic_lengths/log_straight_lines_length_approx).mean().item()
print(f"geodesic length ratio:\n{random_geod_length_ratio}")

In [None]:
fig,(ax1,ax2) = plt.subplots(ncols=2,figsize=(12,6))
fig.suptitle(f"Experiment # {exp_number} with $\lambda_{{\mathrm{{curv}}}}={curv_w}.$")
ax1.set_title("Random geodesics and basepoints")
ax2.set_title("Images of these geodesics through logmaps \n w.r.t. corresponding base points")
for i in range(num_geodesics):
    p = ax1.plot(random_geodesics2plot[i,:,0],random_geodesics2plot[i,:,1])
    automatic_color = p[-1].get_color()  
    ax1.scatter(base_points[i,0],base_points[i,1],c = automatic_color)
    ax2.scatter(random_geodesics2plot_logmap[i,:,0],random_geodesics2plot_logmap[i,:,1],c = automatic_color)
    ax2.plot(random_geodesics2plot_logmap[i,:,0],random_geodesics2plot_logmap[i,:,1],c = automatic_color)
fig.text(0.1,0,f"Geodesic length ratio:{random_geod_length_ratio:.4f}")
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/random_geodesics_exp{exp_number}.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
# tha accuracy here has to be less then the threshehold in the algorithm
dict = {"geod_length_ratio":random_geod_length_ratio}
if violent_saving == True:
    with open(f'{Path_pictures}/random_geodesic_length_ratio_exp{exp_number}.json', 'w') as json_file:
        json.dump(dict, json_file, indent=4)

In [None]:
from pypdf import PdfWriter
build_report = False
if build_report == True:
    pdfs = [f'{Path_pictures}/multiple_geodesics_exp{exp_number}.pdf',f'{Path_pictures}/geodesic_grid_logmap_exp{exp_number}.pdf',f"{Path_pictures}/random_geodesics_exp{exp_number}.pdf"]

    merger = PdfWriter()

    for pdf in pdfs:
        merger.append(pdf)

    merger.write(f"{Path_pictures}/report_exp_{exp_number}.pdf")
    merger.close()

# 8. Frechet mean slow, (via a Sturm-like method).

In [None]:
num_cluster_points = 3
cluster = torch.pi*(torch.rand(num_cluster_points,2)-0.5)
plt.scatter(cluster[:,0],cluster[:,1])
plt.show()

In [None]:
frechet_mean = cluster[0]
for i in range(1,num_cluster_points):
    geodesic,_ = model.connecting_geodesic(frechet_mean, cluster[i])
    frechet_mean = geodesic(torch.tensor([1 / (i + 1)]))
frechet_mean = frechet_mean.detach()
print("frechet_mean:", frechet_mean)

In [None]:
frechet_mean = frechet_mean.squeeze()
plt.scatter(cluster[:,0],cluster[:,1])
plt.scatter(frechet_mean[0],frechet_mean[1], c = "red",marker = "*",s=200)
plt.show()