NB! This is an old notebook. It contains a wrong (very unprecise) way of computing Scalar curvature of the latent space with f.d. formulas used for computing derivatives. Has to be redone!

The autoencoder (AE) consists of the encoder $\Phi$ and the decoder $\Psi$.
The latent space of the AE is $R^d$. We define a Riemannian metric in a local chart of the latent space as the pull-back of the Euclidean metric in the output space $R^D$ by the decoder function $\Psi$ of the AE:
\begin{equation*}
    g = \nabla \Psi ^* \nabla \Psi   
\end{equation*}.

The notebook contains:
1) Loading weights of a pre-trained convolutional AE and plotting its latent space: point plot and manifold plot. If "violent_saving" == True, plots are saved locally.
2) Auxillary tensors involving higher order derivatives of the decoder $\Psi$ are computed with f.d.: metric $g$ and its derivatives, Riemann tensor $R^{i}_{jkl}$, Ricci tensor $R_{ij}$ and scalar curvature.
3) Geodesics shooting via Runge-Kutta approximation. A single plot with a scalar curvature heatmap and geodesics on it is constructed.
4) Prototype of metric evolution by Ricci flow equation 

NB! by default the metric $g$ is the pull-back by the decoder as described above. But one can use any custom metric by manually setting it in "specific_metric" function, that computes the metric matrix at a point $u\in \mathbb{R}: \ g(u)$ given the local coordinates of the point $u$ in the latent space.

In [None]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import torch
import ricci_regularization
import yaml

In [None]:
with open('../../experiments/MNIST_Setting_1_config.yaml', 'r') as yaml_file:
#with open('../../experiments/MNIST01_exp7_config.yaml', 'r') as yaml_file:
#with open('../../experiments/Swissroll_exp4_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

violent_saving = False # if False it will not save plots

d = 2

# Loading data and nn weights

In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"]
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")
additional_path="../"

In [None]:
experiment_name = yaml_config["experiment"]["name"]

#Path_pictures = yaml_config["experiment"]["path"]
Path_pictures = additional_path + "../experiments/" + yaml_config["experiment"]["name"]
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name in ["MNIST01", "Synthetic"]:
    # k from the JSON configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
elif dataset_name == "MNIST":
    k = 10
print("Experiment name:", experiment_name)
print("Plots saved at:", Path_pictures)

In [None]:
### Initialize the two networks

torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config, additional_path = additional_path)

torus_ae = torus_ae.to("cpu")

print(f"AE weights loaded successfully from {Path_ae_weights}.")

encoder = torus_ae.encoder_torus
decoder = torus_ae.decoder_torus

In [None]:
centers = torch.rand(10, 2)


In [None]:
Sc, g = ricci_regularization.Sc_g_fd_batch_minigrids_rhombus(centers, function=decoder, h = 0.01)

In [None]:
ricci_regularization.curvature_loss(points=centers, function=decoder, h = 0.01, eps=0.)

In [None]:
(torch.sqrt( torch.det( ricci_regularization.metric_jacfwd_vmap(centers, function=decoder) ) ) *torch.square( ricci_regularization.Sc_jacfwd_vmap(centers, function=decoder))).mean()

# Square minigrids

In [None]:
def build_mini_grid_batch(centers: torch.Tensor, h: float) -> torch.Tensor:
    """
    Builds a batch of mini-grids centered at the given batch of points.
    
    Args:
        centers (torch.Tensor): A 2D tensor with shape (N, 2) representing N centers.
        grid_size (int): The size of the mini-grid (grid_size x grid_size).
        h (float): The step size for the grid.
        
    Returns:
        torch.Tensor: A batch of mini-grids of shape (N, grid_size * grid_size, 2).
    """
    offset = torch.arange(-3, 4) * h  # Relative offsets from the center (-3h, -2h, ..., 3h)
    grid_x, grid_y = torch.meshgrid(offset, offset, indexing='ij')  # 7x7 grid for x and y

    # Stack the coordinates (x, y) together and add to the center
    mini_grid = torch.stack([grid_x, grid_y], dim=-1).float()  # Shape: (7, 7, 2)
    mini_grid = mini_grid.reshape(49,2) # Shape: (49, 2)
    # Expand dimensions to match the number of centers
    mini_grid = mini_grid.unsqueeze(0)  # shape: (1, grid_size * grid_size, 2)

    # Broadcast the centers to create the batch
    centers = centers.unsqueeze(1)  # shape: (N, 1, 2)

    # Add the centers to the mini-grid points
    batch_minigrids = mini_grid + centers  # shape: (N, grid_size * grid_size, 2)

    d = centers.shape[-1]
    batch_size = centers.shape[0]
    batch_minigrids = batch_minigrids.reshape(batch_size, 7, 7, d) # shape batch_size * 7 * 7 * d
    return batch_minigrids

def diff_by_x_minigrids(tensor_on_batch_minigrids, h):
    # entry of shape batch_size * minigrid_side * minigrid_side * something
    tensor_next_x =  tensor_on_batch_minigrids[:,2:,1:-1]
    tensor_prev_x =  tensor_on_batch_minigrids[:,:-2,1:-1]
    tensor_dx = (tensor_next_x - tensor_prev_x)/(2*h)
    return tensor_dx

def diff_by_y_minigrids(tensor_on_batch_minigrids, h):
    # entry of shape batch_size * minigrid_side * minigrid_side * something
    tensor_next_y =  tensor_on_batch_minigrids[:,1:-1,2:]
    tensor_prev_y =  tensor_on_batch_minigrids[:,1:-1,:-2]
    tensor_dy = (tensor_next_y - tensor_prev_y)/(2*h)
    return tensor_dy


