NB! This is an old notebook. It contains a wrong (very unprecise) way of computing Scalar curvature of the latent space with f.d. formulas used for computing derivatives. Has to be redone!

The autoencoder (AE) consists of the encoder $\Phi$ and the decoder $\Psi$.
The latent space of the AE is $R^d$. We define a Riemannian metric in a local chart of the latent space as the pull-back of the Euclidean metric in the output space $R^D$ by the decoder function $\Psi$ of the AE:
\begin{equation*}
    g = \nabla \Psi ^* \nabla \Psi   
\end{equation*}.

The notebook contains:
1) Loading weights of a pre-trained convolutional AE and plotting its latent space: point plot and manifold plot. If "violent_saving" == True, plots are saved locally.
2) Auxillary tensors involving higher order derivatives of the decoder $\Psi$ are computed with f.d.: metric $g$ and its derivatives, Riemann tensor $R^{i}_{jkl}$, Ricci tensor $R_{ij}$ and scalar curvature.
3) Geodesics shooting via Runge-Kutta approximation. A single plot with a scalar curvature heatmap and geodesics on it is constructed.
4) Prototype of metric evolution by Ricci flow equation 

NB! by default the metric $g$ is the pull-back by the decoder as described above. But one can use any custom metric by manually setting it in "specific_metric" function, that computes the metric matrix at a point $u\in \mathbb{R}: \ g(u)$ given the local coordinates of the point $u$ in the latent space.

In [None]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd
import torch
import torchvision
import ricci_regularization
import yaml

In [None]:
with open('../../experiments/MNIST_Setting_1_config.yaml', 'r') as yaml_file:
#with open('../../experiments/MNIST01_exp7_config.yaml', 'r') as yaml_file:
#with open('../../experiments/Swissroll_exp4_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

violent_saving = False # if False it will not save plots

d = 2

# Loading data and nn weights

In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"]
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")
additional_path="../"

In [None]:
experiment_name = yaml_config["experiment"]["name"]

#Path_pictures = yaml_config["experiment"]["path"]
Path_pictures = additional_path + "../experiments/" + yaml_config["experiment"]["name"]
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name in ["MNIST01", "Synthetic"]:
    # k from the JSON configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
elif dataset_name == "MNIST":
    k = 10
print("Experiment name:", experiment_name)
print("Plots saved at:", Path_pictures)

In [None]:
### Initialize the two networks

torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config, additional_path = additional_path)

torus_ae = torus_ae.to("cpu")

print(f"AE weights loaded successfully from {Path_ae_weights}.")

encoder = torus_ae.encoder_torus
decoder = torus_ae.decoder_torus

# Using mini-grids + minimal verification

In [None]:
def build_mini_grid(center: torch.Tensor, h: float = 1.0) -> torch.Tensor:
    """
    Builds a 7x7 mini-grid around a given center tensor with step size h.

    Args:
        center (torch.Tensor): A tensor representing the center of the grid (2D point).
        h (float): The step size between grid points. Default is 1.0.

    Returns:
        torch.Tensor: A 7x7 grid of shape (7, 7, 2) where each element is a coordinate.
    """
    # Create a 7x7 grid of relative coordinates (i, j) scaled by step size h
    offset = torch.arange(-3, 4) * h  # Relative offsets from the center (-3h, -2h, ..., 3h)
    grid_x, grid_y = torch.meshgrid(offset, offset, indexing='ij')  # 7x7 grid for x and y
    
    # Stack the coordinates (x, y) together and add to the center
    grid = torch.stack([grid_x, grid_y], dim=-1).float()  # Shape: (7, 7, 2)
    
    # Add the center coordinate to every point in the grid
    mini_grid = grid + center.unsqueeze(0).unsqueeze(0)  # Broadcasting center to grid shape
    
    return mini_grid

