 This notebook compares two Torus embeddings in order to visualize the effect of curvature regularization

In [None]:
import torch, json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.tri as mtri

In [None]:
# load the experiment without curvature regularizarion
experiment_json_not_regularized = f'../experiments/MNIST01_torus_AEexp7.json'
with open(experiment_json_not_regularized) as json_file:
    json_cofig_not_regularized = json.load(json_file)

Path_pictures_not_regularized = json_cofig_not_regularized["Path_pictures"]
curv_w_not_regularized = json_cofig_not_regularized["losses"]["curv_w"]

# load the experiment with curvature regularizarion
experiment_json_regularized = f'../experiments/MNIST01_torus_AEexp8.json'
with open(experiment_json_regularized) as json_file:
    json_cofig_regularized = json.load(json_file)

Path_pictures_regularized = json_cofig_regularized["Path_pictures"]
curv_w_regularized = json_cofig_regularized["losses"]["curv_w"]

# loading grids
embedded_grid_not_regularized = torch.load(Path_pictures_not_regularized+'/embedded_grid.pt')
embedded_grid_regularized = torch.load(Path_pictures_regularized+'/embedded_grid.pt')

In [None]:
# Number of points. Check if grid sizes are the same
if embedded_grid_not_regularized.shape[0] == embedded_grid_regularized.shape[0]:
    num_points = embedded_grid_not_regularized.shape[0]
else:
    raise ValueError("An error occurred: different grid sizes")


# Create (u,v) parametrisation to build triangulation
u = embedded_grid_regularized[:,0,0]
v = embedded_grid_regularized[0,:,1]
u, v = np.meshgrid(u, v)
u, v = u.flatten(), v.flatten()

# Create the triangulation
tri = mtri.Triangulation(u, v)

# Create the figure and subplots
fig = plt.figure(figsize=(15, 7), dpi=300)
gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 0.05], wspace=0.3)
ax1 = fig.add_subplot(gs[0, 0], projection='3d')
ax2 = fig.add_subplot(gs[0, 1], projection='3d')
cax = fig.add_subplot(gs[0, 2])

# List of grids
grids = [embedded_grid_not_regularized, embedded_grid_regularized]

# List of axes
axes = [ax1, ax2]

# Calculate global min and max of z coordinates
z_min = float('inf')
z_max = float('-inf')

for embedded_grid in grids:
    z = embedded_grid[:, :, 2].flatten().detach().numpy()
    z_min = min(z_min, z.min())
    z_max = max(z_max, z.max())

# Loop over grids
for ax, embedded_grid in zip(axes, grids):
    # Extract coordinates
    x = embedded_grid[:, :, 0].flatten().detach().numpy()
    y = embedded_grid[:, :, 1].flatten().detach().numpy()
    z = embedded_grid[:, :, 2].flatten().detach().numpy()

    # Plot the surface
    p = ax.plot_trisurf(x, y, z, triangles=tri.triangles, cmap=cm.jet, vmax=z_max, vmin=z_min)
    #ax.scatter(x, y, z,s = 0.5,alpha=0.5, zorder = 10)
    ax.set_zlim(-0.5, 0.5)
    ax.view_init(30, 30)

# Set titles for each subplot
ax1.set_title(f"3D embedding of grid 1 on torus with λ_curv = {curv_w_not_regularized}.")
ax2.set_title(f"3D embedding of grid 2 on torus with λ_curv = {curv_w_regularized}.")

# Add common color bar
cbar = fig.colorbar(p, cax=cax)
cbar.set_label("Height")
cbar.set_ticks([z_min, 0., z_max])
cbar.set_ticklabels([f'{z_min:.3f}', '0', f'{z_max:.3f}'])

plt.show()