def metric_fd_batch_minigrids(centers, function, h=0.01):
    
    batch_minigrids = build_mini_grid_batch(centers, h)
    psi = function(batch_minigrids)
    dpsidx = diff_by_x_minigrids(psi, h) # shape batch_size * 5 * 5 * D
    dpsidy = diff_by_y_minigrids(psi, h) # shape batch_size * 5 * 5 * D
    dpsi = torch.cat(( dpsidx.unsqueeze(-1), dpsidy.unsqueeze(-1) ), -1) # shape batch_size * 5 * 5 * D * 2
    #b is batch_size, g,h are coordinates on the minigrid, D is output of psi dimension, i,j are local coordinates
    metric = torch.einsum('bghDi,bghDj->bghij', dpsi,dpsi)
    return metric

def Sc_fd_batch_minigrids (centers, function, h=0.01, eps = 0.0):
    # d is latent dimension
    d = centers.shape[-1]
    #create a batch of minigrids with given centers and step h
    batch_minigrids = build_mini_grid_batch(centers, h) # shape batch_size * 7 * 7 * d
#    batch_minigrids = batch_minigrids.reshape(batch_minigrids.shape[0], 7, 7, d) # shape batch_size * 7 * 7 * d
    psi = function(batch_minigrids)
    dpsidx = diff_by_x_minigrids(psi, h) # shape batch_size * 5 * 5 * D
    dpsidy = diff_by_y_minigrids(psi, h) # shape batch_size * 5 * 5 * D
    dpsi = torch.cat(( dpsidx.unsqueeze(-1), dpsidy.unsqueeze(-1) ), -1) # shape batch_size * 5 * 5 * D * d
    
    #compute metric
    #b is batch_size, g,h are coordinates on the minigrid, D is output of psi dimension, i,j are local coordinates
    g = torch.einsum('bghDi,bghDj->bghij', dpsi,dpsi) # shape batch_size * 5 * 5 * d * d

    #compute metric derivatives
    dg_dx = diff_by_x_minigrids(g, h)
    dg_dy = diff_by_y_minigrids(g, h)

    #compute inverse
    device = g.device
    
    #cutting the shape of g to compute g_inv
    g = g[:,1:-1,1:-1] # shape batch_size * 3 * 3 * d * d
    g_inv = torch.inverse(g + eps*torch.eye(d,device=device)) # shape batch_size * 3 * 3 * d * d
    

    del g
    dg = torch.cat((dg_dx.unsqueeze(-1), dg_dy.unsqueeze(-1)), dim = -1) # shape batch_size * 3 * 3 * d * d * d
    del dg_dx, dg_dy

    #compute Christoffel symbols
    #b is batch_size, g,h are coordinates on the minigrid, i, m, k, l are local coordinates
    Christoffel = 0.5*(torch.einsum('bghim,bghmkl->bghikl',g_inv,dg)+
              torch.einsum('bghim,bghmlk->bghikl',g_inv,dg)-
              torch.einsum('bghim,bghklm->bghikl',g_inv,dg)
              ) # shape batch_size * 3 * 3 * d * d * d
    del dg
    #compute Christoffel symbols' derivatives
    
    dChristoffel_dx = diff_by_x_minigrids(Christoffel, h) # shape batch_size * 1 * 1 * d * d * d
    dChristoffel_dy = diff_by_y_minigrids(Christoffel, h) # shape batch_size * 1 * 1 * d * d * d

    dChristoffel = torch.cat((dChristoffel_dx.unsqueeze(-1),
                              dChristoffel_dy.unsqueeze(-1)), dim = -1) # shape batch_size * 1 * 1 * d * d * d * d
    del dChristoffel_dx, dChristoffel_dy
    # squeezing since we only have values at centers of minigrids (one point)
    dChristoffel = dChristoffel.squeeze() # shape batch_size * d * d * d * d
    
    #Compute Riemann tensor
    #here we only need christoffels and derivatives at centers
    Christoffel = Christoffel[:,1:-1,1:-1].squeeze() # shape batch_size * d * d * d
    #b is batch_size i, j, k, l, p are local coordinates
    Riemann = torch.einsum("biljk->bijkl",dChristoffel) - torch.einsum("bikjl->bijkl",dChristoffel)
    Riemann += torch.einsum("bikp,bplj->bijkl", Christoffel, Christoffel) - torch.einsum("bilp,bpkj->bijkl", Christoffel, Christoffel)
    # Riemann shape: batch_size * d * d * d
    
    del dChristoffel, Christoffel
    #Compute Ricci
    #b is batch_size c, s, r are local coordinates
    Ricci = torch.einsum("bcscr->bsr",Riemann)
    del Riemann

    #Compute scalar curvature
    #we only need inverse of the metric at one central point:
    g_inv = g_inv[:,1:-1,1:-1].squeeze() # shape batch_size * d * d
    #b is batch_size s, r are local coordinates
    Sc = torch.einsum('bsr,bsr->b',g_inv,Ricci)
    del Ricci, g_inv
    return Sc

# Rhombus minigrids

In [None]:
"""
# Size of the matrix (7x7)
matrix_size = 7
center = matrix_size // 2  # This gives the index 3, which is the center of a 7x7 matrix

# Create the 7x7 matrix filled with zeros
matrix = torch.zeros((matrix_size, matrix_size))

# Get the indices that form the rhombus shape
rhombus_indices = []
for i in range(matrix_size):
    for j in range(matrix_size):
        if abs(i - center) + abs(j - center) <= 3:
            rhombus_indices.append((i, j))
            matrix[i, j] = 1  # Mark the rhombus area for visualization

# Print the indices of the rhombus
print("Indices that form the rhombus:")
print(rhombus_indices)

# Visualize the rhombus
print("Rhombus shape in the 7x7 matrix:")
print(matrix)
"""

