Prerequisite: weights and architecture of a pre-trained AE.

This Notebook builds the quasi-isometric embedding of a local chart of the torus (latent space) into $\mathbb{R}^3$ using a pyTorch optimizer.

Content:
1) Data loading. Weights of the AE and encoded latent space data are loaded
2) Constructing grid and triangulation.
3) Geodesic distances computation via Schauder basis approximation (or loading) + Embedded grid plotting in 3d. 4) Optimization loop. Embedding is constructed via optimization.
4) Plotting the embedded grid with trimesh. Saving the results.

# 1. Data loading

In [None]:
from tqdm.notebook import tqdm
import torch, yaml
import numpy as np
import ricci_regularization
import matplotlib.pyplot as plt
import matplotlib
import torch.nn as nn
import matplotlib.cm as cm
import matplotlib.tri as mtri

violent_saving = False
pretrained_AE_setting_name = 'MNIST_Setting_3'
Path_AE_config = f'../experiments/{pretrained_AE_setting_name}_config.yaml'

with open(Path_AE_config, 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)
Path_pictures = f'../experiments/{pretrained_AE_setting_name}'
print("pictures will be saved at", Path_pictures)

# Uploading the pretrained AE + creating directory for results


In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
print("Experiment results loaded successfully.")
# Loading data
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
validation_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=100)
print("Data loaders created successfully.")

# Loading the pre-tained AE
torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config)
print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)
torus_ae.cpu()
torus_ae.eval()
print("AE sent to cpu and eval mode activated successfully.")

In [None]:
if yaml_config["training_mode"]["compute_curvature"] is True:
    lambda_curv = yaml_config["loss_settings"]["lambda_curv"]
else:
    lambda_curv = 0.
print("This experiment uses curvature loss weight", lambda_curv)

In [None]:
fig = ricci_regularization.point_plot(encoder=torus_ae.encoder_to_lifting,data_loader=test_loader, show_title=False, config=yaml_config, batch_idx=0)
fig.show()

# 2. Constructing grid and triangulation.

In [None]:
# constructing the grid of points 
num_points = 10

x_left = -2#-torch.pi#-torch.pi #-2.0
y_bottom = -2#-torch.pi#-torch.pi #-2.0

x_size = -x_left*2#2*torch.pi # 4.

y_size = -y_bottom*2#2*torch.pi #4. # max shift of geodesics 

x_right = x_left + x_size
y_top = y_bottom + y_size

start_points_horizontal = torch.cat([torch.tensor([x_left,y_bottom + k]) for k in torch.linspace(0,y_size,num_points) ]).reshape(num_points,2)

horizontal_step = torch.tensor([x_size/(num_points-1),0])
grid = torch.cat([(start_points_horizontal + k * horizontal_step) for k in range(num_points)])
grid = grid.reshape(num_points,num_points,2)

# Triangulate parameter space to determine the triangles using the starting flat grid
u = np.linspace(x_left, x_right , endpoint=True, num=num_points)
v = np.linspace(y_bottom, y_top , endpoint=True, num=num_points)
u, v = np.meshgrid(u, v)
u, v = u.flatten(), v.flatten()

# Triangulate parameter space to determine the triangles
tri = mtri.Triangulation(u, v)

# 3. Geodesic distances computation via Schauder (or loading) + Grid plotting.

In [None]:
n_max = 7 # depth of Schauder basis
step_count = 100 # number of interpolation points on each geodesic
geodesic_solver = ricci_regularization.Schauder.NumericalGeodesics(n_max, step_count)
optimization_device = "cuda"
num_geodesic_optimization_epochs = 200
learning_rate = 0.01
torus_ae.to(optimization_device)
optimizer_info = {
    "name": "Adam",   # optimizer class name as string
    "args": {
        "lr": learning_rate    # learning rate
        # "betas": (0.9, 0.999)  # optional for Adam
    }
}

