This notebook produces the barplots comparing Autograd and finite differences with presicion FP32 and FP64. The plots are in Section 4.3.3 of chapter 4 of my thesis.

In [1]:
import torch, ricci_regularization
import numpy as np
import matplotlib.pyplot as plt
from ricci_regularization import Sc_g_fd_batch_minigrids_rhombus
import math
from matplotlib.lines import Line2D

In [None]:
torch.manual_seed(0)
Path_pictures = "../../experiments/AD_FD"

# Choose dtype
dtype = torch.float32
#dtype = torch.float64

# Choose latent dimension
d = 2 # for FD using rhombus it can only be == 2
torus_ae = ricci_regularization.Architectures.TorusAE(
        x_dim=784,
        h_dim1=512,
        h_dim2=256,
        z_dim=d,
        dtype=dtype
    )

# Standard FP32 and FP64 error levels
fp32_error_level = 5.96e-8
fp64_error_level = 1.11e-15

In [None]:
decoder = torus_ae.decoder_torus

In [None]:
next(torus_ae.parameters()).device

# Computing relative errors for different step h of FD

In [None]:
torch.manual_seed(0)
tensor_name = "R" # "R"
# Assume tensor_jacfwd is some precomputed tensor (ground truth)
batch_size = 1024  # Just as an example
centers = 1.9*torch.pi*(torch.rand(batch_size, d, dtype=dtype) - 0.5)  # Simulated ground truth

# We will compute tensor_fd with varying h
if dtype == torch.float64:
    if tensor_name == "g":
        h_values = np.logspace(-9, -2, 7)  # Step sizes in logarithmic scale from 1e-5 to 1e-1 for FP64
    elif tensor_name == "R":
        h_values = np.logspace(-5, -1, 7)  # Step sizes in logarithmic scale from 1e-5 to 1e-1 for FP64
elif dtype == torch.float32:
    if tensor_name == "g":
        h_values = 5 * np.logspace(-5, -1, 7) #np.array([0.005, 0.01, 0.02, 0.05, 0.1, 0.2])
    elif tensor_name == "R":
        h_values = 5 * np.logspace(-3, -1, 7)  # Step sizes in logarithmic scale from 1e-5 to 1e-1 for FP64
errors = []
mean_relative_errors = []
mean_abs_values = []
mae_errors = []
distribution_of_relative_errors = []

for h in h_values:
    # Simulate tensor_fd by perturbing tensor_jacfwd with some finite difference approximation
    if tensor_name == "R":
        with torch.no_grad():
            tensor_fd,_ = Sc_g_fd_batch_minigrids_rhombus(centers, function= decoder,h=h)  # Simulate FD grid
        tensor_jacfwd = ricci_regularization.Sc_jacfwd_vmap(centers,function= decoder)[0].detach()
    elif tensor_name == "g":
        with torch.no_grad():
            _,tensor_fd = Sc_g_fd_batch_minigrids_rhombus(centers, function= decoder,h=h)  # Simulate FD grid
        tensor_jacfwd = ricci_regularization.metric_jacfwd_vmap(centers,function= decoder).detach()
    # Compute the error for this step size
    error = torch.functional.F.mse_loss(tensor_fd, tensor_jacfwd, )
    mean_abs_values.append( torch.mean( torch.abs(tensor_jacfwd) ) )
    errors.append(error.item())  # Store the error as a scalar
    mae_errors.append( torch.mean( torch.abs( tensor_fd - tensor_jacfwd ) ) )
    mean_relative_errors.append( ( torch.abs( tensor_fd - tensor_jacfwd ) / torch.abs(tensor_jacfwd) ).mean() )
    distribution_of_relative_errors.append( torch.abs( tensor_fd - tensor_jacfwd ) / torch.abs(tensor_jacfwd) )
    # in %
    #mean_relative_errors.append( 100*( torch.abs( tensor_fd - tensor_jacfwd ) / torch.abs(tensor_jacfwd) ).mean() )
log_distribution_of_relative_errors = [torch.log10(x.flatten()) for x in distribution_of_relative_errors]

In [None]:
log_distribution_of_relative_errors[0].flatten().shape

# Log of distribution of relative errors

In [None]:
# Update plot configurations
plt.rcParams.update({'font.size': 16})
plt.figure(figsize=(8, 6))

# Plot the relative error data
plt.boxplot(log_distribution_of_relative_errors, showmeans=True, meanline=True)

if dtype == torch.float32:
    # Plot the FP32 error line as a dashed line
    plt.axhline(y=math.log10(fp32_error_level), color='r', linestyle='--', linewidth=1.5, label="FP32 error (5.96e-8)")