In [None]:
def rhombus_mask():
    # Create the 7x7 mask for the rhombus indices without a loop
    center = 3
    grid_x, grid_y = torch.meshgrid(torch.arange(7), torch.arange(7), indexing="ij")

    # Compute Manhattan distance from the center (3, 3)
    manhattan_distance = torch.abs(grid_x - center) + torch.abs(grid_y - center)

    # The rhombus is where the Manhattan distance is less than or equal to 3
    mask = manhattan_distance <= 3
    return mask

In [None]:
def Sc_fd_batch_minigrids_rhombus_slow (centers, function, h=0.01, eps = 0.0):
    # d is latent dimension
    d = centers.shape[-1]
    batch_size = centers.shape[0]
    # create a batch of minigrids with given centers and step h
    batch_minigrids = build_mini_grid_batch(centers, h) # shape batch_size * 7 * 7 * d

    # Create the rhombus mask 
    mask = rhombus_mask()
    # expand the mask to shape [batch_size, 7, 7, d] 
    batch_mask = mask.unsqueeze(0).unsqueeze(-1).expand(batch_size, 7, 7, d)

    # Extract rhombus values for the batch (use the mask on the batch of minigrids)
    rhombus_minigrids_batch = batch_minigrids[batch_mask].view(batch_size, -1, d)

    # Evaluate the decoder psi only on the rhombus
    psi = function( rhombus_minigrids_batch ) # shape batch_size * 25 * D
    D = psi.shape[-1] # psi output dimension: D
    # Reinitialize the batch of minigrids tensor filled with zeros (same shape as the original batch)
    result_tensor = torch.zeros(batch_size, 7, 7, D)

    # expand the mask to shape [batch_size, 7, 7, D]
    batch_mask_D = mask.unsqueeze(0).unsqueeze(-1).expand(batch_size, 7, 7, D)
    
    # use the batch_mask to insert the psi values on the rhombus into the corresponding positions
    result_tensor[batch_mask_D] = psi.view(-1)
    psi = result_tensor

    # compute dpsi
    dpsidx = diff_by_x_minigrids(psi, h) # shape batch_size * 5 * 5 * D
    dpsidy = diff_by_y_minigrids(psi, h) # shape batch_size * 5 * 5 * D
    dpsi = torch.cat(( dpsidx.unsqueeze(-1), dpsidy.unsqueeze(-1) ), -1) # shape batch_size * 5 * 5 * D * d
    
    # compute metric
    # b is batch_size, g,h are coordinates on the minigrid, D is output of psi dimension, i,j are local coordinates
    g = torch.einsum('bghDi,bghDj->bghij', dpsi,dpsi) # shape batch_size * 5 * 5 * d * d

    # compute metric derivatives
    dg_dx = diff_by_x_minigrids(g, h)
    dg_dy = diff_by_y_minigrids(g, h)
    dg = torch.cat((dg_dx.unsqueeze(-1), dg_dy.unsqueeze(-1)), dim = -1) # shape batch_size * 3 * 3 * d * d * d
    del dg_dx, dg_dy

    # compute inverse of g
    device = g.device
    # cutting the shape of g to compute g_inv
    g = g[:,1:-1,1:-1] # shape batch_size * 3 * 3 * d * d
    g_inv = torch.inverse(g + eps*torch.eye(d,device=device)) # shape batch_size * 3 * 3 * d * d
    del g

    # compute Christoffel symbols
    # b is batch_size, g,h are coordinates on the minigrid, i, m, k, l are local coordinates
    Christoffel = 0.5*(torch.einsum('bghim,bghmkl->bghikl',g_inv,dg)+
              torch.einsum('bghim,bghmlk->bghikl',g_inv,dg)-
              torch.einsum('bghim,bghklm->bghikl',g_inv,dg)
              ) # shape batch_size * 3 * 3 * d * d * d
    del dg

    # compute Christoffel symbols' derivatives
    dChristoffel_dx = diff_by_x_minigrids(Christoffel, h) # shape batch_size * 1 * 1 * d * d * d
    dChristoffel_dy = diff_by_y_minigrids(Christoffel, h) # shape batch_size * 1 * 1 * d * d * d

    dChristoffel = torch.cat((dChristoffel_dx.unsqueeze(-1),
                              dChristoffel_dy.unsqueeze(-1)), dim = -1) # shape batch_size * 1 * 1 * d * d * d * d
    del dChristoffel_dx, dChristoffel_dy
    # squeezing since we only have values at centers of minigrids (one point)
    dChristoffel = dChristoffel.squeeze() # shape batch_size * d * d * d * d
    
    # compute Riemann tensor
    # here we only need christoffels and derivatives at centers
    Christoffel = Christoffel[:,1:-1,1:-1].squeeze() # shape batch_size * d * d * d
    # b is batch_size i, j, k, l, p are local coordinates
    Riemann = torch.einsum("biljk->bijkl",dChristoffel) - torch.einsum("bikjl->bijkl",dChristoffel)
    Riemann += torch.einsum("bikp,bplj->bijkl", Christoffel, Christoffel) - torch.einsum("bilp,bpkj->bijkl", Christoffel, Christoffel)
    # Riemann shape: batch_size * d * d * d
    del dChristoffel, Christoffel

    # compute Ricci
    # b is batch_size c, s, r are local coordinates
    Ricci = torch.einsum("bcscr->bsr",Riemann)
    del Riemann

    # compute scalar curvature
    # we only need inverse of the metric at one central point:
    g_inv = g_inv[:,1:-1,1:-1].squeeze() # shape batch_size * d * d
    # b is batch_size s, r are local coordinates
    Sc = torch.einsum('bsr,bsr->b',g_inv,Ricci)
    del Ricci, g_inv
    return Sc