In [None]:
# Geodesics are computed minimizing "energy" of the curves in the latent space,
# No need to compute explicitly the pull-back metric, and thus the algorithm is fast
# Computing horizontal small geodesics
left_points_on_horizontal_segments = grid[:-1,:,:].reshape(-1,2)
right_points_on_horizontal_segments = grid[1:,:,:].reshape(-1,2)
bottom_points_on_horizontal_segments = grid[:,:-1,:].reshape(-1,2)
top_points_on_horizontal_segments = grid[:,1:,:].reshape(-1,2)

_, horizontal_geodesics_connecting_grid_nodes = geodesic_solver.computeGeodesicInterpolationBatch(
                                            generator=torus_ae.decoder_torus, 
                                            optimizer_info=optimizer_info,
                                            m1_batch=left_points_on_horizontal_segments,
                                            m2_batch=right_points_on_horizontal_segments, 
                                            epochs=num_geodesic_optimization_epochs,
                                            display_info="horizontal geodesic optimization",
                                            device=optimization_device)

_, vertical_geodesics_connecting_grid_nodes = geodesic_solver.computeGeodesicInterpolationBatch(
                                            generator=torus_ae.decoder_torus, 
                                            optimizer_info=optimizer_info,
                                            m1_batch=bottom_points_on_horizontal_segments,
                                            m2_batch=top_points_on_horizontal_segments, 
                                            epochs=num_geodesic_optimization_epochs,
                                            display_info="vertical geodesic optimization",
                                            device=optimization_device)

In [None]:
horizontal_geodesics_connecting_grid_nodes = horizontal_geodesics_connecting_grid_nodes.to(optimization_device)
horizontal_lengths = ricci_regularization.RiemannianKmeansTools.compute_lengths(curve=horizontal_geodesics_connecting_grid_nodes,
                                                           decoder=torus_ae.decoder_torus)

vertical_geodesics_connecting_grid_nodes = vertical_geodesics_connecting_grid_nodes.to(optimization_device)
vertical_lengths = ricci_regularization.RiemannianKmeansTools.compute_lengths(curve=vertical_geodesics_connecting_grid_nodes,
                                                           decoder=torus_ae.decoder_torus)

torch.save(horizontal_lengths,Path_pictures+'/horizontal_lengths.pt')
torch.save(vertical_lengths,Path_pictures+'/vertical_lengths.pt')

In [None]:
horizontal_lengths**2

In [None]:
horizontal_energies = ricci_regularization.RiemannianKmeansTools.compute_energy(curve=horizontal_geodesics_connecting_grid_nodes,
                                                           decoder=torus_ae.decoder_torus, reduction="none")

In [None]:
horizontal_energies

In [None]:
# initial embedding rect shape
with torch.no_grad():
    horizontal_lengths_reshaped = horizontal_lengths.reshape(num_points-1,num_points)
    vertical_lengths_reshaped = vertical_lengths.reshape(num_points,num_points-1)
    minimal_horizontal_length = horizontal_lengths_reshaped.sum(dim=0).min()
    minimal_vertical_length = vertical_lengths_reshaped.sum(dim=1).min()
edge = torch.min(minimal_horizontal_length,minimal_vertical_length)
edge = edge.cpu()
print("edge length:", edge.item())

## Embedded grid plotting in 3d. 

