This notebook provides the first naive benchmark for the geodesics in the latent space of the autoencoder.

The benchmarks evaluates the effect of curvature regularization, measuring how far straight lines are from being the geodesics.
\medskip

Let us consider straight line segments $\alpha_{ij} \subset M \subset \R^2$ connecting pairs of points $X_i$
and $X_j$ given by:
$$
\alpha_{ij} ( t ) = t \Phi_{\theta}(X_i) + (1-t) \Phi_{\theta}(X_j) \ .
$$
Let us consider the images of $\alpha_{ij}$ through the decoder $\Psi$. One obtains curves in $\R^D$ given by:
$$
\gamma_{ij} ( t ) = \Psi(t \Phi_{\theta}(X_i) + (1-t) \Phi_{\theta}(X_j)) \ .
$$

Namely for a curve $\gamma_{ij}(t)$ one can consider several functionals:
The functional computing the second derivative of $\gamma (t)$. If $t$ were a natural parameter, the functional below would compute the norm of the acceleration of the curve along the curve. However it does not, once the metric in $M$ is not Euclidean and thus $t$ is not the natural parameter.
\begin{equation}
    \text{1. }\widetilde E_{ij} = \int\limits_0^1 \|\gamma_{ij}'' (t)\|^2 dt \ .
\end{equation}
The energy functional of $\alpha_{ij} ( t )$:
\begin{equation}
    \text{2. } E_{ij} = \int\limits_0^1 \|\alpha_{ij}' (t)\|_g^2 dt = \int\limits_0^1 \|\gamma_{ij}' (t)\|^2 dt \ ,
\end{equation}
recall $g = J_\Psi^* J_\Psi$ is the pull-back of the Euclidean metric by the decoder $\Psi$.