In [None]:
def proper_indices(printing = False):
    # Step 1: Create a 7x7 tensor initialized with zeros (int type)
    matrix_size = 7
    rhombus_tensor = - torch.ones((matrix_size, matrix_size), dtype=torch.int)  # Change dtype to int

    # Step 2: Create the 7x7 mask for the rhombus indices without a loop
    center = 3
    grid_x, grid_y = torch.meshgrid(torch.arange(matrix_size), torch.arange(matrix_size), indexing="ij")

    # Compute Manhattan distance from the center (3, 3)
    manhattan_distance = torch.abs(grid_x - center) + torch.abs(grid_y - center)

    # The rhombus is where the Manhattan distance is less than or equal to 3
    mask = manhattan_distance <= 3

    # Step 3: Fill the tensor using the mask with the same dtype
    indices = torch.arange(25, dtype=torch.int)  # Ensure indices are of int type
    rhombus_tensor[mask] = indices  # Fill masked positions with values 0 to 24

    # Step 4: Print the resulting tensor
    if printing == True:
        print("Rhombus Tensor with Non-Zero Cells:")
        print(rhombus_tensor)
        print("Transposed indices:")
        print(rhombus_tensor.T)
        print("Transposed indexing:", indices)
    non_negative_elements = rhombus_tensor.T[rhombus_tensor.T != -1]
    indices = non_negative_elements.flatten()
    return indices

def metric_fd_batch_minigrids_rhombus_slow (centers, function, h=0.01, eps = 0.0):
    # d is latent dimension
    d = centers.shape[-1]
    batch_size = centers.shape[0]
    # create a batch of minigrids with given centers and step h
    batch_minigrids = build_mini_grid_batch(centers, h) # shape batch_size * 7 * 7 * d

    # Create the rhombus mask 
    mask = rhombus_mask()
    # expand the mask to shape [batch_size, 7, 7, d] 
    batch_mask = mask.unsqueeze(0).unsqueeze(-1).expand(batch_size, 7, 7, d)

    # Extract rhombus values for the batch (use the mask on the batch of minigrids)
    rhombus_minigrids_batch = batch_minigrids[batch_mask].view(batch_size, -1, d)

    # Evaluate the decoder psi only on the rhombus
    psi = function( rhombus_minigrids_batch ) # shape batch_size * 25 * D

    # compute dpsi
    indices = proper_indices()
    dpsi_dx_rhombus = ( psi[:, indices[2:],:] - psi[:, indices[:-2],:] ) / ( 2 * h )   # shape batch_size * 23 * D
    dpsi_dy_rhombus = ( psi[:,2:,:] - psi[:,:-2,:] ) / ( 2 * h )    # shape batch_size * 23 * D
    dpsi = torch.cat(( dpsi_dx_rhombus.unsqueeze(-1), dpsi_dy_rhombus.unsqueeze(-1) ), -1) # shape batch_size * 23 * D * d
    
    # compute metric
    # b is batch_size, g,h are coordinates on the minigrid, D is output of psi dimension, i,j are local coordinates
    g = torch.einsum('bgDi,bgDj->bgij', dpsi,dpsi) # shape batch_size * 23 * D
    return g

In [None]:
def indices (minigrid_size = 7):
    # fill the grid by -1
    rhombus_tensor = - torch.ones((minigrid_size, minigrid_size), dtype=torch.int)  # Change dtype to int

    # Step 1: Create the minigrid_size x minigrid_size mask for the rhombus indices
    center = minigrid_size // 2
    grid_x, grid_y = torch.meshgrid(torch.arange(minigrid_size), torch.arange(minigrid_size), indexing="ij")

    # Compute Manhattan distance from the center (3, 3)
    manhattan_distance = torch.abs(grid_x - center) + torch.abs(grid_y - center)

    # The rhombus is where the Manhattan distance is less than or equal to 3
    mask = manhattan_distance <= center

    # Step 2: Fill the tensor using the mask with the same dtype
    num_rhombus_points = ( minigrid_size * minigrid_size ) // 2 + 1
    indices = torch.arange(num_rhombus_points, dtype=torch.int)  # Ensure indices are of int type
    rhombus_tensor[mask] = indices  # Fill masked positions with values 0 to 24

    # Compute Manhattan distance from the point (center + 1 , center)
    manhattan_distance_x_next = torch.abs(grid_x - (center + 1) ) + torch.abs(grid_y - center)
    mask_x_next = manhattan_distance_x_next <= center - 1

    # Compute Manhattan distance from the point (center - 1 , center)
    manhattan_distance_x_prev = torch.abs(grid_x - (center - 1) ) + torch.abs(grid_y - center)
    mask_x_prev = manhattan_distance_x_prev <= center - 1

    # Compute Manhattan distance from the point (center, center + 1)
    manhattan_distance_y_next = torch.abs(grid_x - center ) + torch.abs( grid_y - ( center + 1 ) )
    mask_y_next = manhattan_distance_y_next <= center - 1

    # Compute Manhattan distance from the point (center, center - 1)
    manhattan_distance_y_prev = torch.abs(grid_x - center ) + torch.abs( grid_y - ( center - 1 ) )
    mask_y_prev = manhattan_distance_y_prev <= center - 1

    # Compute Manhattan distance from the point (center, center)
    manhattan_distance_central = torch.abs(grid_x - center ) + torch.abs( grid_y - center )
    mask_central = manhattan_distance_central <= center - 1

    # Step 3: give the proper indices for steps in x and y directions on the rhombus
    indices_x_next = rhombus_tensor[mask_x_next]
    indices_x_prev = rhombus_tensor[mask_x_prev]
    indices_y_next = rhombus_tensor[mask_y_next]
    indices_y_prev = rhombus_tensor[mask_y_prev]
    indices_central = rhombus_tensor[mask_central]
    return mask, indices_x_next, indices_x_prev, indices_y_next, indices_y_prev, indices_central

