NB! This is an old notebook. It contains a wrong (very unprecise) way of computing Scalar curvature of the latent space with f.d. formulas used for computing derivatives. Has to be redone!

The autoencoder (AE) consists of the encoder $\Phi$ and the decoder $\Psi$.
The latent space of the AE is $R^d$. We define a Riemannian metric in a local chart of the latent space as the pull-back of the Euclidean metric in the output space $R^D$ by the decoder function $\Psi$ of the AE:
\begin{equation*}
    g = \nabla \Psi ^* \nabla \Psi   
\end{equation*}.

The notebook contains:
1) Loading weights of a pre-trained convolutional AE and plotting its latent space: point plot and manifold plot. If "violent_saving" == True, plots are saved locally.
2) Auxillary tensors involving higher order derivatives of the decoder $\Psi$ are computed with f.d.: metric $g$ and its derivatives, Riemann tensor $R^{i}_{jkl}$, Ricci tensor $R_{ij}$ and scalar curvature.
3) Geodesics shooting via Runge-Kutta approximation. A single plot with a scalar curvature heatmap and geodesics on it is constructed.
4) Prototype of metric evolution by Ricci flow equation 

NB! by default the metric $g$ is the pull-back by the decoder as described above. But one can use any custom metric by manually setting it in "specific_metric" function, that computes the metric matrix at a point $u\in \mathbb{R}: \ g(u)$ given the local coordinates of the point $u$ in the latent space.

In [None]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd
import torch
import torchvision
import ricci_regularization
import yaml

In [None]:
with open('../../experiments/MNIST_Setting_1_config.yaml', 'r') as yaml_file:
#with open('../../experiments/MNIST01_exp7_config.yaml', 'r') as yaml_file:
#with open('../../experiments/Swissroll_exp4_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

violent_saving = False # if False it will not save plots

d = 2

# Loading data and nn weights

In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"]
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")
additional_path="../"

In [None]:
experiment_name = yaml_config["experiment"]["name"]

#Path_pictures = yaml_config["experiment"]["path"]
Path_pictures = additional_path + "../experiments/" + yaml_config["experiment"]["name"]
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name in ["MNIST01", "Synthetic"]:
    # k from the JSON configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
elif dataset_name == "MNIST":
    k = 10
print("Experiment name:", experiment_name)
print("Plots saved at:", Path_pictures)

In [None]:
### Initialize the two networks

torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config, additional_path = additional_path)

torus_ae = torus_ae.to("cpu")

print(f"AE weights loaded successfully from {Path_ae_weights}.")

encoder = torus_ae.encoder_torus
decoder = torus_ae.decoder_torus

# Metric on a grid via torch.roll

In [None]:
# Defining the scale of the grid which is numsteps \times numsteps points

numsteps = 100

tgrid = ricci_regularization.FiniteDifferences.make_grid(numsteps= numsteps)

In [None]:

# takes 42 secs
#numsteps_values = [10, 25, 50, 75, 100, 250, 500, 750, 1000]
size = torch.pi / 10 # grid side size
numsteps_values = [7, 15, 30, 50, 80, 160, 320]
mse_errors = []
mean_metric_frob_norm_values = []
std_metric_frob_norm_values = []
h_values = []
for numsteps in numsteps_values:
    grid = ricci_regularization.FiniteDifferences.make_grid(numsteps= numsteps, 
        xlim_left= -size/2, xlim_right= size/2, ylim_bottom= -size/2, ylim_top= size/2)
    error = ricci_regularization.error_fd_jacfwd_on_grid(tensor_fd=ricci_regularization.metric_fd_grid(grid, function=decoder),
                                             tensor_jacfwd=ricci_regularization.metric_jacfwd_vmap(grid, function=decoder),cut=1)
    mse_errors.append(error.item())
    #mse_errors.append(ricci_regularization.compute_error_metric_on_grid(numsteps, function=decoder).item())
    h_values.append(1 / numsteps)