And finally, the acceleration functional:
\begin{equation}
    \text{3. } A_{ij} = \int\limits_0^1 \| \nabla_{\alpha_{ij}' (t)} \alpha_{ij}' (t) \|^2 dt \ ,
\end{equation}
All the functionals are avereged:
\begin{equation}
    \mathcal{E} = \frac{1}{\binom{K}{2}} \sum\limits_{1 \leq i < j \leq K} E_{ij} \ .
\end{equation}

# The first benchmark 

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import ricci_regularization

In [None]:
with open('../../experiments/MNIST_Setting_1_config.yaml', 'r') as yaml_file:
#with open('../../experiments/MNIST01_exp8_config.yaml', 'r') as yaml_file:
#with open('../../experiments/Swissroll_exp1_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

violent_saving = True # if False it will not save plots

#number of points to be paiwise connected by straight lines
#K = 100 # this can go up to 300 in practice
K = 150
d = 2
print("number of points to be paiwise connected by straight lines:", K)
print("straight lines constructed", int(K*(K-1)/2))

## Introducing the curve $\gamma (t)$ 

In [None]:
def gamma (t, x, y, decoder):
    return decoder(x*t + y*(1-t))

\begin{equation}
\gamma ' ( x ) \approx \frac{ \gamma ( x + h ) - \gamma ( x - h ) }{ 2 h }
\end{equation}



In [None]:
# first derivative at point (x,y) of the latent space 
def gamma_prime(t, h,  x, y):
    return (gamma(t+h, x, y) - gamma(t-h, x, y))/( 2 * h)

\begin{equation}
\gamma''(x) \approx \frac{\gamma(x+h) - 2\gamma(x) + \gamma(x-h)}{h^2}
\end{equation}

In [None]:
def gamma_second(t, h,  x, y, decoder ):
    return (gamma(t+h, x, y, decoder = decoder) - 2*gamma(t, x, y, decoder = decoder) + gamma(t-h, x, y, decoder = decoder))/(h**2)

 $$
 \widetilde E_{ij} \approx \frac{1}{2 (n-1)} \sum\limits_{k=0}^{n-2} \left( \| \gamma''_{ij}(t_k) \| + \| \gamma''_{ij}(t_{k+1}) \| \right)
 $$

In [None]:
def E_tilde(x_i,x_j,n_partition, decoder):
    n = n_partition
    segment_partition = ( 1 / (n-1) ) * torch.arange(n, dtype=torch.float32)
    gamma_second_array = gamma_second (segment_partition.reshape(-1,d),
              h = 1/n, x = x_i, y = x_j,
              decoder = decoder)
    gamma_second_norm_array = gamma_second_array.norm(dim=1)
    E_ij = ( 0.5 / ( n - 1 ) ) * torch.sum( (gamma_second_norm_array[:-1] + gamma_second_norm_array[1:]) )
    # return E_ij.item() doesnot work with vmap
    return E_ij

 $$
 E_{ij} \approx \frac{1}{2} \sum_{k=0}^{N} \left\| \gamma(kh+h) - \gamma(kh-h) \right\|_2^2
 $$

In [None]:
def delta_gamma(t, h,  x, y, decoder):
    return (gamma(t+h, x, y, decoder) - gamma(t-h, x, y, decoder))

In [None]:
def E(x_i,x_j,n_partition, decoder):
    # n_partition is number of points in partition for 
    # the integral approximation by its Riemann sum
    n = n_partition
    segment_partition = (1/(n-1))*torch.arange(n,dtype=torch.float32)

    delta_gamma_array = delta_gamma (segment_partition.reshape(-1,d),
              h = 1/n, x = x_i, y = x_j,
              decoder = decoder)
    delta_gamma_array_norm_array = delta_gamma_array.norm(dim=1)
    E_ij = 0.5 * torch.sum( delta_gamma_array_norm_array ** 2 )
    # return E_ij.item() doesnot work with vmap
    return E_ij

### vmap vectorization

In [None]:
E_tilde_vmap = torch.func.vmap(E_tilde)
E_vmap = torch.func.vmap(E)

In [None]:
import math

def make_pairs(batch_of_points):
    start_points_list = []
    end_points_list = []
    for i in range(K):
        for j in range(i+1,K):
            start_points_list.append(batch_of_points[i].unsqueeze(0))
            end_points_list.append(batch_of_points[j].unsqueeze(0))
    start_points = torch.cat(start_points_list, dim = 0)
    end_points = torch.cat(end_points_list)
    return start_points, end_points


# Loading data and nn weights

In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"]
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")
additional_path="../"

In [None]:
experiment_name = yaml_config["experiment"]["name"]

#Path_pictures = yaml_config["experiment"]["path"]
Path_pictures = additional_path + "../experiments/" + yaml_config["experiment"]["name"]
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name in ["MNIST01", "Synthetic"]:
    # k from the JSON configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
elif dataset_name == "MNIST":
    k = 10
print("Experiment name:", experiment_name)
print("Plots saved at:", Path_pictures)

In [None]:
l = len(test_dataset)

first_benchmark_data,rest = torch.utils.data.random_split(test_dataset, [K, l - K])

#reformating
first_benchmark_data = torch.stack([first_benchmark_data[i][0] for i in range(len(first_benchmark_data))])
first_benchmark_data = first_benchmark_data.reshape(-1,D)

d = yaml_config["architecture"]["latent_dim"]

loading the weights

In [None]:
#with open('../../experiments/Swissroll_exp0_config.yaml', 'r') as yaml_file:
#    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config, additional_path = additional_path)

torus_ae = torus_ae.to("cpu")

print(f"AE weights loaded successfully from {Path_ae_weights}.")

encoder = torus_ae.encoder_torus
decoder = torus_ae.decoder_torus

# Plotting the curves $\alpha_{ij}$

In [None]:
torch.manual_seed(0)
# define a square box where points are sampled
square_side = 2*torch.pi
center_of_square = torch.zeros(2)

random_points_latent_space = square_side * ( torch.rand(K, 2) - 0.5 )  + center_of_square
# Convert the tensor to a numpy array
points_np = random_points_latent_space.numpy()

# Plot the points
plt.figure(figsize=(8, 6))
plt.scatter(points_np[:, 0], points_np[:, 1], c='blue', marker='o', edgecolor='k')
plt.title('Random Points in Latent Space')
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')
plt.grid(True)
plt.show()

In [None]:
# use the same Random points (uniformly distributed) not depending on the encoder:

start_points_latent_space, end_points_latent_space = make_pairs(random_points_latent_space)
# Random points depending on the encoder:
#start_points, end_points = make_pairs(first_benchmark_data)
#start_points_latent_space = encoder(start_points).detach()
#end_points_latent_space = encoder(end_points).detach()

plt.figure(figsize=(8, 6))
plt.scatter(start_points_latent_space[:, 0], start_points_latent_space[:, 1], color='blue', label='Start Points')
plt.scatter(end_points_latent_space[:, 0], end_points_latent_space[:, 1], color='red', label='End Points')
for start, end in zip(start_points_latent_space, end_points_latent_space):
    plt.plot([start[0], end[0]], [start[1], end[1]], 'k--', color='blue')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title(r'Lines $\alpha_{ij}$ Connecting Points in the latent space')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
#computing Functiolnals 

E_tilde_array = E_tilde_vmap (start_points_latent_space, end_points_latent_space, n_partition=100, decoder = decoder)
Energy_array = E_vmap (start_points_latent_space, end_points_latent_space, n_partition=100, decoder = decoder)

#E_tilde_array = E_tilde_vmap(encoder(start_points),encoder(end_points),n_partition=100,decoder = decoder)
#Energy_array = E_vmap(encoder(start_points),encoder(end_points),n_partition=100,decoder = decoder)
#Distance_ls_pairwize_array_no_curv_pen = (encoder(end_points) - encoder(start_points)).norm(dim=1)
#Distance_RD_pairwize_array_no_curv_pen = (end_points - start_points).norm(dim=1)

# Histograms

Functionals $\widetilde E_{ij}$.

In [None]:
# Create a histogram
plt.figure(figsize=(10, 6))
plt.hist(E_tilde_array.detach(), bins=30, edgecolor='black', alpha=0.7)

# Add labels and title
plt.xlabel('Functional $\widetilde E_{ij}$ Values')
plt.ylabel('Frequency')
#plt.title(r'Histogram of energy functional $E_{ij}$ values')

# Show grid for better readability
plt.grid(True)
plt.savefig(Path_pictures+f"/Histogram_E_tilde_ij_{experiment_name}.pdf",bbox_inches='tight', format = "pdf")
# Display the plot
plt.show()

Statistical analisis

In [None]:

samples = E_tilde_array.detach()  # Detach from the computation graph
N = len(samples)  # Number of samples (straight lines)
mean_value = samples.mean().item()  # Mean value of the energy functional
std_dev = torch.std(samples).item()  # Standard deviation
SE = std_dev / math.sqrt(N)  # Standard error (SE)
# printing here
print("Number of samples):", N)
print(f"Mean value of funtional E_tilde: {mean_value:.3f}")
print(f"Std of funtional E_tilde: {std_dev:.3f}")
print(f"Standard error of mean (SE): {SE:.3f}")
# Define the path to the output file
output_file_path = f"{Path_pictures}/statistical_analysys_E_tilde_{experiment_name}.txt"

# Save the results to the text file
with open(output_file_path, 'w') as f:
    f.write(f"Number of straight lines (samples): {N}\n")
    f.write(f"Mean value of energy functional E: {mean_value:.3f}\n")
    f.write(f"Standard deviation of E: {std_dev:.3f}\n")
    f.write(f"Standard error of the mean (SE): {SE:.3f}\n")

print(f"Results saved to {output_file_path}")

Functionals $ E_{ij}$.

In [None]:
# Create a histogram
plt.figure(figsize=(10, 6))
plt.hist(Energy_array.detach(), bins=30, edgecolor='black', alpha=0.7)

# Add labels and title
plt.xlabel('Energy functional $E_{ij}$ Values')
plt.ylabel('Frequency')
#plt.title(r'Histogram of energy functional $E_{ij}$ values')

# Show grid for better readability
plt.grid(True)
plt.savefig(Path_pictures+f"/Histogram_E_ij_{experiment_name}.pdf",bbox_inches='tight', format = "pdf")
# Display the plot
plt.show()

Statistical analisis

In [None]:
# Assuming the necessary imports and variable definitions are here
samples = Energy_array.detach()  # Detach from the computation graph
N = len(samples)  # Number of samples (straight lines)
mean_value = samples.mean().item()  # Mean value of the energy functional
std_dev = torch.std(samples).item()  # Standard deviation
SE = std_dev / math.sqrt(N)  # Standard error (SE)
# printing here
print("Number of straight lines(samples):", N)
print(f"Mean value of energy funtional E: {mean_value:.3f}")
print(f"Std of energy funtional E: {std_dev:.3f}")
print(f"Standard error of mean (SE): {SE:.3f}")
# Define the path to the output file
output_file_path = f"{Path_pictures}/statistical_analysys_E_{experiment_name}.txt"

# Save the results to the text file
with open(output_file_path, 'w') as f:
    f.write(f"Number of straight lines (samples): {N}\n")
    f.write(f"Mean value of energy functional E: {mean_value:.3f}\n")
    f.write(f"Standard deviation of E: {std_dev:.3f}\n")
    f.write(f"Standard error of the mean (SE): {SE:.3f}\n")

print(f"Results saved to {output_file_path}")

here the mess starts

In [None]:
# with penalty on curvature
load_weight_name = "swissroll_curv_w=1_ls=R^2"
#load_weight_name = "swissroll_curv_w=10_ls=R^2"
#load_weight_name = "swissroll_curv_w=10_ls=R^2_20epochs_bs=32"
PATH_enc = f'../nn_weights/encoder_{load_weight_name}'
encoder.load_state_dict(torch.load(PATH_enc))
encoder.eval()
PATH_dec = f'../nn_weights/decoder_{load_weight_name}'
decoder.load_state_dict(torch.load(PATH_dec))
decoder.eval()

start_points, end_points = make_pairs()

Energy_pairwise_array_with_curv_pen = E_vmap(encoder(start_points),encoder(end_points),n_partition=100)
Distance_ls_pairwize_array_with_curv_pen = (encoder(end_points) - encoder(start_points)).norm(dim=1)
Distance_RD_pairwize_array_with_curv_pen = (end_points - start_points).norm(dim=1)
#Energy_pairwise_array_with_curv_pen = E_vmap(start_points,end_points,n_partition=100)
#Distance_ls_pairwize_array_with_curv_pen = (end_points - start_points).norm(dim=1)
#Distance_RD_pairwize_array_with_curv_pen = (decoder(end_points) - decoder(start_points)).norm(dim=1)

In [None]:
# with penalty on curvature

load_weight_name = "swissroll_curv_w=10_ls=R^2_20epochs_bs=32"
PATH_enc = f'../nn_weights/encoder_{load_weight_name}'
encoder.load_state_dict(torch.load(PATH_enc))
encoder.eval()
PATH_dec = f'../nn_weights/decoder_{load_weight_name}'
decoder.load_state_dict(torch.load(PATH_dec))
decoder.eval()

start_points, end_points = make_pairs()

Energy_pairwise_array_curv_w10_pen = E_vmap(encoder(start_points),encoder(end_points),n_partition=100)


In [None]:
Distance_RD_pairwize_array_with_curv_pen

In [None]:
Distance_RD_pairwize_array_no_curv_pen

# ground truth check

In [None]:
plt.hist((Distance_RD_pairwize_array_no_curv_pen-Distance_RD_pairwize_array_with_curv_pen).detach(), bins = 50)
plt.title("Distances in $\mathbb{R}^D$ with and without curvature penalization")
plt.show()

In [None]:
plt.hist((Distance_ls_pairwize_array_no_curv_pen-Distance_ls_pairwize_array_with_curv_pen).detach(), bins = 50)
plt.title("Distances in latent space with and without curvature penalization")
plt.show()

## plotting

In [None]:
plt.rcParams.update({'font.size': 24})
fig, ax = plt.subplots(figsize=(9,9),dpi = 300)
#plt.title(f"Swissroll: Energy of $C_{{{K}}}^2$ paths $E_{{ij}}$ vs distance in $\mathbb{{R}}^3$")
ax.scatter(Distance_RD_pairwize_array_no_curv_pen.detach().numpy(),Energy_array.detach().numpy(),
            color = "red", s = 10, label = "$\lambda_{curv} = 0$")
ax.scatter(Distance_RD_pairwize_array_with_curv_pen.detach().numpy(),Energy_pairwise_array_with_curv_pen.detach().numpy(),
            color = "blue", s = 10, label = "$\lambda_{curv} = 1$")
ax.scatter(Distance_RD_pairwize_array_with_curv_pen.detach().numpy(),Energy_pairwise_array_curv_w10_pen.detach().numpy(),
            color = "green", s = 10, label = "$\lambda_{curv} = 10$")
ax.legend(loc='upper left')
#ax.set_xlabel('$\|\Psi \circ \Theta(X_i) - \Psi \circ \Theta(X_j)\|_2$, $\Psi \circ \Theta(X_i)\in \mathbb{R}^3$')
#ax.set_xlabel('$\|X_i - X_j\|_2$, $X_i \in \mathbb{R}^3$')
ax.set_xlabel('$\|X_i - X_j\|_2$')
ax.set_ylabel('$E_{ij}$')
fig.savefig("Scatterplot_E_ij_dist_R3_swissroll.pdf",bbox_inches='tight', format = "pdf")
plt.show()

In [None]:
plt.rcParams.update({'font.size': 24})
fig, ax = plt.subplots(figsize=(9,9),dpi = 300)
#plt.title(f"Swissroll: enegry of paths change $E_{{ij}}^{{\lambda_{{curv}} = 1}} - E_{{ij}}^{{\lambda_{{curv}} = 0}}$ \n vs distance in $\mathbb{{R}}^3$ for $C_{{{K}}}^2$ paths")
ax.scatter(Distance_RD_pairwize_array_with_curv_pen,
           (Energy_pairwise_array_with_curv_pen - Energy_array).detach().numpy(),
            color = "magenta", s = 10)
#ax.set_xlabel('$\|\Psi \circ \Theta(X_i) - \Psi \circ \Theta(X_j)\|_2$, $\Psi \circ \Theta(X_i)\in \mathbb{R}^3$')
#ax.set_xlabel('$\|X_i - X_j\|_2$, $X_i \in \mathbb{R}^3$')
ax.set_xlabel('$\|X_i - X_j\|_2$')
ax.set_ylabel('Change in $E_{ij}$')
#ax.set_ylabel('$E_{ij}^{\lambda_{curv} = 1} - E_{ij}^{\lambda_{curv} = 0}$')
fig.savefig("Scatterplot_change_E_ij_dist_R3_swissroll.pdf",bbox_inches='tight', format = "pdf")
plt.show()

### logscale

In [None]:
plt.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(figsize=(9,9),dpi = 300)
plt.title(f"Swissroll: Energy of $C_{{{K}}}^2$ paths $E_{{ij}}$ vs distance in $\mathbb{{R}}^3$")
ax.scatter(Distance_RD_pairwize_array_no_curv_pen.detach().numpy(),Energy_array.detach().numpy(),
            color = "red", s = 10, label = "$\lambda_{curv} = 0$")
ax.scatter(Distance_RD_pairwize_array_with_curv_pen.detach().numpy(),Energy_pairwise_array_with_curv_pen.detach().numpy(),
            color = "blue", s = 10, label = "$\lambda_{curv} = 1$")
ax.legend(loc='upper left')
#ax.set_xlabel('$\|\Psi \circ \Theta(X_i) - \Psi \circ \Theta(X_j)\|_2$, $\Psi \circ \Theta(X_i)\in \mathbb{R}^3$')
ax.set_xlabel('$\|X_i - X_j\|_2$, $X_i \in \mathbb{R}^3$')
ax.set_ylabel('$E_{ij}$')
#ax.set_xscale('log')
ax.set_yscale('log')
fig.savefig("Scatterplot_E_ij_dist_R3_swissroll_ylogscale.pdf", format = "pdf")
plt.show()

In [None]:
plt.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(figsize=(9,9),dpi = 300)
plt.title(f"Swissroll: Energy of $C_{{{K}}}^2$ paths $E_{{ij}}$ vs distance in latent space")
ax.scatter(Distance_ls_pairwize_array_no_curv_pen.detach().numpy(),Energy_array.detach().numpy(),
            color = "red", s = 10, label = "$\lambda_{curv} = 0$")
ax.scatter(Distance_ls_pairwize_array_with_curv_pen.detach().numpy(),Energy_pairwise_array_with_curv_pen.detach().numpy(),
            color = "blue", s = 10, label = "$\lambda_{curv} = 1$")
ax.legend(loc='upper left')
ax.set_xlabel('$\|\Theta(X_i) - \Theta(X_j)\|_2$, $\Theta(X_i) \in \mathbb{R}^2$')
ax.set_ylabel('$E_{ij}$')
fig.savefig("Scatterplot_E_ij_dist_ls_swissroll.pdf", format = "pdf")
plt.show()