In [None]:
def metric_fd_batch_minigrids_rhombus (centers, function, h=0.01, eps = 0.0):
    d = centers.shape[-1]
    batch_size = centers.shape[0]
    # create a batch of minigrids with given centers and step h
    batch_minigrids = build_mini_grid_batch(centers, h) # shape batch_size * 7 * 7 * d

    # Create the rhombus mask and indices for differentiation 
    mask, indices_x_next, indices_x_prev, indices_y_next, indices_y_prev, _ = indices()
    # expand the mask to shape [batch_size, 7, 7, d] 
    batch_mask = mask.unsqueeze(0).unsqueeze(-1).expand(batch_size, 7, 7, d)

    # Extract rhombus values for the batch (use the mask on the batch of minigrids)
    rhombus_minigrids_batch = batch_minigrids[batch_mask].view(batch_size, -1, d)

    # Evaluate the decoder psi only on the rhombus
    psi = function( rhombus_minigrids_batch ) # shape batch_size * 25 * D

    # compute dpsi
    dpsi_dx_fast = ( psi[:, indices_x_next] - psi[:, indices_x_prev] ) / ( 2 * h ) # shape batch_size * 13 * D
    dpsi_dy_fast = ( psi[:, indices_y_next] - psi[:, indices_y_prev] ) / ( 2 * h ) # shape batch_size * 13 * D
    dpsi = torch.cat(( dpsi_dx_fast.unsqueeze(-1), dpsi_dy_fast.unsqueeze(-1) ), -1) # shape batch_size * 13 * D * d
    
    # compute metric
    # b is batch_size, g,h are coordinates on the minigrid, D is output of psi dimension, i,j are local coordinates
    g = torch.einsum('bgDi,bgDj->bgij', dpsi,dpsi) # shape batch_size * 13 * d * d
    return g

In [None]:
def Ch_fd_batch_minigrids_rhombus (centers, function, h=0.01, eps = 0.0):
    d = centers.shape[-1]
    batch_size = centers.shape[0]
    # create a batch of minigrids with given centers and step h
    batch_minigrids = build_mini_grid_batch(centers, h) # shape batch_size * 7 * 7 * d

    # Create the rhombus mask and indices for differentiation 
    mask, indices_x_next, indices_x_prev, indices_y_next, indices_y_prev,_ = indices( minigrid_size = 7)
    # expand the mask to shape [batch_size, 7, 7, d] 
    batch_mask = mask.unsqueeze(0).unsqueeze(-1).expand(batch_size, 7, 7, d)

    # Extract rhombus values for the batch (use the mask on the batch of minigrids)
    rhombus_minigrids_batch = batch_minigrids[batch_mask].view(batch_size, -1, d)

    # Evaluate the decoder psi only on the rhombus
    psi = function( rhombus_minigrids_batch ) # shape batch_size * 25 * D

    # compute dpsi
    dpsi_dx_fast = ( psi[:, indices_x_next] - psi[:, indices_x_prev] ) / ( 2 * h ) # shape batch_size * 13 * D
    dpsi_dy_fast = ( psi[:, indices_y_next] - psi[:, indices_y_prev] ) / ( 2 * h ) # shape batch_size * 13 * D
    dpsi = torch.cat(( dpsi_dx_fast.unsqueeze(-1), dpsi_dy_fast.unsqueeze(-1) ), -1) # shape batch_size * 13 * D * d
    
    # compute metric
    # b is batch_size, g,h are coordinates on the minigrid, D is output of psi dimension, i,j are local coordinates
    g = torch.einsum('bgDi,bgDj->bgij', dpsi,dpsi) # shape batch_size * 13 * d * d

    # compute metric derivatives

    # Get new indices for differentiation 
    _, indices_x_next, indices_x_prev, indices_y_next, indices_y_prev, indices_central  = indices(minigrid_size = 5)

    dg_dx = ( g[:, indices_x_next] - g[:, indices_x_prev] ) / ( 2 * h ) # shape batch_size * 5 * D
    dg_dy = ( g[:, indices_y_next] - g[:, indices_y_prev] ) / ( 2 * h ) # shape batch_size * 5 * D
    dg = torch.cat((dg_dx.unsqueeze(-1), dg_dy.unsqueeze(-1)), dim = -1) # shape batch_size * 5 * d * d * d
    del dg_dx, dg_dy

    # compute inverse of g
    device = g.device
    
    # cutting the shape of g to compute g_inv
    g = g[:, indices_central] # shape batch_size * 5 * d * d
    g_inv = torch.inverse(g + eps*torch.eye(d,device=device)) # shape batch_size * 5 * d * d
    del g

    # compute Christoffel symbols
    # b is batch_size, g,h are coordinates on the minigrid, i, m, k, l are local coordinates
    Christoffel = 0.5*(torch.einsum('bgim,bgmkl->bgikl',g_inv,dg)+
              torch.einsum('bgim,bgmlk->bgikl',g_inv,dg)-
              torch.einsum('bgim,bgklm->bgikl',g_inv,dg)
              ) # shape batch_size * 5 * d * d * d
    return Christoffel