In [None]:
# plotting
def plot_triang(embedded_grid,additional_comment='',savefig=False, plot_number=0,lambda_curv = None,
                zmax = 0.5,zmin = -0.5, view_angle_horizontal = 0., view_angle_vertical = 30):
    # triple
    embedded_grid = embedded_grid.cpu().detach()
    x = embedded_grid[:,:,0].flatten()
    y = embedded_grid[:,:,1].flatten()
    z = embedded_grid[:,:,2].flatten()

    fig = plt.figure(figsize = (10,10),dpi=300)
    # Plot the surface.  The triangles in parameter space determine which x, y, z
    # points are connected by an edge.



    ax = fig.add_subplot(1, 2, 1, projection='3d')

    if lambda_curv == None:
        lambda_curv = "?"
    ax.set_title(rf"3d embedding of a grid on torus with $\lambda_{{\mathrm{{curv}}}} = ${lambda_curv}."+
                    additional_comment)
    
    p = ax.plot_trisurf(x, y, z, triangles=tri.triangles, cmap=cm.Spectral,vmax=z.max(),vmin=z.min())
    ax.set_zlim(zmin, zmax)
    ax.view_init(view_angle_horizontal, view_angle_vertical)

    # Add ticks for x, y, z
    xticks = np.linspace(x.min(), x.max(), 3)
    yticks = np.linspace(y.min(), y.max(), 3)
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_zticks([z.min(), 0., z.max()])

    ax.set_xticklabels([f"{val:.1f}" for val in xticks])
    ax.set_yticklabels([f"{val:.1f}" for val in yticks])

    # Colorbar
    cbar = fig.colorbar(p,shrink = 0.1)
    cbar.set_label("Height")
    cbar.set_ticks(ticks=[z.min(), 0., z.max()])
    cbar.set_ticklabels(ticklabels=[f'{z.min():.3f}','0', f'{z.max():.3f}'])

    if savefig == True:
        plt.savefig(Path_pictures+f"/3dembedding_plot_{plot_number}.pdf",format="pdf")

    plt.show()
    return tri

# 4. Optimization loop