elif dtype == torch.float64:
    plt.axhline(y=math.log10(fp64_error_level), color='r', linestyle='--', linewidth=1.5, label="FP64 error (1.11e-15)")
# Axis labels
plt.xlabel('Step size (h)')
plt.ylabel(f'Log of relative error of ${tensor_name}$')

# Set x-ticks to only values in h_values, using const × 10^n format
plt.xticks(np.arange(7)+1, [f'{h / (10**np.floor(np.log10(h))):.0f} $\cdot 10^{{{int(np.floor(np.log10(h)))}}}$' for h in h_values])

# Set y-ticks using scientific notation
#plt.yticks(mean_relative_errors, [f'{y:.0e}' for y in mean_relative_errors])  # Format y-ticks in scientific notation

# Add legend elements
legend_elements = [
    Line2D([0], [0], color='green', linestyle='--', label='Mean'),
    Line2D([0], [0], color='orange', label='Median')
]
# Legend and grid
legend = plt.legend()
plt.legend(handles = legend_elements + legend.legend_handles, loc = "lower right")
plt.grid(True)
#plt.yticks(mean_relative_errors, [f'{y:.1e}' for y in mean_relative_errors])  # Format y-ticks in scientific notation


# Save and show plot
plt.savefig(Path_pictures+"/fd_"+f'{dtype}'+f"relative_error_boxplot_{tensor_name}.pdf", bbox_inches='tight', format = "pdf")
plt.show()

# Histogram of the optimal case

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import scipy 

# Set seed for reproducibility
torch.manual_seed(0)

# Use the first part of the distribution for the test
data = log_distribution_of_relative_errors[2]
mean = data.mean()
std = data.std()

# Create the histograms
plt.hist(data.numpy(), bins=100, density=True, alpha=0.8, label="Relative errors")
plt.hist(std * torch.randn(1024) + mean, bins=100, density=True, alpha=0.5, label="Normal samples")

# Normalize the data for Shapiro-Wilk test
normalized_data = 5*(data - mean) / std + 0. * torch.randn(1024)
#normalized_data = data
# Perform Shapiro-Wilk test
statistic, p_value = scipy.stats.shapiro(normalized_data_wo_outlyers.numpy())
#statistic,  = scipy.stats.shapiro(normalized_data.numpy())
# Print the results
print(f"Shapiro-Wilk statistic: {statistic}")
print(f"P-value: {p_value}")

# Interpret the p-value
shapiro_result = ""
if p_value > 0.05:
    shapiro_result = "The data is likely normally distributed (fail to reject H0)."
else:
    shapiro_result = "The data is likely not normally distributed (reject H0)."

# Add the legend
plt.legend(loc="center left")

# Add the Shapiro-Wilk test result as text on the plot
plt.text(0., -0.15, f"Shapiro-Wilk statistic: {statistic:.3f}", transform=plt.gca().transAxes)
plt.text(0., -0.25, f"P-value: {p_value / (10**np.floor(np.log10(p_value))):.0f} $\cdot 10^{{{int(np.floor(np.log10(p_value)))}}}$", transform=plt.gca().transAxes)
plt.text(0., -0.35, f"{shapiro_result}", transform=plt.gca().transAxes)

# Save and show plot
plt.savefig(Path_pictures+"/fd_"+f'{dtype}'+f"relative_error_hisogram_{tensor_name}.pdf", bbox_inches='tight', format = "pdf")
plt.show()


In [None]:
normalized_data_wo_outlyers = torch.sort(normalized_data).values[100:-100]

In [None]:
import pylab
import scipy.stats
scipy.stats.probplot(normalized_data_wo_outlyers, dist="norm", plot=pylab)

# Log of mean relative errors only

In [None]:
# Update plot configurations
plt.rcParams.update({'font.size': 16})
plt.figure(figsize=(8, 6))

# Plot the relative error data
plt.loglog(h_values, mean_relative_errors, marker='o', label="Relative error")

if dtype == torch.float32:
    # Plot the FP32 error line as a dashed line
    plt.axhline(y=fp32_error_level, color='r', linestyle='--', linewidth=1.5, label="FP32 error (5.96e-8)")
elif dtype == torch.float64:
    plt.axhline(y=fp64_error_level, color='r', linestyle='--', linewidth=1.5, label="FP64 error (1.11e-15)")
# Axis labels
plt.xlabel('Step size (h)')
plt.ylabel('Relative error of $|R|$')

# Set x-ticks to only values in h_values, using scientific notation
#plt.xticks(h_values, [f'{h:.0e}' for h in h_values])  # Format x-ticks in scientific notation
# Set x-ticks to only values in h_values, using const × 10^n format
plt.xticks(h_values, [f'{h / (10**np.floor(np.log10(h))):.0f} $\cdot 10^{{{int(np.floor(np.log10(h)))}}}$' for h in h_values])