# Create the plot
plt.figure(figsize=(8, 6),dpi=300)
plt.loglog(h_values, mse_errors, marker='o', linestyle='-', color='b', markersize=4)


# Set grid for easier visualization
#plt.grid(True, which="both", ls="--")
plt.ylabel("Error")
plt.xlabel("Step h")

# Display the plot
plt.title(r'Log-Log plot of step $h$ vs. MSE Error of metric $g$ .', fontsize=14)
plt.savefig(Path_pictures+"/metric_error_fd_jacfwd.pdf", bbox_inches='tight', format = "pdf")
plt.show()

## Heatmap of frobenius norm of metric

In [None]:
# Fast computation of Frobenious norm on the grid without borders
numsteps= 100
tgrid = ricci_regularization.FiniteDifferences.make_grid(numsteps= numsteps)

metric = ricci_regularization.metric_fd_grid(tgrid, function=decoder)
Newfrob = metric.norm(dim=(1,2)).view(numsteps,numsteps)
Newfrob = Newfrob[1:-1,1:-1].transpose(0,1)

In [None]:

# Create the ticks for the axes (3 ticks between the borders)
num_ticks = 5  # 3 between borders + 2 borders
ticks = np.linspace(-np.pi, np.pi, num_ticks)
tick_labels = [f'{tick:.2f}' for tick in ticks]  # Formatting tick labels for clarity

# Plotting the colormap
plt.figure(figsize=(6, 6))  # Square figure
plt.imshow(Newfrob.detach(), cmap='viridis', origin='lower', extent=[-np.pi, np.pi, -np.pi, np.pi])
plt.colorbar(label='Frobenius norm of the metric', shrink = 0.7)

# Set ticks and labels for both axes
plt.xticks(ticks=ticks, labels=tick_labels)
plt.yticks(ticks=ticks, labels=tick_labels)

# Set axis labels and title
plt.title('Finite differences: Frobenius norm of metric on a grid')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')

plt.grid(False)  # Disable grid lines on the plot
plt.show()


# Derivatives of the metric: $d g$. Precision analysis

In [None]:
numsteps= 100
#size = 1.
size = 2*torch.pi
tgrid = ricci_regularization.FiniteDifferences.make_grid(numsteps= numsteps, 
        xlim_left= -size/2, xlim_right= size/2, ylim_bottom= -size/2, ylim_top= size/2)

metric = ricci_regularization.metric_fd_grid(tgrid, function=decoder)

In [None]:
dg_fd = ricci_regularization.metric_der_fd_grid(tgrid, decoder)
dg_jacfwd = ricci_regularization.metric_der_jacfwd_vmap(tgrid, function = decoder)

In [None]:
ricci_regularization.error_fd_jacfwd_on_grid(dg_fd,dg_jacfwd, cut = 2)

In [None]:
print("Step size:", size/numsteps)

# Christoffel symbols

In [None]:

dCh_fd = ricci_regularization.Ch_der_fd(tgrid, function=decoder)
dCh_jacfwd = ricci_regularization.Ch_der_jacfwd_vmap(tgrid, function=decoder)

In [None]:
# Error
ricci_regularization.error_fd_jacfwd_on_grid(tensor_fd= dCh_fd,
                                             tensor_jacfwd=dCh_jacfwd, cut=3)

# Curvature precision

In [None]:
numsteps = 20
size = torch.pi / 2
#size = 2 * torch.pi
tgrid7 = ricci_regularization.FiniteDifferences.make_grid(numsteps= 5* numsteps, 
        xlim_left= -size/2, xlim_right= size/2, ylim_bottom= -size/2, ylim_top= size/2)

tgrid = ricci_regularization.FiniteDifferences.make_grid(numsteps= numsteps, 
        xlim_left= -size/2, xlim_right= size/2, ylim_bottom= -size/2, ylim_top= size/2)

In [None]:
# 0.3 secs for 100 by 100 grid
Sc_fd = ricci_regularization.Sc_fd(tgrid7, function=decoder)