def Sc_fd_batch_minigrids_rhombus (centers, function, h=0.01, eps = 0.0):
    d = centers.shape[-1]
    batch_size = centers.shape[0]
    # create a batch of minigrids with given centers and step h
    batch_minigrids = build_mini_grid_batch(centers, h) # shape batch_size * 7 * 7 * d

    # Create the rhombus mask and indices for differentiation 
    mask, indices_x_next, indices_x_prev, indices_y_next, indices_y_prev,_ = indices( minigrid_size = 7)
    # expand the mask to shape [batch_size, 7, 7, d] 
    batch_mask = mask.unsqueeze(0).unsqueeze(-1).expand(batch_size, 7, 7, d)

    # Extract rhombus values for the batch (use the mask on the batch of minigrids)
    rhombus_minigrids_batch = batch_minigrids[batch_mask].view(batch_size, -1, d)

    # Evaluate the decoder psi only on the rhombus
    psi = function( rhombus_minigrids_batch ) # shape batch_size * 25 * D

    # compute dpsi
    dpsi_dx_fast = ( psi[:, indices_x_next] - psi[:, indices_x_prev] ) / ( 2 * h ) # shape batch_size * 13 * D
    dpsi_dy_fast = ( psi[:, indices_y_next] - psi[:, indices_y_prev] ) / ( 2 * h ) # shape batch_size * 13 * D
    dpsi = torch.cat(( dpsi_dx_fast.unsqueeze(-1), dpsi_dy_fast.unsqueeze(-1) ), -1) # shape batch_size * 13 * D * d
    
    # compute metric
    # b is batch_size, g,h are coordinates on the minigrid, D is output of psi dimension, i,j are local coordinates
    g = torch.einsum('bgDi,bgDj->bgij', dpsi,dpsi) # shape batch_size * 13 * d * d

    # compute metric derivatives

    # Get new indices for differentiation 
    _, indices_x_next, indices_x_prev, indices_y_next, indices_y_prev, indices_central  = indices(minigrid_size = 5)

    dg_dx = ( g[:, indices_x_next] - g[:, indices_x_prev] ) / ( 2 * h ) # shape batch_size * 5 * D
    dg_dy = ( g[:, indices_y_next] - g[:, indices_y_prev] ) / ( 2 * h ) # shape batch_size * 5 * D
    dg = torch.cat((dg_dx.unsqueeze(-1), dg_dy.unsqueeze(-1)), dim = -1) # shape batch_size * 5 * d * d * d
    del dg_dx, dg_dy

    # compute inverse of g
    device = g.device
    
    # cutting the shape of g to compute g_inv
    g = g[:, indices_central] # new shape: batch_size * 5 * d * d
    g_inv = torch.inverse(g + eps*torch.eye(d,device=device)) # shape batch_size * 5 * d * d
    del g

    # compute Christoffel symbols
    # b is batch_size, g,h are coordinates on the minigrid, i, m, k, l are local coordinates
    Christoffel = 0.5*(torch.einsum('bgim,bgmkl->bgikl',g_inv,dg)+
              torch.einsum('bgim,bgmlk->bgikl',g_inv,dg)-
              torch.einsum('bgim,bgklm->bgikl',g_inv,dg)
              ) # shape batch_size * 5 * d * d * d
    del dg

    # compute Christoffel symbols' derivatives
    # Get new indices for differentiation 
    _, indices_x_next, indices_x_prev, indices_y_next, indices_y_prev, indices_central  = indices(minigrid_size = 3)

    dChristoffel_dx = ( Christoffel[:, indices_x_next] - Christoffel[:, indices_x_prev] ) / ( 2 * h ) # shape batch_size * 1 * D
    dChristoffel_dy = ( Christoffel[:, indices_y_next] - Christoffel[:, indices_y_prev] ) / ( 2 * h ) # shape batch_size * 1 * D

    dChristoffel = torch.cat((dChristoffel_dx.unsqueeze(-1),
                              dChristoffel_dy.unsqueeze(-1)), dim = -1) # shape batch_size * 1 * d * d * d * d
    del dChristoffel_dx, dChristoffel_dy
    # squeezing since we only have values at centers of minigrids (one point)
    dChristoffel = dChristoffel.squeeze() # shape batch_size * d * d * d * d
    
    # compute Riemann tensor

    # cutting the shape of Christoffels to compute Riemann
    assert indices_central[0] == 2 # the central index should be 2 indeed
    Christoffel = Christoffel[:, indices_central].squeeze() # new shape: batch_size * d * d * d
    # b is batch_size i, j, k, l, p are local coordinates
    Riemann = torch.einsum("biljk->bijkl",dChristoffel) - torch.einsum("bikjl->bijkl",dChristoffel)
    Riemann += torch.einsum("bikp,bplj->bijkl", Christoffel, Christoffel) - torch.einsum("bilp,bpkj->bijkl", Christoffel, Christoffel)
    # Riemann shape: batch_size * d * d * d
    del dChristoffel, Christoffel

    # compute Ricci
    # b is batch_size c, s, r are local coordinates
    Ricci = torch.einsum("bcscr->bsr",Riemann)
    del Riemann

    # compute scalar curvature
    # cutting the shape of the inverse of the metric. Only needed at one central point:
    g_inv = g_inv[:,indices_central].squeeze() # shape batch_size * d * d
    # b is batch_size s, r are local coordinates
    Sc = torch.einsum('bsr,bsr->b',g_inv,Ricci)
    del Ricci, g_inv
    return Sc

# Timing: metric vs scalar curvature computation

In [None]:
import timeit
import json

# Define the number of iterations for averaging
iterations = 1

batch_sizes = [32, 64, 128, 256, 512]  # Different batch sizes to test

# Initialize a list to hold timing results
timing_results = []

# Generate grid and centers based on the fixed numsteps
h = 0.01  # Step size (arbitrary)
centers = torch.randn(max(batch_sizes), 2)  # Example centers, random values

# Loop through different batch sizes
for batch_size in batch_sizes:
    # Adjust centers and batch_minigrids to match the current batch_size
    current_centers = centers[:batch_size]
    
    # Timing for metric_fd_batch_minigrids
    time_metric_fd = timeit.timeit(
        stmt="metric_fd_batch_minigrids(current_centers, decoder, h)",
        setup="from __main__ import metric_fd_batch_minigrids, current_centers, decoder, h",
        number=iterations
    )

    # Timing for Sc_fd_batch_minigrids
    time_sc_fd = timeit.timeit(
        stmt="Sc_fd_batch_minigrids(current_centers, decoder)",
        setup="from __main__ import Sc_fd_batch_minigrids, current_centers, decoder",
        number=iterations
    )

    # Append the results to the timing_results list
    timing_results.append({
        "batch_size": batch_size,
        "metric_fd_avg_time": time_metric_fd / iterations,
        "Sc_fd_avg_time": time_sc_fd / iterations,
    })

