This notebook provides Evaluations 1-4 for the geodesics in the latent space of the autoencoder. Later the results are used in sections 4.52 and 4.5.3 of my thesis.

Evaluation 5 is still work in progress

The benchmarks evaluates the effect of curvature regularization, measuring how far straight lines are from being the geodesics.

Let us consider straight line segments $\alpha_{ij} \subset M \subset \mathbb{R}^2$ connecting pairs of points $X_i$
and $X_j$ given by:
$$
\alpha_{ij} ( t ) = t \Phi_{\theta}(X_i) + (1-t) \Phi_{\theta}(X_j) \ .
$$
Let us consider the images of $\alpha_{ij}$ through the decoder $\Psi$. One obtains curves in $\mathbb{R}^D$ given by:
$$
\gamma_{ij} ( t ) = \Psi(t \Phi_{\theta}(X_i) + (1-t) \Phi_{\theta}(X_j)) \ .
$$

Namely for a curve $\gamma_{ij}(t)$ one can consider several functionals:
The functional computing the second derivative of $\gamma (t)$. If $t$ were a natural parameter, the functional below would compute the norm of the acceleration of the curve along the curve. However it does not, once the metric in $M$ is not Euclidean and thus $t$ is not the natural parameter.
\begin{equation}
    \text{1. }\widetilde E_{ij} = \int\limits_0^1 \|\gamma_{ij}'' (t)\|^2 dt \ .
\end{equation}
The energy functional of $\alpha_{ij} ( t )$:
\begin{equation}
    \text{2. } E_{ij} = \int\limits_0^1 \|\alpha_{ij}' (t)\|_g^2 dt = \int\limits_0^1 \|\gamma_{ij}' (t)\|^2 dt \ ,
\end{equation}
recall $g = J_\Psi^* J_\Psi$ is the pull-back of the Euclidean metric by the decoder $\Psi$.