def build_mini_grid_batch(centers: torch.Tensor, h: float = 1.0) -> torch.Tensor:
    """
    Builds a batch of mini-grids centered at the given batch of points.
    
    Args:
        centers (torch.Tensor): A 2D tensor with shape (N, 2) representing N centers.
        grid_size (int): The size of the mini-grid (grid_size x grid_size).
        h (float): The step size for the grid.
        
    Returns:
        torch.Tensor: A batch of mini-grids of shape (N, grid_size * grid_size, 2).
    """
    offset = torch.arange(-3, 4) * h  # Relative offsets from the center (-3h, -2h, ..., 3h)
    grid_x, grid_y = torch.meshgrid(offset, offset, indexing='ij')  # 7x7 grid for x and y

    # Stack the coordinates (x, y) together and add to the center
    mini_grid = torch.stack([grid_x, grid_y], dim=-1).float()  # Shape: (7, 7, 2)
    mini_grid = mini_grid.reshape(49,2) # Shape: (49, 2)
    # Expand dimensions to match the number of centers
    mini_grid = mini_grid.unsqueeze(0)  # shape: (1, grid_size * grid_size, 2)

    # Broadcast the centers to create the batch
    centers = centers.unsqueeze(1)  # shape: (N, 1, 2)

    # Add the centers to the mini-grid points
    batch_minigrids = mini_grid + centers  # shape: (N, grid_size * grid_size, 2)
    return batch_minigrids

In [None]:
centers = torch.rand(10,2)
h = 0.01
batch_minigrids = build_mini_grid_batch(centers, h)   

print("batch of grids built correctly:",torch.equal( batch_minigrids[:,24,:], centers))

In [None]:
decoder(batch_minigrids).shape

In [None]:
def metric_fd_batch_minigrids(batch_minigrids, function):
    h = (batch_minigrids[0,1] - batch_minigrids[0,0]).norm()
    psi = function(batch_minigrids)
    psi_next_x =  psi.roll(-1,1)
    psi_prev_x =  psi.roll(1,1)
    psi_next_y =  psi.roll(-7,1)
    psi_prev_y =  psi.roll(7,1)

    dpsidx = (psi_next_x - psi_prev_x)/(2 * h)
    dpsidy = (psi_next_y - psi_prev_y)/(2 * h)
    E = torch.einsum('bgD,bgD->bg',dpsidx, dpsidx)
    F = torch.einsum('bgD,bgD->bg',dpsidx, dpsidy)
    G = torch.einsum('bgD,bgD->bg',dpsidy, dpsidy)

    metric = torch.cat((G.unsqueeze(-1), F.unsqueeze(-1), F.unsqueeze(-1), E.unsqueeze(-1)),-1)
    metric = metric.view(-1, 7 * 7, 2, 2)
    return metric


In [None]:
metric = metric_fd_batch_minigrids(batch_minigrids, decoder)
metric[7][24]

In [None]:
ricci_regularization.metric_jacfwd(centers[7],decoder)

In [None]:
def diff_by_x_minigrids(tensor_on_batch_minigrids, h):
    tensor_next_x =  tensor_on_batch_minigrids.roll(-7,1)
    tensor_prev_x =  tensor_on_batch_minigrids.roll(7,1)
    tensor_dx = (tensor_next_x - tensor_prev_x)/(2*h)
    return tensor_dx

def diff_by_y_minigrids(tensor_on_batch_minigrids, h):
    psi_next_y =  tensor_on_batch_minigrids.roll(-1,1)
    psi_prev_y =  tensor_on_batch_minigrids.roll(1,1)
    dpsidy = (psi_next_y - psi_prev_y)/(2*h)
    return dpsidy

def metric_der_fd_batch_minigrids(batch_minigrids, function):
    h = (batch_minigrids[0,1] - batch_minigrids[0,0]).norm()
    metric = metric_fd_batch_minigrids(batch_minigrids, 
                    function = function)
    dg_dx_fd = diff_by_x_minigrids(metric, h = h)
    dg_dy_fd = diff_by_y_minigrids(metric, h = h)
    dg = torch.cat((dg_dx_fd.unsqueeze(-1), dg_dy_fd.unsqueeze(-1)), dim = -1)
    return dg