# Set y-ticks using scientific notation
#plt.yticks(mean_relative_errors, [f'{y:.0e}' for y in mean_relative_errors])  # Format y-ticks in scientific notation

# Legend and grid
plt.legend(loc = "center right")
plt.grid(True)
#plt.yticks(mean_relative_errors, [f'{y:.1e}' for y in mean_relative_errors])  # Format y-ticks in scientific notation


# Save and show plot
plt.savefig(Path_pictures+"/fd_"+f'{dtype}'+f"relative_error_{tensor_name}.pdf", bbox_inches='tight', format = "pdf")
plt.show()


# Absolute errors MAE, MSE

In [None]:
# Now we plot the error vs. h
plt.figure(figsize=(8, 6))
plt.loglog(h_values, errors, marker='o', label="MSE Error")
plt.loglog(h_values, mae_errors, marker='o', label="MAE Error")
plt.loglog(h_values, mean_abs_values, marker='o', label="Mean value of $|R|$")

plt.xlabel('Step size (h)')
plt.ylabel('Error ')
plt.title(f'{dtype}: Errors vs. Step Size for f.d. on minigrid for scalar curvature $R$')

plt.xticks(h_values, [f'{h:.3f}' for h in h_values])  # Ensuring h_values are shown as tick labels # Setting the x-ticks to match h_values
plt.legend(loc = "center left")
plt.grid(True)

plt.savefig(Path_pictures+"/fd_"+f'{dtype}'+"_error.pdf", bbox_inches='tight', format = "pdf")
plt.show()

# Timing AD vs FD with different batch size 

In [None]:
import timeit
import json

# Define the number of iterations for averaging
iterations = 100

batch_sizes = [16, 32, 64, 128, 256, 512]  # Different batch sizes to test

# Initialize a list to hold timing results
timing_results = []

# Generate grid and centers based on the fixed numsteps
h = 0.01  # Step size (arbitrary)
centers = torch.randn(max(batch_sizes), 2)  # Example centers, random values
# Generate batch mini-grids for the current numsteps
batch_minigrids = ricci_regularization.build_mini_grid_batch(centers, h=h)

# Loop through different batch sizes
for batch_size in batch_sizes:
    # Adjust centers and batch_minigrids to match the current batch_size
    current_centers = centers[:batch_size]

    # Timing for Sc_fd
    time_fd_fast = timeit.timeit(
        stmt="ricci_regularization.curvature_loss(current_centers, h=0.01, eps=0.0, function=decoder)",
        setup="from __main__ import ricci_regularization, current_centers, decoder",
        number=iterations
    )

    # Timing for Sc_jacfwd
    time_jacfwd = timeit.timeit(
        stmt="ricci_regularization.curvature_loss_jacfwd(current_centers, function=decoder)",
        setup="from __main__ import ricci_regularization, current_centers, decoder",
        number=iterations
    )

    # Append the results to the timing_results list
    timing_results.append({
        "batch_size": batch_size,
        "Sc_fd_rhombus_avg_time": time_fd_fast / iterations,
        "Sc_jacfwd_avg_time": time_jacfwd / iterations,
    })

In [None]:
# Save results to a JSON file
with open(Path_pictures+'/timing_results_batch_minigrids.json', 'w') as f:
    json.dump(timing_results, f, indent=4)

# Print the timing results
for result in timing_results:
    print(result)

In [None]:
batch_sizes

In [None]:
# Plotting the results
batch_sizes = [result['batch_size'] for result in timing_results][1:]
sc_fd_rhombus_times = [result['Sc_fd_rhombus_avg_time'] for result in timing_results][1:]
sc_jacfwd_times = [result['Sc_jacfwd_avg_time'] for result in timing_results][1:]

plt.figure(figsize=(10, 6))

# Plot average times for Sc_fd and Sc_jacfwd_vmap
plt.plot(batch_sizes, sc_fd_rhombus_times, marker='o', label='FD', linestyle='-')
plt.plot(batch_sizes, sc_jacfwd_times, marker='s', label='AD', linestyle='-')

# Adding labels and title
plt.ylabel('Average Time (seconds)')
plt.xlabel('Batch Size')
#plt.title('Timing curvature loss $\widehat\mathcal{L}_\mathrm{curv}$ computation: fd on minigrids vs jacfwd')
plt.grid()
plt.legend()
# Set x-ticks to be the actual batch size values
plt.xticks(batch_sizes)  # Setting the x-ticks to match batch sizes

# Save the plot
plt.savefig(Path_pictures+'/timing_AD_FD.pdf', bbox_inches='tight')
plt.show()