And finally, the acceleration functional:
\begin{equation}
    \text{3. } A_{ij} = \int\limits_0^1 \| \nabla_{\alpha_{ij}' (t)} \alpha_{ij}' (t) \|^2 dt \ ,
\end{equation}
All the functionals are avereged:
\begin{equation}
    \mathcal{E} = \frac{1}{\binom{K}{2}} \sum\limits_{1 \leq i < j \leq K} E_{ij} \ .
\end{equation}

# The first benchmark 

In [None]:
import torch
import matplotlib.pyplot as plt
import os
import yaml
import ricci_regularization

In [None]:
with open('../../experiments/MNIST_Setting_1_config.yaml', 'r') as yaml_file:
#with open('../../experiments/MNIST01_exp7_config.yaml', 'r') as yaml_file:
#with open('../../experiments/Swissroll_exp4_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

violent_saving = False # if False it will not save plots

#number of points to be paiwise connected by straight lines
#K = 100 # this can go up to 300 in practice
K = 10
d = 2
print("number of points to be paiwise connected by straight lines:", K)
print("straight lines constructed", int(K*(K-1)/2))

# Loading data and nn weights

In [None]:
# Load data loaders based on YAML configuration
# Set the random seed for data loader reproducibility
torch.manual_seed(yaml_config["data_loader_settings"]["random_seed"])
print(f"Set random seed to: {yaml_config['data_loader_settings']['random_seed']}")

try:
    dtype = getattr(torch, yaml_config["architecture"]["weights_dtype"]) # Convert the string to actual torch dtype
except KeyError:
    dtype = torch.float32
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=dtype
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")
additional_path="../"

In [None]:
experiment_name = yaml_config["experiment"]["name"]

#Path_pictures = yaml_config["experiment"]["path"]
Path_pictures = additional_path + "../experiments/" + yaml_config["experiment"]["name"]
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name in ["MNIST01", "Synthetic"]:
    # k from the JSON configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
elif dataset_name == "MNIST":
    k = 10
print("Experiment name:", experiment_name)
print("Plots saved at:", Path_pictures)

In [None]:
l = len(test_dataset)

first_benchmark_data,rest = torch.utils.data.random_split(test_dataset, [K, l - K])

#reformating
first_benchmark_data = torch.stack([first_benchmark_data[i][0] for i in range(len(first_benchmark_data))])
first_benchmark_data = first_benchmark_data.reshape(-1,D)

d = yaml_config["architecture"]["latent_dim"]

loading the weights

In [None]:
#with open('../../experiments/Swissroll_exp0_config.yaml', 'r') as yaml_file:
#    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config, additional_path = additional_path)

torus_ae = torus_ae.to("cpu")

print(f"AE weights loaded successfully from {Path_ae_weights}.")

encoder = torus_ae.encoder_torus
decoder = torus_ae.decoder_torus

## Introducing the curve $\gamma (t)$ 

In [None]:
#evaluates gamma at t between x and y
def gamma (t, x, y, decoder):
    return decoder( y * t + x * ( 1 - t ) )

\begin{equation}
\gamma ' ( x ) \approx \frac{ \gamma ( x + h ) - \gamma ( x - h ) }{ 2 h }
\end{equation}



In [None]:
# first derivative of gamma that connects x and y at t 
def gamma_prime(t, h,  x, y, decoder):
    return (gamma(t+h, x, y,decoder) - gamma(t-h, x, y, decoder))/( 2 * h)

\begin{equation}
\gamma''(x) \approx \frac{\gamma(x+h) - 2\gamma(x) + \gamma(x-h)}{h^2}
\end{equation}

In [None]:
def gamma_second(t, h,  x, y, decoder ):
    return (gamma(t+h, x, y, decoder = decoder) - 2*gamma(t, x, y, decoder = decoder) + gamma(t-h, x, y, decoder = decoder))/(h**2)

 $$
 \widetilde E_{ij} \approx \frac{1}{2 (n-1)} \sum\limits_{k=0}^{n-2} \left( \| \gamma''_{ij}(t_k) \|^2 + \| \gamma''_{ij}(t_{k+1}) \|^2 \right)
 $$

In [None]:
def E_tilde(x_i,x_j,n_partition, decoder):
    n = n_partition
    segment_partition = ( 1 / (n-1) ) * torch.arange(n, dtype=torch.float32)
    segment_partition_dim_d = segment_partition.repeat(d,1).T
    gamma_second_array = gamma_second (segment_partition_dim_d,
              h = 1/n, x = x_i, y = x_j,
              decoder = decoder)
    gamma_second_norm_array = gamma_second_array.norm(dim=1)
    E_ij = ( 0.5 / ( n - 1 ) ) * torch.sum( (gamma_second_norm_array[:-1]**2 + gamma_second_norm_array[1:]**2) )
    # return E_ij.item() doesnot work with vmap
    return E_ij

\begin{equation}
\begin{aligned}
 E_{ij} 
 &\approx \sum_{k=0}^{N} h \left\| \frac{ \gamma(kh+h) - \gamma(kh-h)}{ 2 h } \right\|_2^2 \\
 &= \frac{1}{4 h} \sum_{k=0}^{N} \left\| \gamma(kh+h) - \gamma(kh-h) \right\|_2^2
\end{aligned}
\end{equation}


In [None]:
def delta_gamma(t, h,  x, y, decoder):
    return (gamma(t+h, x, y, decoder) - gamma(t-h, x, y, decoder))

In [None]:
def E(x_i,x_j,n_partition, decoder):
    # n_partition is number of points in partition for 
    # the integral approximation by its Riemann sum
    n = n_partition
    segment_partition = (1/(n-1))*torch.arange(n,dtype=torch.float32)
    segment_partition_dim_d = segment_partition.repeat(d,1).T
    delta_gamma_array = delta_gamma (segment_partition_dim_d,
              h = 1/ (n + 1), x = x_i, y = x_j,
              decoder = decoder)
    delta_gamma_array_norm_array = delta_gamma_array.norm(dim=1)
    # h = 1/ (N + 1)
    E_ij = 0.25 * (n + 1) * torch.sum( delta_gamma_array_norm_array ** 2 )
    # return E_ij.item() doesnot work with vmap
    return E_ij

$$A_{ij} = \int_0^1 \| \gamma''_T \|^2 dt
\approx \frac{1}{N - 1} \sum_{k = 1}^N \| Proj_{ \langle v_1, v_2 \rangle }( \gamma(kh+h) - 2\gamma(kh) + \gamma(kh-h) ) \|_2^2 \ ,
$$
where $v_1$ and $v_2$ are orthogonormal basis of $\langle d \Psi e_1 , d \Psi e_2 \rangle$ at point $kh$

$\gamma(x+h) - 2\gamma(x) + \gamma(x-h)$

In [None]:
def delta_gamma_second(t, h,  x, y, decoder):
    gamma_right = gamma(t+h, x, y, decoder)
    gamma_left = gamma(t - h, x, y, decoder)
    gamma_central = gamma(t, x, y, decoder)
    return gamma_left + gamma_right - 2 * gamma_central

In [None]:
# rewrite using delta_second of the gamma
def A(x_i,x_j,n_partition, decoder):
    # n_partition is number of points in partition for 
    # the integral approximation by its Riemann sum
    n = n_partition
    segment_partition = (1/(n-1))*torch.arange(n, dtype=torch.float32)
    segment_partition_dim_d = segment_partition.repeat(d,1).T
    alpha_array = x_j * segment_partition_dim_d + x_i * ( 1 - segment_partition_dim_d )

    dPsi = torch.func.vmap( torch.func.jacfwd(decoder)) 
    Q,_ = dPsi(alpha_array).qr()
    gamma_second_array = gamma_second (segment_partition_dim_d,
                h = 1/(n - 1), x = x_i, y = x_j,
                decoder = decoder)
    # parallel multiplication of batches of matrices Q and vector gamma_second 
    # evaluated at intermediate points od segment partition
    A_ij = (1 / ( n - 1 )) * ( torch.matmul(Q.transpose(-1,-2), gamma_second_array.unsqueeze(-1))**2  ).sum()
    return A_ij.detach()

In [None]:
import math

def make_pairs(batch_of_points):
    start_points_list = []
    end_points_list = []
    for i in range(K):
        for j in range(i+1,K):
            start_points_list.append(batch_of_points[i].unsqueeze(0))
            end_points_list.append(batch_of_points[j].unsqueeze(0))
    start_points = torch.cat(start_points_list, dim = 0)
    end_points = torch.cat(end_points_list)
    return start_points, end_points


In [None]:
h=1e-2
(1/h**4)*delta_gamma_second(t=0,x=torch.tensor([[0., 0.]]), y =torch.tensor([[1., 1.]]), decoder=decoder, h=h).norm()**2

In [None]:
# instable once h < 1.e-3
gamma_second(t=0,x=torch.tensor([[0., 0.]]), y =torch.tensor([[1., 1.]]), decoder=decoder, h=h).norm()**2

In [None]:
f = lambda x : torch.sin(x)
def ddf (x,h):
    return f(x+h) - 2*f(x) +  f(x-h)
def f_second (x,h):
    return (f(x+h) - 2*f(x) +  f(x-h))/(h**2)

In [None]:
x_0 = torch.zeros(1) + 0.2
print(f"ddf({x_0.item()}) using h = 1.e-3")
print(ddf(x_0, h = 1.e-3).item())
print(f"ddf({x_0.item()}) using h = 1.e-5")
print(ddf(x_0, h = 1.e-5).item())

print(f"f''({x_0.item()}) using h = 1.e-3")
print(f_second(x_0, h = 1.e-3).item())
print(f"f''({x_0.item()}) uning h = 1.e-5")
print(f_second(x_0, h = 1.e-5).item())

# Plotting the curves $\alpha_{ij}$

In [None]:
torch.manual_seed(0)
# define a square box where points are sampled
square_side = 2*torch.pi
center_of_square = torch.zeros(2)

random_points_latent_space = square_side * ( torch.rand(K, 2) - 0.5 )  + center_of_square
# Convert the tensor to a numpy array
points_np = random_points_latent_space.numpy()

# Plot the points
plt.figure(figsize=(8, 6))
plt.scatter(points_np[:, 0], points_np[:, 1], c='blue', marker='o', edgecolor='k')
plt.title('Random Points in Latent Space')
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')
plt.grid(True)
plt.show()

In [None]:
# use the same Random points (uniformly distributed) not depending on the encoder:

start_points_latent_space, end_points_latent_space = make_pairs(random_points_latent_space)
# Random points depending on the encoder:
#start_points, end_points = make_pairs(first_benchmark_data)
#start_points_latent_space = encoder(start_points).detach()
#end_points_latent_space = encoder(end_points).detach()
"""
plt.figure(figsize=(8, 6))
plt.scatter(start_points_latent_space[:, 0], start_points_latent_space[:, 1], color='blue', label='Start Points')
plt.scatter(end_points_latent_space[:, 0], end_points_latent_space[:, 1], color='red', label='End Points')
for start, end in zip(start_points_latent_space, end_points_latent_space):
    plt.plot([start[0], end[0]], [start[1], end[1]], 'k--', color='blue')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title(r'Lines $\alpha_{ij}$ Connecting Points in the latent space')
plt.legend()
plt.grid(True)
plt.show()
"""

### vmap vectorization and computation of functionals $\widetilde E$, $E$ and $A$

In [None]:
E_tilde_vmap = torch.func.vmap(E_tilde)
E_vmap = torch.func.vmap(E)
A_vmap = torch.func.vmap(A)

#computing Functiolnals 
n_partition = 100

E_tilde_array = E_tilde_vmap (start_points_latent_space, end_points_latent_space, n_partition=n_partition, decoder = decoder)
Energy_array = E_vmap (start_points_latent_space, end_points_latent_space, n_partition=n_partition, decoder = decoder)
Acceleration_array = A_vmap (start_points_latent_space, end_points_latent_space, n_partition=n_partition, decoder = decoder)

#E_tilde_array = E_tilde_vmap(encoder(start_points),encoder(end_points),n_partition=100,decoder = decoder)
#Energy_array = E_vmap(encoder(start_points),encoder(end_points),n_partition=100,decoder = decoder)
#Distance_ls_pairwize_array_no_curv_pen = (encoder(end_points) - encoder(start_points)).norm(dim=1)
#Distance_RD_pairwize_array_no_curv_pen = (end_points - start_points).norm(dim=1)

# Histograms and statistical analysis

In [None]:
def plot_histogram(samples, samples_name: str, n_bins = None, xlim = None, show_title_labels = False):
    
    samples = samples.detach()  # Detach from the computation graph
    N = len(samples)  # Number of samples (straight lines)
    if n_bins == None:
        num_bins = 1 + math.ceil( math.log2(N) ) # Sturge's rule
    else:
        num_bins = n_bins
    # num_bins = 30 was used before

    # Create a histogram
    plt.figure(figsize=(10, 6))
    plt.hist(samples.detach(), bins=num_bins, edgecolor='black', alpha=0.7, density=False)

    # Add labels and title
    #plt.xlabel(f'${samples_name}_{{ij}}$ Values')
    if show_title_labels:
        plt.title(f'{samples_name} values')
        plt.ylabel('Frequency')
    if xlim != None:
        plt.xlim(0, xlim)

    # Show grid for better readability
    plt.grid(True)
    file_name = Path_pictures+f"/Histogram_{samples_name}_ij_{experiment_name}.pdf"
    plt.savefig(file_name, bbox_inches='tight', format = "pdf")
    print("Histogram saved at:", file_name)
    return plt

def statistical_analysis(samples: torch.tensor, samples_name: str):
    # Assuming the necessary imports and variable definitions are here
    samples = samples.detach()  # Detach from the computation graph
    N = len(samples)  # Number of samples (straight lines)
    mean_value = samples.mean().item()  # Mean value of the energy functional
    std_dev = torch.std(samples).item()  # Standard deviation
    SE = std_dev / math.sqrt(N)  # Standard error (SE)
    # printing here
    print("Number of straight lines(samples):", N)
    print(f"Mean value of {samples_name}: {mean_value:.3f}")
    print(f"Std of {samples_name}: {std_dev:.3f}")
    print(f"Standard error of mean (SE): {SE:.4f}")
    # Define the path to the output file
    output_file_path = f"{Path_pictures}/statistical_analysys_{samples_name}_{experiment_name}.txt"

    # Save the results to the text file
    with open(output_file_path, 'w') as f:
        f.write(f"Number of straight lines (samples): {N}\n")
        f.write(f"Mean value of {samples_name}: {mean_value:.3f}\n")
        f.write(f"Standard deviation of {samples_name}: {std_dev:.3f}\n")
        f.write(f"Standard error of the mean (SE): {SE:.4f}\n")

    print(f"Results saved to {output_file_path}")
    return

Functionals $\widetilde E_{ij}$.

In [None]:
p_E_tilde = plot_histogram(samples=E_tilde_array, samples_name="E_tilde",n_bins=30)
p_E_tilde.show()
statistical_analysis(E_tilde_array, "E_tilde")

Functionals $ E_{ij}$.

In [None]:
p_E = plot_histogram(samples=Energy_array, samples_name="E",n_bins=30)
p_E.show()
statistical_analysis(Energy_array, "E")

Functionals $ A_{ij}$.

In [None]:
p_A = plot_histogram(samples=Acceleration_array, samples_name="A", xlim=400e+3)
p_A.show()
statistical_analysis(Acceleration_array, "A")

# Saving results

In [None]:
# Save the tensors to a single file
torch.save({'E^1': E_tilde_array, 'E^2': Energy_array, 'E^3': Acceleration_array}, Path_pictures+'/3Functionals.pt')
print("results saved at:", Path_pictures+'3Functionals.pt')


Now let us compute energy rations between the geodesics $\beta_{ij}$ and straigt line segments $\alpha_{ij}$. Note that:
$$
R_{ij}^{(1)}
= \frac{\int_0^1 \| \dot \beta_{ij}(t)\|_g^2 dt}{\int_0^1 \| \dot \gamma_{ij} (t)\|_2^2 dt} 
= \frac{\int_0^1 \| \dot \beta_{ij}(t)\|_g^2 dt}{\int_0^1 \| \dot \alpha_{ij} (t)\|_g^2 dt} \ .
$$

# Energy ratios: $ \frac{\int_0^1 \| \dot \beta_{ij}(t)\|_g^2 dt}{\int_0^1 \| \dot \gamma_{ij} (t)\|_2^2 dt} \ . $

A demo for two points

In [None]:
from ricci_regularization import NumericalGeodesics

In [None]:
geodesic_solver = NumericalGeodesics(n_max=7, step_count=100)

In [None]:
num_epochs = 100
optimization_device = "cuda"

torus_ae = torus_ae.to(optimization_device)
# alphas are linear segments, betas are geodesics
alpha_array, beta_array = geodesic_solver.computeGeodesicInterpolationBatch(generator=torus_ae.decoder_torus,
                                                  m1_batch=start_points_latent_space, 
                                                  m2_batch=end_points_latent_space,
                                                  epochs=num_epochs, device=optimization_device)

In [None]:
# computing energies
torus_ae = torus_ae.to("cpu")
alpha_energies_array = ricci_regularization.compute_energy(alpha_array, decoder=torus_ae.decoder_torus,
                                                           reduction="none")
beta_energies_array = ricci_regularization.compute_energy(beta_array, decoder=torus_ae.decoder_torus,
                                                           reduction="none")

Saving energy ratios

In [None]:
torch.save(beta_energies_array.detach(), Path_pictures+'/geodesic_energy_array.pt')
print("results saved at:", Path_pictures+'/geodesic_energy_array.pt')

In [None]:
p = plot_histogram(beta_energies_array / alpha_energies_array, samples_name="Naive_energy_ratio")
p.show()
statistical_analysis(beta_energies_array / alpha_energies_array, samples_name="Naive_energy_ratio")