In [None]:
# 5.6 secs for 100 by 100 grid

Sc_jacfwd = ricci_regularization.Sc_jacfwd_vmap(tgrid, function=decoder)

In [None]:
# Error
Sc_error = ricci_regularization.error_fd_jacfwd_on_grid(tensor_fd= Sc_fd,
                                             tensor_jacfwd=Sc_jacfwd, cut=3)

In [None]:
print("scalar curvature MSE error:", Sc_error.item())

In [None]:
size = torch.pi / 5 # grid side size
numsteps_values = [7, 15, 30, 50, 80, 160]
mse_errors = []
mean_curvature_values = []
std_curvature_values = []
h_values = []
for numsteps in numsteps_values:
    grid = ricci_regularization.FiniteDifferences.make_grid(numsteps= numsteps, 
        xlim_left= -size/2, xlim_right= size/2, ylim_bottom= -size/2, ylim_top= size/2)
    linsize = (grid[numsteps-1] - grid[0]).norm()
    h_values.append(linsize.item() / numsteps)

    Sc_jacfwd = ricci_regularization.Sc_jacfwd_vmap(grid, function=decoder)
    Sc_fd = ricci_regularization.Sc_fd(grid, function=decoder)

    mean_curvature_values.append(torch.abs(Sc_jacfwd).mean().item())
    std_curvature_values.append(torch.abs(Sc_jacfwd).std().item())

    error = ricci_regularization.error_fd_jacfwd_on_grid(tensor_fd=Sc_jacfwd,
                                             tensor_jacfwd=Sc_fd,
                                             cut=3)
    
    del grid, Sc_jacfwd, Sc_fd
    mse_errors.append(error.item())
    del error

In [None]:
# Create the plot
plt.figure(figsize=(8, 6), dpi=300)
plt.loglog(h_values, mse_errors, marker='o', linestyle='-', color='b', markersize=4)
plt.loglog(h_values, mean_curvature_values, marker='o', linestyle='-', label ="mean value of $R$",
            color='r', markersize=4)
plt.loglog(h_values, std_curvature_values, marker='o', linestyle='-', label ="std of $R$",
            color='orange', markersize=4)

# Set grid for easier visualization
plt.ylabel("Error")
plt.xlabel("Step h")
plt.legend()

# Set only major ticks at h_values (no minor ticks)
ax = plt.gca()
ax.set_xticks(h_values)  # Set x-axis major ticks to h_values
ax.set_xticks([], minor=True)  # Disable minor ticks on the x-axis

# Set custom x-axis tick labels and rotate them for better visibility
ax.set_xticklabels([f"{h:.3g}" for h in h_values], rotation=45, ha='right')

# Display the plot
plt.title(r'Log-Log plot of step $h$ vs. MSE Error of scalar curvature $R$.', fontsize=14)
plt.savefig(Path_pictures + "/R_error_fd_jacfwd.pdf", bbox_inches='tight', format="pdf")
plt.show()

# Using mini-grids + minimal verification

In [None]:
def build_mini_grid(center: torch.Tensor, h: float = 1.0) -> torch.Tensor:
    """
    Builds a 7x7 mini-grid around a given center tensor with step size h.

    Args:
        center (torch.Tensor): A tensor representing the center of the grid (2D point).
        h (float): The step size between grid points. Default is 1.0.

    Returns:
        torch.Tensor: A 7x7 grid of shape (7, 7, 2) where each element is a coordinate.
    """
    # Create a 7x7 grid of relative coordinates (i, j) scaled by step size h
    offset = torch.arange(-3, 4) * h  # Relative offsets from the center (-3h, -2h, ..., 3h)
    grid_x, grid_y = torch.meshgrid(offset, offset, indexing='ij')  # 7x7 grid for x and y
    
    # Stack the coordinates (x, y) together and add to the center
    grid = torch.stack([grid_x, grid_y], dim=-1).float()  # Shape: (7, 7, 2)
    
    # Add the center coordinate to every point in the grid
    mini_grid = grid + center.unsqueeze(0).unsqueeze(0)  # Broadcasting center to grid shape
    
    return mini_grid