In [None]:
def Quasi_isometric_embedding(epoch, params, embedded_grid, horizontal_lengths_reshaped, vertical_lengths_reshaped, num_iter=1, 
                         mode="diagnostic", loss_history=None, learning_rate=1e+1, lambda_curv = lambda_curv):
    """
    Optimize a 3D embedding of a grid on a manifold so that the embedding is as isometric as possible (w.r.t geodesic distance on the manifold).

    Adjusts the given embedding parameters so that the pairwise
    distances between adjacent grid points match the provided target 
    horizontal and vertical lengths. The optimization is performed using 
    stochastic gradient descent.

    Args:
        epoch (int): Current training epoch (used for plotting diagnostics).
        params (iterable): Learnable parameters (coordinates of the embeddings of the grid nodes) (PyTorch tensors) to optimize.
        embedded_grid (torch.Tensor): Initial embedding grid of shape (m, n, d),
            where m,n are grid dimensions and d is embedding dimension (=3).
        horizontal_lengths_reshaped (torch.Tensor): Target(geodesic) horizontal distances.
        vertical_lengths_reshaped (torch.Tensor): Target(geodesic) vertical distances.
        num_iter (int, optional): Number of optimization iterations to run. 
            Default is 1.
        mode (str, optional): If "diagnostic", prints loss and plots progress. 
            Otherwise runs silently. Default is "diagnostic".
        loss_history (list, optional): List for storing loss values across calls.
            If None, a new list will be created.
        learning_rate (float, optional): Learning rate for SGD optimizer. 
            Default is 1e+1.

    Returns:
        list: Updated `loss_history` containing loss values for each iteration.
    """
    if loss_history is None:
        loss_history = []  # List to store loss values

    # Use an optimizer (e.g., Adam)
    optimizer = torch.optim.SGD(params, lr=learning_rate)

    for iter_num in range(num_iter):
        optimizer.zero_grad()  # Zero gradients

        # Calculate Euclidean horizontal and vertical distances
        horizontal_grid_distances = (embedded_grid[1:, :, :] - embedded_grid[:-1, :, :]).norm(dim=-1)
        vertical_grid_distances = (embedded_grid[:, 1:, :] - embedded_grid[:, :-1, :]).norm(dim=-1)

        # Compute the losses
        loss_horizontal = (horizontal_lengths_reshaped - horizontal_grid_distances).square().mean()
        loss_vertical = (vertical_lengths_reshaped - vertical_grid_distances).square().mean()

        # Sum the losses
        loss = 1e2 * (loss_horizontal + loss_vertical)

        # Backpropagation: compute gradients
        loss.backward()

        # Update parameters
        optimizer.step()

        # Store the loss value
        loss_history.append(loss.item())

        
        # Print diagnostics if needed
        if mode == "diagnostic":
            print(f"Iteration #{iter_num + 1}, loss: {loss.item():.3f}")
        

    # Plot the loss values if in diagnostic mode
    if mode == "diagnostic":
        plot_triang(embedded_grid, plot_number=epoch+1,savefig=False,
                              additional_comment=f'\n After {len(loss_history)} iterations.', 
                              lambda_curv=lambda_curv)


        plt.figure()
        plt.plot(loss_history, label='Loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.title('Training Loss Over Time')
        plt.legend()
        plt.show()

    return loss_history

In [None]:
#setting the initial grid embedding
with torch.no_grad():
    stretched_grid = grid*edge/x_size
stretched_grid = stretched_grid.to(optimization_device)
#setteing small vertical perturbation
torch.manual_seed(666)
eps = 1e-2*edge/num_points
z_perturbation = eps * torch.randn(num_points,num_points,1, device=optimization_device)

# initiate optimization
embedded_grid = torch.cat((stretched_grid, z_perturbation),dim=2).requires_grad_()

# Define parameters (for example, weights to optimize)
params = [embedded_grid]


In [None]:
num_epochs = 10 
loss_history = []
learing_rate_embedding = 1e-2
#embedded_grid.plot_triang()
for epoch in range(num_epochs):
    # Run the training loop
    loss_history = Quasi_isometric_embedding(epoch, params, embedded_grid, horizontal_lengths_reshaped, 
                         vertical_lengths_reshaped,loss_history=loss_history, 
                         num_iter=100, mode="diagnostic",learning_rate=learing_rate_embedding)
    #embedded_grid.plot_triang(plot_number=epoch+1,savefig=True,
    #                          additional_comment=f'\n After {len(loss_history)} iterations.')

In [None]:

plt.plot(loss_history, label='Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.legend()
if violent_saving == True:
    plt.savefig(Path_pictures+"/3dembedding_optimization_history/loss.pdf")
plt.show()

In [None]:
optimized_grid = embedded_grid.detach()
# grid saving
if violent_saving == True:
    torch.save(optimized_grid,Path_pictures+f"/embedded_grid_{pretrained_AE_setting_name}.pt")
plot_triang(embedded_grid, plot_number="final",savefig=True,
                              additional_comment=f'\n After {len(loss_history)} iterations.', lambda_curv=lambda_curv)
    

# 5. Plotting the embedded grid with trimesh

In [None]:
import trimesh
import matplotlib.pyplot as plt
%matplotlib notebook

optimized_grid = optimized_grid.cpu()
# Create the trimesh object directly from vertices and faces
mesh = trimesh.Trimesh(vertices=optimized_grid.reshape(-1,3), faces=tri.triangles)

# Extract z-coordinates of the vertices
z_coords = mesh.vertices[:, 2]

# Normalize the z-coordinates to the range [0, 1]
z_min = z_coords.min()
z_max = z_coords.max()
normalized_z = (z_coords - z_min) / (z_max - z_min)

# Get a colormap from matplotlib
colormap = matplotlib.colormaps.get_cmap("jet")
#colormap = cm.get_cmap('rainbow')

# Map the normalized z-coordinates to colors using the colormap
colors = colormap(normalized_z)

# Convert colors to 0-255 range and RGBA format
vertex_colors = (colors[:, :4] * 255).astype(np.uint8)

# Assign these colors to the mesh's vertices
mesh.visual.vertex_colors = vertex_colors

# Create a scene with the mesh
scene = trimesh.Scene(mesh)

# Define the initial camera transformation matrix
# Here, we are setting the camera to look at the mesh from a specific angle
# and zoom out by translating the camera along the y-axis
zoom_out_factor = 12.0  # Increase this value to zoom out more
camera_transform = trimesh.transformations.translation_matrix([-0.2, -zoom_out_factor, -5.5])


# Set the camera transform in the scene
scene.camera_transform = camera_transform

# Display the mesh in a viewer window with the specified initial observation angle
scene.show()