In [None]:
metric_der_fd_batch_minigrids(batch_minigrids, decoder)[7][24]

In [None]:
ricci_regularization.metric_der_jacfwd(centers[7],decoder)

In [None]:
def metric_inv_batch_minigrids(batch_minigrids, function, eps=0.0):
    g = metric_fd_batch_minigrids(batch_minigrids, function)
    d = g.shape[-1]
    device = g.device
    g_inv = torch.inverse(g + eps*torch.eye(d,device=device))
    return g_inv

#metric_inv_jacfd_vmap = torch.func.vmap(metric_inv_fd)

def Ch_fd_batch_minigrids (batch_minigrids, function, eps = 0.0):
    g_inv = metric_inv_batch_minigrids(batch_minigrids,function,
                                       eps=eps)
    dg = metric_der_fd_batch_minigrids(batch_minigrids,function)
    Ch = 0.5*(torch.einsum('bgim,bgmkl->bgikl',g_inv,dg)+
              torch.einsum('bgim,bgmlk->bgikl',g_inv,dg)-
              torch.einsum('bgim,bgklm->bgikl',g_inv,dg)
              )
    return Ch
#Ch_fd_vmap = torch.func.vmap(Ch_fd)

def Ch_der_fd_batch_minigrids (grid, function, eps=0.0):
    h = (batch_minigrids[0,1] - batch_minigrids[0,0]).norm()

    Ch = Ch_fd_batch_minigrids(grid, function=function, eps=eps)
    dChdx = diff_by_x_minigrids(Ch, h)
    dChdy = diff_by_y_minigrids(Ch, h)
    dCh = torch.cat((dChdx.unsqueeze(-1), dChdy.unsqueeze(-1)), dim = -1)
    return dCh

In [None]:
Ch_der_fd_batch_minigrids(batch_minigrids, decoder)[7][24]

In [None]:
ricci_regularization.Ch_der_jacfwd(centers[7],decoder)

In [None]:
# Riemann curvature tensor (3,1)
def Riem_fd_batch_minigrids(u, function,eps=0.0):
    Ch = Ch_fd_batch_minigrids(u, function, eps=eps)
    Ch_der = Ch_der_fd_batch_minigrids(u, function, eps=eps)

    Riem = torch.einsum("bgiljk->bgijkl",Ch_der) - torch.einsum("bgikjl->bgijkl",Ch_der)
    Riem += torch.einsum("bgikp,bgplj->bgijkl", Ch, Ch) - torch.einsum("bgilp,bgpkj->bgijkl", Ch, Ch)
    return Riem

def Ric_fd_batch_minigrids(u, function, eps=0.0):
    Riemann = Riem_fd_batch_minigrids(u, function, eps=eps)
    Ric = torch.einsum("bgcscr->bgsr",Riemann)
    return Ric

def Sc_fd_batch_minigrids_slow (u, function, eps = 0.0):
    Ricci = Ric_fd_batch_minigrids(u, function=function,eps=eps)
    metric_inv = metric_inv_batch_minigrids(u,function=function, eps=eps)
    Sc = torch.einsum('bgsr,bgsr->bg',metric_inv,Ricci)
    return Sc


# REDO this