def build_mini_grid_batch(centers: torch.Tensor, h: float = 1.0) -> torch.Tensor:
    """
    Builds a batch of mini-grids centered at the given batch of points.
    
    Args:
        centers (torch.Tensor): A 2D tensor with shape (N, 2) representing N centers.
        grid_size (int): The size of the mini-grid (grid_size x grid_size).
        h (float): The step size for the grid.
        
    Returns:
        torch.Tensor: A batch of mini-grids of shape (N, grid_size * grid_size, 2).
    """
    offset = torch.arange(-3, 4) * h  # Relative offsets from the center (-3h, -2h, ..., 3h)
    grid_x, grid_y = torch.meshgrid(offset, offset, indexing='ij')  # 7x7 grid for x and y

    # Stack the coordinates (x, y) together and add to the center
    mini_grid = torch.stack([grid_x, grid_y], dim=-1).float()  # Shape: (7, 7, 2)
    mini_grid = mini_grid.reshape(49,2) # Shape: (49, 2)
    # Expand dimensions to match the number of centers
    mini_grid = mini_grid.unsqueeze(0)  # shape: (1, grid_size * grid_size, 2)

    # Broadcast the centers to create the batch
    centers = centers.unsqueeze(1)  # shape: (N, 1, 2)

    # Add the centers to the mini-grid points
    batch_minigrids = mini_grid + centers  # shape: (N, grid_size * grid_size, 2)
    return batch_minigrids

In [None]:
centers = torch.rand(10,2)
h = 0.01
batch_minigrids = build_mini_grid_batch(centers, h)   

print("batch of grids built correctly:",torch.equal( batch_minigrids[:,24,:], centers))

In [None]:
decoder(batch_minigrids).shape

In [None]:
def metric_fd_batch_minigrids(batch_minigrids, function):
    h = (batch_minigrids[0,1] - batch_minigrids[0,0]).norm()
    psi = function(batch_minigrids)
    psi_next_x =  psi.roll(-1,1)
    psi_prev_x =  psi.roll(1,1)
    psi_next_y =  psi.roll(-7,1)
    psi_prev_y =  psi.roll(7,1)

    dpsidx = (psi_next_x - psi_prev_x)/(2 * h)
    dpsidy = (psi_next_y - psi_prev_y)/(2 * h)
    E = torch.einsum('bgD,bgD->bg',dpsidx, dpsidx)
    F = torch.einsum('bgD,bgD->bg',dpsidx, dpsidy)
    G = torch.einsum('bgD,bgD->bg',dpsidy, dpsidy)

    metric = torch.cat((G.unsqueeze(-1), F.unsqueeze(-1), F.unsqueeze(-1), E.unsqueeze(-1)),-1)
    metric = metric.view(-1, 7 * 7, 2, 2)
    return metric


In [None]:
metric = metric_fd_batch_minigrids(batch_minigrids, decoder)
metric[7][24]

In [None]:
ricci_regularization.metric_jacfwd(centers[7],decoder)

In [None]:
def diff_by_x_minigrids(tensor_on_batch_minigrids, h):
    tensor_next_x =  tensor_on_batch_minigrids.roll(-7,1)
    tensor_prev_x =  tensor_on_batch_minigrids.roll(7,1)
    tensor_dx = (tensor_next_x - tensor_prev_x)/(2*h)
    return tensor_dx

def diff_by_y_minigrids(tensor_on_batch_minigrids, h):
    psi_next_y =  tensor_on_batch_minigrids.roll(-1,1)
    psi_prev_y =  tensor_on_batch_minigrids.roll(1,1)
    dpsidy = (psi_next_y - psi_prev_y)/(2*h)
    return dpsidy

