This Notebook builds the quasi-isomorphic embedding of a torus latent space in $\mathbb{R}^3$ using a pyTorch optimizer.

1) Weights of the AE and encoded latent space data are loaded
2) A grid is constructed and geodesic distances between neigbor points are computed via Stochman
3) Optimizer is tuned. Embedding is constructed via optimization.
4) The embedding is plotted vith matplotlib and trimesh. The results are saved.

# Data loading

In [None]:
from tqdm.notebook import tqdm
import torch
import numpy as np
import ricci_regularization
import matplotlib.pyplot as plt
import matplotlib
import stochman
from stochman.manifold import EmbeddedManifold
import torch.nn as nn
import matplotlib.cm as cm
import matplotlib.tri as mtri

violent_saving = True

load_distances = True

#experiment_json = f'../experiments/MNIST_torus_AEexp34.json' # no curv_pen

experiment_json = f'../experiments/MNIST01_torus_AEexp7.json'
mydict = ricci_regularization.get_dataloaders_tuned_nn(Path_experiment_json = experiment_json)

In [None]:
torus_ae = mydict["tuned_neural_network"]
test_loader = mydict["test_loader"]
json_cofig = mydict["json_config"]
Path_pictures = json_cofig["Path_pictures"]
exp_number = json_cofig["experiment_number"]
curv_w = json_cofig["losses"]["curv_w"]

In [None]:
D = 784
k = json_cofig["dataset"]["parameters"]["k"]
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
feature_space_encoding_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1,D)))
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(k,"jet"))
plt.show()

# Construct the grid and its triangulation

In [None]:
# constructing the grid of points 
num_points = 100

x_left = -2#-torch.pi#-torch.pi #-2.0
y_bottom = -2#-torch.pi#-torch.pi #-2.0

x_size = -x_left*2#2*torch.pi # 4.

y_size = -y_bottom*2#2*torch.pi #4. # max shift of geodesics 

x_right = x_left + x_size
y_top = y_bottom + y_size

starting_points = torch.cat([torch.tensor([x_left,y_bottom + k]) for k in torch.linspace(0,y_size,num_points) ]).reshape(num_points,2)
end_points = torch.cat([torch.tensor([x_right,y_bottom + k]) for k in torch.linspace(0,y_size,num_points) ]).reshape(num_points,2)

starting_points_vertical = torch.cat([torch.tensor([x_left +k, y_bottom]) for k in torch.linspace(0,y_size,num_points) ]).reshape(num_points,2)
end_points_vertical = torch.cat([torch.tensor([x_left + k, y_top]) for k in torch.linspace(0,y_size,num_points) ]).reshape(num_points,2)

horizontal_step = torch.tensor([x_size/(num_points-1),0])
grid = torch.cat([(starting_points + k * horizontal_step) for k in range(num_points)])
grid = grid.reshape(num_points,num_points,2)

# Triangulate parameter space to determine the triangles using the starting flat grid
u = np.linspace(x_left, x_right , endpoint=True, num=num_points)
v = np.linspace(y_bottom, y_top , endpoint=True, num=num_points)
u, v = np.meshgrid(u, v)
u, v = u.flatten(), v.flatten()

# Triangulate parameter space to determine the triangles
tri = mtri.Triangulation(u, v)

# Geodesic distances computation via Stochman (or loading)

In [None]:
# geodesics are computed minimizing "energy" in the embedding of the manifold,
# So no need to compute the Pullback metric. and thus the algorithm is fast
class Autoencoder(EmbeddedManifold):
    def embed(self, c, jacobian = False):
        return torus_ae.decoder_torus(c)
#selected_labels = json_cofig["dataset"]["selected_labels"]
model = Autoencoder()

In [None]:
if load_distances == False:
    #horizontal geodesics left to right
    c,success = model.connecting_geodesic(grid[:-1,:,:].reshape(-1,2),grid[1:,:,:].reshape(-1,2))
    c.plot()
    plt.scatter(grid[:,:,0], grid[:,:,1],c="blue")
    #plt.scatter(all_end_points_horizontal[-num_points:,0],all_end_points_horizontal[-num_points:,1],c="blue")
    plt.show()