minigrid_side = 5
psi = decoder(batch_minigrids)
psi_matrix = psi.reshape(-1,7,7,D)
#psi_next_x =  psi_matrix[:,:,2:,:].reshape(-1, minigrid_side * minigrid_side, D)
#psi_prev_x =  psi_matrix[:,:,:-2,:].reshape(-1, minigrid_side * minigrid_side, D)
#psi_next_y =  psi_matrix[:,2:,:,:].reshape(-1, minigrid_side * minigrid_side, D)
#psi_prev_y =  psi_matrix[:,:-2,:,:].reshape(-1, minigrid_side * minigrid_side, D)
psi_next_x =  psi_matrix[:,1:-1,2:,:]
psi_prev_x =  psi_matrix[:,1:-1,:-2,:]
psi_next_y =  psi_matrix[:,2:,1:-1,:]
psi_prev_y =  psi_matrix[:,:-2,1:-1,:]


dpsidx = (psi_next_x - psi_prev_x)/(2 * h)
dpsidy = (psi_next_y - psi_prev_y)/(2 * h)

def diff_by_x_minigrids_fast(tensor_on_batch_minigrids, h, minigrid_side):
    dim = tensor_on_batch_minigrids.shape[-1]
    tensor_on_batch_minigrids_matrix = tensor_on_batch_minigrids.reshape(-1,minigrid_side,minigrid_side, dim)
    tensor_next_x =  tensor_on_batch_minigrids_matrix[:,1:-1,2:,:]
    tensor_prev_x =  tensor_on_batch_minigrids_matrix[:,1:-1,:-2,:]
    tensor_dx = (tensor_next_x - tensor_prev_x)/(2*h)
    
    return tensor_dx

def diff_by_y_minigrids_fast(tensor_on_batch_minigrids, h, minigrid_side):
    dim = tensor_on_batch_minigrids.shape[-1]
    tensor_on_batch_minigrids_matrix = tensor_on_batch_minigrids.reshape(-1,minigrid_side,minigrid_side, dim)
    tensor_next_y =  tensor_on_batch_minigrids_matrix[:,2:,1:-1,:]
    tensor_prev_y =  tensor_on_batch_minigrids_matrix[:,:-2,1:-1,:]
    tensor_dy = (tensor_next_y - tensor_prev_y)/(2*h)

    return tensor_dy


def metric_fd_batch_minigrids(batch_minigrids, function, h=0.01):
    
    psi = function(batch_minigrids)
    psi_next_x =  psi.roll(-1,1)
    psi_prev_x =  psi.roll(1,1)
    psi_next_y =  psi.roll(-7,1)
    psi_prev_y =  psi.roll(7,1)

    dpsidx = (psi_next_x - psi_prev_x)/(2 * h)
    dpsidy = (psi_next_y - psi_prev_y)/(2 * h)
    E = torch.einsum('bgD,bgD->bg',dpsidx, dpsidx)
    F = torch.einsum('bgD,bgD->bg',dpsidx, dpsidy)
    G = torch.einsum('bgD,bgD->bg',dpsidy, dpsidy)

    metric = torch.cat((G.unsqueeze(-1), F.unsqueeze(-1), F.unsqueeze(-1), E.unsqueeze(-1)),-1)
    metric = metric.view(-1, 7 * 7, 2, 2)
    return metric