# Output the timing results
for result in timing_results:
    print(f"Batch Size: {result['batch_size']}, Metric_fd_avg_time: {result['metric_fd_avg_time']:.6f} sec, "
          f"Sc_fd_avg_time: {result['Sc_fd_avg_time']:.6f} sec")

plotting


In [None]:
# Extract the batch sizes, metric_fd times, and Sc_fd times
batch_sizes = [result['batch_size'] for result in timing_results]
metric_fd_times = [result['metric_fd_avg_time'] for result in timing_results]
sc_fd_times = [result['Sc_fd_avg_time'] for result in timing_results]

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(batch_sizes, metric_fd_times, label='Metric $g$', marker='o')
plt.plot(batch_sizes, sc_fd_times, label='Scalar curvature $R$', marker='s')

# Add labels and title
plt.xlabel('Batch Size')
plt.ylabel('Average Time (seconds)')
plt.title('Timing metric $g$ vs scalar curvature $R$ computation with f.d. for different batch sizes')
plt.legend()
# Set x-ticks to be the actual batch size values
plt.xticks(batch_sizes)  # Setting the x-ticks to match batch sizes
# Show the plot
plt.grid(True)
# Save the plot
#plt.savefig(Path_pictures+'/timing_metric_Sc_minigrids.pdf', bbox_inches='tight')
plt.show()

# Timing: metric computation breakdown

In [None]:
# Define the number of iterations for averaging
iterations = 1

batch_sizes = [32, 64, 128, 256, 512]  # Different batch sizes to test

# Initialize a list to hold timing results
timing_results = []

# Generate grid and centers based on the fixed numsteps
h = 0.01  # Step size (arbitrary)
minigrid_side = 7
centers = torch.randn(max(batch_sizes), 2)  # Example centers, random values

# Generate batch mini-grids for the current numsteps
batch_minigrids = build_mini_grid_batch(centers, h=h)

# Loop through different batch sizes
for batch_size in batch_sizes:
    # Adjust batch_minigrids to match the current batch_size
    
    current_centers = centers[:batch_size]
    current_batch_minigrids = batch_minigrids[:batch_size]
    psi = decoder(current_batch_minigrids)

    # Step 1: Time for decoder(batch_minigrids)
    time_decoder = timeit.timeit(
        stmt="psi = decoder(current_batch_minigrids)",
        setup="from __main__ import decoder, current_batch_minigrids",
        number=iterations
    )

    # Step 2: Time for diff_by_x_minigrids and diff_by_y_minigrids
    time_diff = timeit.timeit(
        stmt="""
dpsidx = diff_by_x_minigrids(psi, h) 
dpsidy = diff_by_y_minigrids(psi, h) 
dpsi = torch.cat(( dpsidx.unsqueeze(-1), dpsidy.unsqueeze(-1) ), -1) 
metric = torch.einsum('bghDi,bghDj->bghij', dpsi,dpsi)
        """,
        setup="from __main__ import torch, diff_by_x_minigrids, diff_by_y_minigrids, decoder, current_centers, h, psi",
        number=iterations
    )

    # Step 3: Time for metric_fd_batch_minigrids
    time_metric_fd = timeit.timeit(
        stmt="metric_fd_batch_minigrids(current_centers, decoder, h)",
        setup="from __main__ import metric_fd_batch_minigrids, current_centers, decoder, h",
        number=iterations
    )

    # Append the results to the timing_results list
    timing_results.append({
        "batch_size": batch_size,
        "decoder_time": time_decoder / iterations,
        "diff_time": time_diff / iterations,
        "metric_fd_time": time_metric_fd / iterations,
    })

# Output the timing results
for result in timing_results:
    print(f"Batch Size: {result['batch_size']}, Decoder Time: {result['decoder_time']:.6f} sec, "
          f"Roll Time: {result['diff_time']:.6f} sec, "
          f"Metric_fd Time: {result['metric_fd_time']:.6f} sec")


plotting

In [None]:

# Extract values for plotting
batch_sizes = [result["batch_size"] for result in timing_results]
decoder_times = [result["decoder_time"] for result in timing_results]
diff_times = [result["diff_time"] for result in timing_results]
metric_fd_times = [result["metric_fd_time"] for result in timing_results]

# Create the plot
plt.figure(figsize=(10, 6))

plt.plot(batch_sizes, decoder_times, label='Decoder Time', marker='o')
plt.plot(batch_sizes, diff_times, label='Diff Time', marker='o')
plt.plot(batch_sizes, metric_fd_times, label='Metric FD Time', marker='o')

plt.xlabel('Batch Size')
plt.ylabel('Time (s)')
plt.title('Timing Results for Different Batch Sizes')
plt.legend()
plt.grid(True)
plt.xticks(batch_sizes)  # Setting the x-ticks to match batch sizes

# Save the plot
#plt.savefig(Path_pictures+'/timing_metric_batch_minigrids.pdf', bbox_inches='tight')
# Show the plot
plt.show()

# Timing: fd vs jacfwd

In [None]:
import timeit
import json

# Define the number of iterations for averaging
iterations = 100

batch_sizes = [16, 32, 64, 128, 256, 512]  # Different batch sizes to test

# Initialize a list to hold timing results
timing_results = []

# Generate grid and centers based on the fixed numsteps
h = 0.01  # Step size (arbitrary)
centers = torch.randn(max(batch_sizes), 2)  # Example centers, random values
# Generate batch mini-grids for the current numsteps
batch_minigrids = build_mini_grid_batch(centers, h=h)