In [None]:
if load_distances == False:
    # vertical geodesics
    c_vert,success = model.connecting_geodesic(grid[:,:-1,:].reshape(-1,2),grid[:,1:,:].reshape(-1,2))
    c_vert.plot()
    plt.scatter(grid[:,:,0], grid[:,:,1],c="blue")
    plt.show()

In [None]:
if load_distances == False:
    c_vert.plot()
    c.plot()
    plt.scatter(grid[:,:,0], grid[:,:,1],c="blue")
    plt.show

In [None]:
if load_distances == True:
    # loading geod lengths
    horizontal_lengths = torch.load(Path_pictures+'/horizontal_lengths.pt')
    vertical_lengths = torch.load(Path_pictures+'/vertical_lengths.pt')
else:
    t = torch.linspace(0,1,20)
    horizontal_lengths = model.curve_length(c(t)).detach()
    vertical_lengths = model.curve_length(c_vert(t)).detach()
    torch.save(horizontal_lengths,Path_pictures+'/horizontal_lengths.pt')
    torch.save(vertical_lengths,Path_pictures+'/vertical_lengths.pt')

In [None]:
# initial embedding rect shape
with torch.no_grad():
    horizontal_lengths_reshaped = horizontal_lengths.reshape(num_points-1,num_points)
    vertical_lengths_reshaped = vertical_lengths.reshape(num_points,num_points-1)
    minimal_horizontal_length = horizontal_lengths_reshaped.sum(dim=0).min()
    minimal_vertical_length = vertical_lengths_reshaped.sum(dim=1).min()
edge = torch.min(minimal_horizontal_length,minimal_vertical_length)
print("edge length:", edge.item())

In [None]:
#setting the initial grid embedding
with torch.no_grad():
    stretched_grid = grid*edge/x_size
#setteing small vertical perturbation
torch.manual_seed(666)
eps = 1e-2*edge/num_points
z_perturbation = eps * torch.randn(num_points,num_points,1)

# Simplified optimization avoiding nn.module

In [None]:
# plotting
def plot_triang(embedded_grid,additional_comment='',savefig=False, plot_number=0):
    # triple
    x = embedded_grid[:,:,0].flatten().detach()
    y = embedded_grid[:,:,1].flatten().detach()
    z = embedded_grid[:,:,2].flatten().detach()

    fig = plt.figure(figsize = (10,10),dpi=300)
    # Plot the surface.  The triangles in parameter space determine which x, y, z
    # points are connected by an edge.


    ax = fig.add_subplot(1, 2, 1, projection='3d')

    ax.set_title(f"3d embedding of a grid on torus with $\lambda_{{\mathrm{{curv}}}} = ${curv_w}."+
                    additional_comment)
    
    p = ax.plot_trisurf(x, y, z, triangles=tri.triangles, cmap=cm.Spectral,vmax=z.max(),vmin=z.min())
    ax.set_zlim(-0.5, 0.5)
    ax.view_init(0, 30)

    cbar = fig.colorbar(p,shrink = 0.1)
    cbar.set_label("Height")
    cbar.set_ticks(ticks=[z.min(), 0., z.max()])
    cbar.set_ticklabels(ticklabels=[f'{z.min():.3f}','0', f'{z.max():.3f}'])

    if savefig == True:
        plt.savefig(Path_pictures+f"/3dembedding_optimization_history/plot{plot_number}.pdf",format="pdf")

    plt.show()
    return tri