def Sc_fd_batch_minigrids (centers, function, h=0.01, eps = 0.0):
    
    #create a batch of minigrids with given centers and step h
    batch_minigrids = build_mini_grid_batch(centers, h)
    #compute metric
    g = metric_fd_batch_minigrids(batch_minigrids, function, h)
    #compute inverse
    d = g.shape[-1]
    device = g.device
    g_inv = torch.inverse(g + eps*torch.eye(d,device=device))
    #compute metric derivatives
    dg_dx_fd = diff_by_x_minigrids(g, h = h)
    dg_dy_fd = diff_by_y_minigrids(g, h = h)
    del g
    dg = torch.cat((dg_dx_fd.unsqueeze(-1), dg_dy_fd.unsqueeze(-1)), dim = -1)
    del dg_dx_fd, dg_dy_fd
    #compute Christoffel symbols
    Christoffel = 0.5*(torch.einsum('bgim,bgmkl->bgikl',g_inv,dg)+
              torch.einsum('bgim,bgmlk->bgikl',g_inv,dg)-
              torch.einsum('bgim,bgklm->bgikl',g_inv,dg)
              )
    del dg
    #compute Christoffel symbols' derivatives
    #dChristoffel_dx = diff_by_x_minigrids_fast(Christoffel, h,minigrid_side=7)
    #dChristoffel_dy = diff_by_y_minigrids_fast(Christoffel, h, minigrid_side=7)
    
    dChristoffel_dx = diff_by_x_minigrids(Christoffel, h)
    dChristoffel_dy = diff_by_y_minigrids(Christoffel, h)

    dChristoffel = torch.cat((dChristoffel_dx.unsqueeze(-1),
                              dChristoffel_dy.unsqueeze(-1)), dim = -1)
    del dChristoffel_dx, dChristoffel_dy
    #Compute Riemann tensor
    Riemann = torch.einsum("bgiljk->bgijkl",dChristoffel) - torch.einsum("bgikjl->bgijkl",dChristoffel)
    Riemann += torch.einsum("bgikp,bgplj->bgijkl", Christoffel, Christoffel) - torch.einsum("bgilp,bgpkj->bgijkl", Christoffel, Christoffel)
    del dChristoffel, Christoffel
    #Compute Ricci
    Ricci = torch.einsum("bgcscr->bgsr",Riemann)
    del Riemann, 
    #Scalar curvature
    Sc = torch.einsum('bgsr,bgsr->bg',g_inv,Ricci)
    del Ricci, g_inv
    return Sc


In [None]:
Sc_fd_batch_minigrids(centers, decoder).shape

In [None]:
centers = torch.rand(100,2)
h = 0.01
batch_minigrids = build_mini_grid_batch(centers, h)


In [None]:
psi = decoder(batch_minigrids)
psi.shape

In [None]:
Ch_fd_batch_minigrids(batch_minigrids,decoder).shape

In [None]:
psi_matrix = psi.reshape(-1,7,7,D)

In [None]:
metric_fd_batch_minigrids(batch_minigrids,decoder)[0][24]

In [None]:
#psi_next_x = psi_matrix[:,:,1:,:]
#print(psi_next_x)

minigrid_side = 5
psi = decoder(batch_minigrids)
psi_matrix = psi.reshape(-1,7,7,D)
#psi_next_x =  psi_matrix[:,:,2:,:].reshape(-1, minigrid_side * minigrid_side, D)
#psi_prev_x =  psi_matrix[:,:,:-2,:].reshape(-1, minigrid_side * minigrid_side, D)
#psi_next_y =  psi_matrix[:,2:,:,:].reshape(-1, minigrid_side * minigrid_side, D)
#psi_prev_y =  psi_matrix[:,:-2,:,:].reshape(-1, minigrid_side * minigrid_side, D)
psi_next_x =  psi_matrix[:,1:-1,2:,:]
psi_prev_x =  psi_matrix[:,1:-1,:-2,:]
psi_next_y =  psi_matrix[:,2:,1:-1,:]
psi_prev_y =  psi_matrix[:,:-2,1:-1,:]


dpsidx = (psi_next_x - psi_prev_x)/(2 * h)
dpsidy = (psi_next_y - psi_prev_y)/(2 * h)
E = torch.einsum('bghD,bghD->bgh',dpsidx, dpsidx)
F = torch.einsum('bghD,bghD->bgh',dpsidx, dpsidy)
G = torch.einsum('bghD,bghD->bgh',dpsidy, dpsidy)

metric = torch.cat((G.unsqueeze(-1), F.unsqueeze(-1), F.unsqueeze(-1), E.unsqueeze(-1)),-1)
metric = metric.view(-1, minigrid_side * minigrid_side, 2, 2)

In [None]:
metric[0].shape

In [None]:
(psi[:,2:,:] - psi[:,:-2,:]).shape

In [None]:
psi[:,:-14,:].shape