# Loop through different batch sizes
for batch_size in batch_sizes:
    # Adjust centers and batch_minigrids to match the current batch_size
    current_centers = centers[:batch_size]

    # Timing for Sc_fd
    time_fd = timeit.timeit(
        stmt="Sc_fd_batch_minigrids(current_centers, function=decoder)",
        setup="from __main__ import Sc_fd_batch_minigrids, current_centers, decoder",
        number=iterations
    )

    # Timing for Sc_fd_fast
    time_fd_fast = timeit.timeit(
        stmt="Sc_fd_batch_minigrids_rhombus(current_centers, function=decoder)",
        setup="from __main__ import Sc_fd_batch_minigrids_rhombus, current_centers, decoder",
        number=iterations
    )

    # Timing for Sc_jacfwd
    time_jacfwd = timeit.timeit(
        stmt="ricci_regularization.Sc_jacfwd_vmap(current_centers, function=decoder)",
        setup="from __main__ import ricci_regularization, current_centers, decoder",
        number=iterations
    )

    # Append the results to the timing_results list
    timing_results.append({
        "batch_size": batch_size,
        "Sc_fd_avg_time": time_fd / iterations,
        "Sc_fd_rhombus_avg_time": time_fd_fast / iterations,
        "Sc_jacfwd_avg_time": time_jacfwd / iterations,
    })

In [None]:
# Save results to a JSON file
with open(Path_pictures+'/timing_results_batch_minigrids.json', 'w') as f:
    json.dump(timing_results, f, indent=4)

# Print the timing results
for result in timing_results:
    print(result)

In [None]:
# Plotting the results
batch_sizes = [result['batch_size'] for result in timing_results]
sc_fd_times = [result['Sc_fd_avg_time'] for result in timing_results]
sc_fd_rhombus_times = [result['Sc_fd_rhombus_avg_time'] for result in timing_results]
sc_jacfwd_times = [result['Sc_jacfwd_avg_time'] for result in timing_results]

plt.figure(figsize=(10, 6))

# Plot average times for Sc_fd and Sc_jacfwd_vmap
plt.plot(batch_sizes, sc_fd_times, marker='o', label='fd on square mini_grids', linestyle='-')
plt.plot(batch_sizes, sc_fd_rhombus_times, marker='o', label='fd on rhombus mini_grids', linestyle='-')
plt.plot(batch_sizes, sc_jacfwd_times, marker='s', label='jacfwd', linestyle='-')

# Adding labels and title
plt.ylabel('Average Time (seconds)')
plt.xlabel('Batch Size')
plt.title(f'Timing scalar curvature $R$ computation: fd on minigrids vs jacfwd')
plt.grid()
plt.legend()
# Set x-ticks to be the actual batch size values
plt.xticks(batch_sizes)  # Setting the x-ticks to match batch sizes

# Save the plot
plt.savefig(Path_pictures+'/timing_jacfwd_fd_square_rhombus.pdf', bbox_inches='tight')
plt.show()

# Error plot

In [None]:
torch.manual_seed(0)
# Assume tensor_jacfwd is some precomputed tensor (ground truth)
batch_size = 128  # Just as an example
centers = torch.rand(batch_size, 2)  # Simulated ground truth

# We will compute tensor_fd with varying h
h_values = np.logspace(-3, -1, 10)  # Step sizes in logarithmic scale from 1e-5 to 1e-1
errors = []
mean_relative_errors = []
mean_abs_values = []
mae_errors = []

for h in h_values:
    # Simulate tensor_fd by perturbing tensor_jacfwd with some finite difference approximation
    with torch.no_grad():
        tensor_fd = Sc_fd_batch_minigrids(centers, function= decoder,h=h)  # Simulate FD grid
    tensor_jacfwd = ricci_regularization.Sc_jacfwd_vmap(centers,function= decoder).detach()
    # Compute the error for this step size
    error = torch.functional.F.mse_loss(tensor_fd, tensor_jacfwd)
    mean_abs_values.append( torch.mean( torch.abs(tensor_jacfwd) ) )
    errors.append(error.item())  # Store the error as a scalar
    mae_errors.append( torch.mean( torch.abs( tensor_fd - tensor_jacfwd ) ) )
    mean_relative_errors.append( 100*( torch.abs( tensor_fd - tensor_jacfwd ) / torch.abs(tensor_jacfwd) ).mean() )

In [None]:
# Now we plot the error vs. h
plt.figure(figsize=(8, 6))
plt.loglog(h_values, mean_relative_errors, marker='o', label="Relative error in %")


plt.xlabel('Step size (h)')
plt.ylabel('Relative error of $|R|$ in %')
plt.title(f'FP32: Mean relative error of scalar curvature $R$ computation f.d. vs jacfwd, batch size: {batch_size}.')

plt.xticks(h_values, [f'{h:.3f}' for h in h_values])  # Ensuring h_values are shown as tick labels # Setting the x-ticks to match h_values
plt.yticks(mean_relative_errors, [f'{y:.2f}' for y in mean_relative_errors])  # Ensuring h_values are shown as tick labels # Setting the x-ticks to match h_values
plt.legend()
plt.grid(True)

plt.savefig(Path_pictures+"/fd_relative_error.pdf", bbox_inches='tight', format = "pdf")
plt.show()

In [None]:
# Now we plot the error vs. h
plt.figure(figsize=(8, 6))
plt.loglog(h_values, errors, marker='o', label="MSE Error")
plt.loglog(h_values, mae_errors, marker='o', label="MAE Error")
plt.loglog(h_values, mean_abs_values, marker='o', label="Mean value of $|R|$")

plt.xlabel('Step size (h)')
plt.ylabel('Error ')
plt.title('FP32: Errors vs. Step Size for f.d. on minigrid for scalar curvature $R$')

plt.xticks(h_values, [f'{h:.3f}' for h in h_values])  # Ensuring h_values are shown as tick labels # Setting the x-ticks to match h_values
plt.legend(loc = "center left")
plt.grid(True)

plt.savefig(Path_pictures+"/fd_minigrid_error.pdf", bbox_inches='tight', format = "pdf")
plt.show()