def metric_der_fd_batch_minigrids(batch_minigrids, function):
    h = (batch_minigrids[0,1] - batch_minigrids[0,0]).norm()
    metric = metric_fd_batch_minigrids(batch_minigrids, 
                    function = function)
    dg_dx_fd = diff_by_x_minigrids(metric, h = h)
    dg_dy_fd = diff_by_y_minigrids(metric, h = h)
    dg = torch.cat((dg_dx_fd.unsqueeze(-1), dg_dy_fd.unsqueeze(-1)), dim = -1)
    return dg


In [None]:
metric_der_fd_batch_minigrids(batch_minigrids, decoder)[7][24]

In [None]:
ricci_regularization.metric_der_jacfwd(centers[7],decoder)

In [None]:
def metric_inv_batch_minigrids(batch_minigrids, function, eps=0.0):
    h = (batch_minigrids[0,1] - batch_minigrids[0,0]).norm()
    g = metric_fd_batch_minigrids(batch_minigrids, function)
    d = g.shape[-1]
    device = g.device
    g_inv = torch.inverse(g + eps*torch.eye(d,device=device))
    return g_inv

#metric_inv_jacfd_vmap = torch.func.vmap(metric_inv_fd)

def Ch_fd_batch_minigrids (batch_minigrids, function, eps = 0.0):
    g_inv = metric_inv_batch_minigrids(batch_minigrids,function,
                                       eps=eps)
    dg = metric_der_fd_batch_minigrids(batch_minigrids,function)
    Ch = 0.5*(torch.einsum('bgim,bgmkl->bgikl',g_inv,dg)+
              torch.einsum('bgim,bgmlk->bgikl',g_inv,dg)-
              torch.einsum('bgim,bgklm->bgikl',g_inv,dg)
              )
    return Ch
#Ch_fd_vmap = torch.func.vmap(Ch_fd)

def Ch_der_fd_batch_minigrids (grid, function, eps=0.0):
    h = (batch_minigrids[0,1] - batch_minigrids[0,0]).norm()

    Ch = Ch_fd_batch_minigrids(grid, function=function, eps=eps)
    dChdx = diff_by_x_minigrids(Ch, h)
    dChdy = diff_by_y_minigrids(Ch, h)
    dCh = torch.cat((dChdx.unsqueeze(-1), dChdy.unsqueeze(-1)), dim = -1)
    return dCh

In [None]:
Ch_der_fd_batch_minigrids(batch_minigrids, decoder)[7][24]

In [None]:
ricci_regularization.Ch_der_jacfwd(centers[7],decoder)

In [None]:
# Riemann curvature tensor (3,1)
def Riem_fd_batch_minigrids(u, function,eps=0.0):
    Ch = Ch_fd_batch_minigrids(u, function, eps=eps)
    Ch_der = Ch_der_fd_batch_minigrids(u, function, eps=eps)

    Riem = torch.einsum("bgiljk->bgijkl",Ch_der) - torch.einsum("bgikjl->bgijkl",Ch_der)
    Riem += torch.einsum("bgikp,bgplj->bgijkl", Ch, Ch) - torch.einsum("bgilp,bgpkj->bgijkl", Ch, Ch)
    return Riem

def Ric_fd_batch_minigrids(u, function, eps=0.0):
    Riemann = Riem_fd_batch_minigrids(u, function, eps=eps)
    Ric = torch.einsum("bgcscr->bgsr",Riemann)
    return Ric

def Sc_fd_batch_minigrids (u, function, eps = 0.0):
    Ricci = Ric_fd_batch_minigrids(u, function=function,eps=eps)
    metric_inv = metric_inv_batch_minigrids(u,function=function, eps=eps)
    Sc = torch.einsum('bgsr,bgsr->bg',metric_inv,Ricci)
    return Sc


In [None]:
Sc_fd_batch_minigrids(batch_minigrids, decoder)[7][24]

In [None]:
ricci_regularization.Sc_jacfwd(centers[7],decoder)

errors