In [None]:
metric[0][13]

In [None]:
ricci_regularization.metric_jacfwd(centers[0], decoder)

In [None]:
psi = function(batch_minigrids)
psi_next_x =  psi.roll(-1,1)
psi_prev_x =  psi.roll(1,1)
psi_next_y =  psi.roll(-7,1)
psi_prev_y =  psi.roll(7,1)

dpsidx = (psi_next_x - psi_prev_x)/(2 * h)
dpsidy = (psi_next_y - psi_prev_y)/(2 * h)
E = torch.einsum('bgD,bgD->bg',dpsidx, dpsidx)
F = torch.einsum('bgD,bgD->bg',dpsidx, dpsidy)
G = torch.einsum('bgD,bgD->bg',dpsidy, dpsidy)

metric = torch.cat((G.unsqueeze(-1), F.unsqueeze(-1), F.unsqueeze(-1), E.unsqueeze(-1)),-1)
metric = metric.view(-1, 7 * 7, 2, 2)

In [None]:
psi_matrix[:,:,:-2,:].shape

In [None]:
psi.roll(-1,1)

In [None]:
lo

In [None]:
Sc_fd_batch_minigrids(centers, decoder)

In [None]:
Sc_fd_batch_minigrids_slow(batch_minigrids,decoder)

In [None]:
#Sc_fd_batch_minigrids(batch_minigrids, decoder)[7][24]
#ricci_regularization.Sc_jacfwd(centers[7],decoder)

errors

In [None]:
def error_fd_jacfwd_batch_minigrids(tensor_fd, tensor_jacfwd):
    batch_size = tensor_fd.shape[0]
    #finite differences
    tensor_fd_central = tensor_fd[:, 24]

    error = torch.functional.F.mse_loss(tensor_fd_central, tensor_jacfwd)
    return error

In [None]:
metric_fd = metric_fd_batch_minigrids(batch_minigrids, function=decoder)
metric_jacfwd = ricci_regularization.metric_jacfwd_vmap(centers, function=decoder)

In [None]:
error_fd_jacfwd_batch_minigrids(metric_fd, metric_jacfwd)

In [None]:
error_fd_jacfwd_batch_minigrids(metric_fd, metric_jacfwd)

# Error plot

In [None]:
la

In [None]:
h_values = np.logspace(-4, -1, 5) 

In [None]:
h_values

In [None]:
#tensor_fd = Sc_fd_batch_minigrids(batch_minigrids, function= decoder)  # Simulate FD grid
#tensor_jacfwd = ricci_regularization.Sc_jacfwd_vmap(batch_minigrids,function= decoder)

In [None]:
h = 0.01
batch_size = 64  # Just as an example
centers = torch.rand(batch_size, 2)
batch_minigrids = build_mini_grid_batch(centers = centers, h = h)
Sc_fd_batch_minigrids(batch_minigrids, function= decoder).shape

In [None]:
ricci_regularization.Sc_jacfwd_vmap(centers,function= decoder).shape

In [None]:
torch.manual_seed(0)
# Assume tensor_jacfwd is some precomputed tensor (ground truth)
batch_size = 128  # Just as an example
centers = torch.rand(batch_size, 2)  # Simulated ground truth

# We will compute tensor_fd with varying h
h_values = np.logspace(-4, -1, 10)  # Step sizes in logarithmic scale from 1e-5 to 1e-1
errors = []

for h in h_values:
    batch_minigrids = build_mini_grid_batch(centers = centers, h = h)
    # Simulate tensor_fd by perturbing tensor_jacfwd with some finite difference approximation
    tensor_fd = Sc_fd_batch_minigrids(batch_minigrids, function= decoder)  # Simulate FD grid
    tensor_jacfwd = ricci_regularization.Sc_jacfwd_vmap(centers,function= decoder)
    # Compute the error for this step size
    error = error_fd_jacfwd_batch_minigrids(tensor_fd, tensor_jacfwd)
    errors.append(error.item())  # Store the error as a scalar