In [None]:
def training_loop_simple(epoch, params, embedded_grid, horizontal_lengths_reshaped, vertical_lengths_reshaped, num_iter=1, 
                         mode="diagnostic", loss_history=None, learning_rate=1e+1):
    if loss_history is None:
        loss_history = []  # List to store loss values

    # Use an optimizer (e.g., Adam)
    optimizer = torch.optim.SGD(params, lr=learning_rate)

    for iter_num in range(num_iter):
        optimizer.zero_grad()  # Zero gradients

        # Compute the embedded grid once to avoid redundant computations
        embedded = embedded_grid

        # Calculate horizontal and vertical distances
        horizontal_grid_distances = (embedded[1:, :, :] - embedded[:-1, :, :]).norm(dim=-1)
        vertical_grid_distances = (embedded[:, 1:, :] - embedded[:, :-1, :]).norm(dim=-1)

        # Compute the losses
        loss_horizontal = (horizontal_lengths_reshaped - horizontal_grid_distances).square().mean()
        loss_vertical = (vertical_lengths_reshaped - vertical_grid_distances).square().mean()

        # Sum the losses
        loss = 1e2 * (loss_horizontal + loss_vertical)

        # Backpropagation: compute gradients
        loss.backward()

        # Update parameters
        optimizer.step()

        # Store the loss value
        loss_history.append(loss.item())

        """
        # Print diagnostics if needed
        if mode == "diagnostic":
            print(f"Iteration #{iter_num + 1}, loss: {loss.item():.3f}")
        """

    # Plot the loss values if in diagnostic mode
    if mode == "diagnostic":
        plot_triang(embedded_grid, plot_number=epoch+1,savefig=True,
                              additional_comment=f'\n After {len(loss_history)} iterations.')


        plt.figure()
        plt.plot(loss_history, label='Loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.title('Training Loss Over Time')
        plt.legend()
        plt.show()

    return loss_history



In [None]:
# initiate optimization
embedded_grid = torch.cat((stretched_grid, z_perturbation),dim=2).requires_grad_()

# Define parameters (for example, weights to optimize)
params = [embedded_grid]


In [None]:
num_epochs = 5 
loss_history = []
#embedded_grid.plot_triang()
for epoch in range(num_epochs):
    # Run the training loop
    loss_history = training_loop_simple(epoch, params, embedded_grid, horizontal_lengths_reshaped, 
                         vertical_lengths_reshaped,loss_history=loss_history, 
                         num_iter=100, mode="diagnostic",learning_rate=1e0)
    #embedded_grid.plot_triang(plot_number=epoch+1,savefig=True,
    #                          additional_comment=f'\n After {len(loss_history)} iterations.')

In [None]:

plt.plot(loss_history, label='Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.legend()
if violent_saving == True:
    plt.savefig(Path_pictures+"/3dembedding_optimization_history/loss.pdf")
plt.show()

In [None]:
optimized_grid = embedded_grid.detach()
# grid saving
if violent_saving == True:
    torch.save(optimized_grid,Path_pictures+"/embedded_grid.pt")

# Plotting with the embedded grid with trimesh

In [None]:
import trimesh
import matplotlib.pyplot as plt
%matplotlib notebook

# Create the trimesh object directly from vertices and faces
mesh = trimesh.Trimesh(vertices=optimized_grid.reshape(-1,3), faces=tri.triangles)

# Extract z-coordinates of the vertices
z_coords = mesh.vertices[:, 2]

# Normalize the z-coordinates to the range [0, 1]
z_min = z_coords.min()
z_max = z_coords.max()
normalized_z = (z_coords - z_min) / (z_max - z_min)

# Get a colormap from matplotlib
colormap = matplotlib.colormaps.get_cmap("jet")
#colormap = cm.get_cmap('rainbow')

# Map the normalized z-coordinates to colors using the colormap
colors = colormap(normalized_z)

# Convert colors to 0-255 range and RGBA format
vertex_colors = (colors[:, :4] * 255).astype(np.uint8)

# Assign these colors to the mesh's vertices
mesh.visual.vertex_colors = vertex_colors

# Create a scene with the mesh
scene = trimesh.Scene(mesh)

# Define the initial camera transformation matrix
# Here, we are setting the camera to look at the mesh from a specific angle
# and zoom out by translating the camera along the y-axis
zoom_out_factor = 12.0  # Increase this value to zoom out more
camera_transform = trimesh.transformations.translation_matrix([-0.2, -zoom_out_factor, -5.5])


# Set the camera transform in the scene
scene.camera_transform = camera_transform

# Display the mesh in a viewer window with the specified initial observation angle
scene.show()