In [None]:
def error_fd_jacfwd_batch_minigrids(tensor_fd, tensor_jacfwd):
    batch_size = tensor_fd.shape[0]
    #finite differences
    tensor_fd_central = tensor_fd[:, 24]

    error = torch.functional.F.mse_loss(tensor_fd_central, tensor_jacfwd)
    return error

In [None]:
metric_fd = metric_fd_batch_minigrids(batch_minigrids, function=decoder)
metric_jacfwd = ricci_regularization.metric_jacfwd_vmap(centers, function=decoder)

In [None]:
error_fd_jacfwd_batch_minigrids(metric_fd, metric_jacfwd)

In [None]:
error_fd_jacfwd_batch_minigrids(metric_fd, metric_jacfwd)

Experiments

In [None]:
centers = torch.rand(10000,2)
h = 0.01
batch_minigrids = build_mini_grid_batch(centers, h)   

In [None]:
Sc_fd = Sc_fd_batch_minigrids(batch_minigrids, function=decoder)
Sc_jacfwd = ricci_regularization.Sc_jacfwd_vmap(centers, function=decoder)

In [None]:
import timeit
import json

# Define the number of iterations for averaging
iterations = 100

batch_sizes = [8, 16, 32,40, 64, 128, 256]  # Different batch sizes to test

# Initialize a list to hold timing results
timing_results = []

# Generate grid and centers based on the fixed numsteps
h = 0.01  # Step size (arbitrary)
centers = torch.randn(max(batch_sizes), 2)  # Example centers, random values
# Generate batch mini-grids for the current numsteps
batch_minigrids = build_mini_grid_batch(centers, h=h)

# Loop through different batch sizes
for batch_size in batch_sizes:
    # Adjust centers and batch_minigrids to match the current batch_size
    current_centers = centers[:batch_size]
    current_batch_minigrids = batch_minigrids[:batch_size]

    # Timing for Sc_fd
    time_fd = timeit.timeit(
        stmt="Sc_fd_batch_minigrids(current_batch_minigrids, function=decoder)",
        setup="from __main__ import Sc_fd_batch_minigrids, current_batch_minigrids, decoder",
        number=iterations
    )

    # Timing for Sc_jacfwd
    time_jacfwd = timeit.timeit(
        stmt="ricci_regularization.Sc_jacfwd_vmap(current_centers, function=decoder)",
        setup="from __main__ import ricci_regularization, current_centers, decoder",
        number=iterations
    )

    # Append the results to the timing_results list
    timing_results.append({
        "batch_size": batch_size,
        "Sc_fd_avg_time": time_fd / iterations,
        "Sc_jacfwd_avg_time": time_jacfwd / iterations,
    })

# Save results to a JSON file
with open('timing_results_batch_minigrids.json', 'w') as f:
    json.dump(timing_results, f, indent=4)

# Print the timing results
for result in timing_results:
    print(result)

In [None]:
# Plotting the results
batch_sizes = [result['batch_size'] for result in timing_results]
sc_fd_times = [result['Sc_fd_avg_time'] for result in timing_results]
sc_jacfwd_times = [result['Sc_jacfwd_avg_time'] for result in timing_results]

plt.figure(figsize=(10, 6))

# Plot average times for Sc_fd and Sc_jacfwd_vmap
plt.plot(batch_sizes, sc_fd_times, marker='o', label='fd on mini_grids', linestyle='-')
plt.plot(batch_sizes, sc_jacfwd_times, marker='s', label='jacfwd', linestyle='-')

# Adding labels and title
plt.ylabel('Average Time (seconds)')
plt.xlabel('Batch Size')
plt.title(f'Timing scalar curvature $R$ computation: fd on minigrids vs jacfwd')
plt.grid()
plt.legend()
# Set x-ticks to be the actual batch size values
plt.xticks(batch_sizes)  # Setting the x-ticks to match batch sizes

# Save the plot
plt.savefig(Path_pictures+'/timing_results_batch_minigrids.pdf', bbox_inches='tight')
plt.show()