In [None]:
# Now we plot the error vs. h
plt.figure(figsize=(8, 6))
plt.loglog(h_values, errors, marker='o', label="MSE Error")
plt.xlabel('Step size (h)')
plt.ylabel('Error (MSE)')
plt.title('Error vs. Step Size for Finite Differences on minigrid for scalar curvature computation')
plt.legend()
plt.grid(True, which="both", ls="--")
plt.savefig(Path_pictures+"/fd_minigrid_error.pdf", bbox_inches='tight', format = "pdf")
plt.show()

# Timing

In [None]:
centers = torch.rand(100,2)
h = 0.01
batch_minigrids = build_mini_grid_batch(centers, h)   

In [None]:
Sc_fd = Sc_fd_batch_minigrids(batch_minigrids, function=decoder)
Sc_jacfwd = ricci_regularization.Sc_jacfwd_vmap(centers, function=decoder)

In [None]:
import timeit
import json

# Define the number of iterations for averaging
iterations = 100

batch_sizes = [8, 16, 32,40, 64, 128, 256]  # Different batch sizes to test

# Initialize a list to hold timing results
timing_results = []

# Generate grid and centers based on the fixed numsteps
h = 0.01  # Step size (arbitrary)
centers = torch.randn(max(batch_sizes), 2)  # Example centers, random values
# Generate batch mini-grids for the current numsteps
batch_minigrids = build_mini_grid_batch(centers, h=h)

# Loop through different batch sizes
for batch_size in batch_sizes:
    # Adjust centers and batch_minigrids to match the current batch_size
    current_centers = centers[:batch_size]
    current_batch_minigrids = batch_minigrids[:batch_size]

    # Timing for Sc_fd
    time_fd = timeit.timeit(
        stmt="Sc_fd_batch_minigrids(current_batch_minigrids, function=decoder)",
        setup="from __main__ import Sc_fd_batch_minigrids, current_batch_minigrids, decoder",
        number=iterations
    )

    # Timing for Sc_jacfwd
    time_jacfwd = timeit.timeit(
        stmt="ricci_regularization.Sc_jacfwd_vmap(current_centers, function=decoder)",
        setup="from __main__ import ricci_regularization, current_centers, decoder",
        number=iterations
    )

    # Append the results to the timing_results list
    timing_results.append({
        "batch_size": batch_size,
        "Sc_fd_avg_time": time_fd / iterations,
        "Sc_jacfwd_avg_time": time_jacfwd / iterations,
    })

# Save results to a JSON file
with open('timing_results_batch_minigrids.json', 'w') as f:
    json.dump(timing_results, f, indent=4)

# Print the timing results
for result in timing_results:
    print(result)

In [None]:
# Plotting the results
batch_sizes = [result['batch_size'] for result in timing_results]
sc_fd_times = [result['Sc_fd_avg_time'] for result in timing_results]
sc_jacfwd_times = [result['Sc_jacfwd_avg_time'] for result in timing_results]

plt.figure(figsize=(10, 6))

# Plot average times for Sc_fd and Sc_jacfwd_vmap
plt.plot(batch_sizes, sc_fd_times, marker='o', label='fd on mini_grids', linestyle='-')
plt.plot(batch_sizes, sc_jacfwd_times, marker='s', label='jacfwd', linestyle='-')

# Adding labels and title
plt.ylabel('Average Time (seconds)')
plt.xlabel('Batch Size')
plt.title(f'Timing scalar curvature $R$ computation: fd on minigrids vs jacfwd')
plt.grid()
plt.legend()
# Set x-ticks to be the actual batch size values
plt.xticks(batch_sizes)  # Setting the x-ticks to match batch sizes

# Save the plot
plt.savefig(Path_pictures+'/timing_results_batch_minigrids.pdf', bbox_inches='tight')
plt.show()