# Data loading

In [None]:
from tqdm.notebook import tqdm
import torch
import numpy as np
import ricci_regularization
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
import stochman
from stochman.manifold import EmbeddedManifold
from stochman.curves import CubicSpline
import torch.nn as nn
import matplotlib.cm as cm

import matplotlib.tri as mtri

violent_saving = True

load_distances = True

#experiment_json = f'../experiments/MNIST_torus_AEexp34.json' # no curv_pen

experiment_json = f'../experiments/MNIST01_torus_AEexp7.json'
mydict = ricci_regularization.get_dataloaders_tuned_nn(Path_experiment_json=experiment_json)

In [None]:
torus_ae = mydict["tuned_neural_network"]
test_loader = mydict["test_loader"]
json_cofig = mydict["json_config"]
Path_pictures = json_cofig["Path_pictures"]
exp_number = json_cofig["experiment_number"]
curv_w = json_cofig["losses"]["curv_w"]

In [None]:
D = 784
k = json_cofig["dataset"]["parameters"]["k"]
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
feature_space_encoding_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1,D)))
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(k,"jet"))
plt.show()

# Computing distances with Stochman

In [None]:
from stochman.manifold import EmbeddedManifold
# geodesics are computed minimizing "energy" in the embedding of the manifold,
# So no need to compute the Pullback metric. and thus the algorithm is fast
class Autoencoder(EmbeddedManifold):
    def embed(self, c, jacobian = False):
        return torus_ae.decoder_torus(c)
#selected_labels = json_cofig["dataset"]["selected_labels"]
model = Autoencoder()

In [None]:
num_geodesics = 100

x_left = -2#-torch.pi#-torch.pi #-2.0
y_bottom = -2#-torch.pi#-torch.pi #-2.0

x_size = -x_left*2#2*torch.pi # 4.

y_size = -y_bottom*2#2*torch.pi #4. # max shift of geodesics 

x_right = x_left + x_size
y_top = y_bottom + y_size

starting_points = torch.cat([torch.tensor([x_left,y_bottom + k]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)
end_points = torch.cat([torch.tensor([x_right,y_bottom + k]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)

starting_points_vertical = torch.cat([torch.tensor([x_left +k, y_bottom]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)
end_points_vertical = torch.cat([torch.tensor([x_left + k, y_top]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)

In [None]:
horizontal_step = torch.tensor([x_size/(num_geodesics-1),0])
grid = torch.cat([(starting_points + k * horizontal_step) for k in range(num_geodesics)])
grid = grid.reshape(num_geodesics,num_geodesics,2)

In [None]:
grid.grad

In [None]:
if load_distances == False:
    #horizontal geodesics left to right
    c,success = model.connecting_geodesic(grid[:-1,:,:].reshape(-1,2),grid[1:,:,:].reshape(-1,2))
    c.plot()
    plt.scatter(grid[:,:,0], grid[:,:,1],c="blue")
    #plt.scatter(all_end_points_horizontal[-num_geodesics:,0],all_end_points_horizontal[-num_geodesics:,1],c="blue")
    plt.show()

In [None]:
if load_distances == False:
    # vertical geodesics
    c_vert,success = model.connecting_geodesic(grid[:,:-1,:].reshape(-1,2),grid[:,1:,:].reshape(-1,2))
    c_vert.plot()
    plt.scatter(grid[:,:,0], grid[:,:,1],c="blue")
    plt.show()

In [None]:
if load_distances == False:
    c_vert.plot()
    c.plot()
    plt.scatter(grid[:,:,0], grid[:,:,1],c="blue")
    plt.show

In [None]:
t = torch.linspace(0,1,20)

# Distances computation and shaping breakdown (skip this)

In [None]:
#grid[:-1,:,:]

In [None]:
# horizontal geodesics 
#c(t)[:,0,:]

In [None]:
# horizontal geodesics 
#c_vert(t)[:,0,:]

In [None]:
if load_distances == False:
    horizontal_lengths = model.curve_length(c(t)).detach()
    #print(horizontal_energies)
    vertical_lengths = model.curve_length(c_vert(t)).detach()
    print(vertical_lengths)
    print(vertical_lengths.reshape(num_geodesics,num_geodesics-1))
    #print(horizontal_energies.reshape(num_geodesics-1,num_geodesics))
    torch.save(horizontal_lengths,Path_pictures+'/horizontal_lengths.pt')
    torch.save(vertical_lengths,Path_pictures+'/vertical_lengths.pt')

In [None]:
# loading geod lengths
horizontal_lengths = torch.load(Path_pictures+'/horizontal_lengths.pt')
vertical_lengths = torch.load(Path_pictures+'/vertical_lengths.pt')

In [None]:
# initial embedding rect shape
with torch.no_grad():
    horizontal_lengths_reshaped = horizontal_lengths.reshape(num_geodesics-1,num_geodesics)
    vertical_lengths_reshaped = vertical_lengths.reshape(num_geodesics,num_geodesics-1)
    minimal_horizontal_length = horizontal_lengths_reshaped.sum(dim=0).min()
    minimal_vertical_length = vertical_lengths_reshaped.sum(dim=1).min()
    #horizontal_energies = horizontal_energies.reshape(num_geodesics-1,num_geodesics)
    
    #vertical_energies = vertical_energies.reshape(num_geodesics,num_geodesics-1)
    #print(horizontal_energies_normalized)

In [None]:
edge = torch.min(minimal_horizontal_length,minimal_vertical_length)
print("edge length:", edge.item())


In [None]:
#setting the initial grid embedding
with torch.no_grad():
    stretched_grid = grid*edge/x_size
#setteing small vertical perturbation
torch.manual_seed(0)
eps = 1e-2*edge/num_geodesics
z_perturbation = eps * torch.randn(num_geodesics,num_geodesics,1)

flat_grid = torch.cat((stretched_grid, z_perturbation),dim=2).requires_grad_()
#flat_grid.requires_grad_()

# The embedding nn.Module

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        flat_grid = torch.cat((stretched_grid, z_perturbation),dim=2).requires_grad_()
        self.emb = nn.Parameter(flat_grid)
    def plot(self):
        fig = plt.figure(figsize = (10,10),dpi=300)
        ax = fig.add_subplot(1, 2, 1, projection='3d')
        ax.set_title("Grid embedding")
        x = self.emb[:,:,0].detach()
        y = self.emb[:,:,1].detach()
        z = self.emb[:,:,2].detach()
        ax.set_zlim(-edge.item(),edge.item())
        ax.scatter(x,y,z)
        return
    def plot_triang(self,additional_comment='',savefig=False, plot_number=0):
        u = np.linspace(-2.0, 2.0 , endpoint=True, num=num_geodesics)
        v = np.linspace(-2.0, 2.0 , endpoint=True, num=num_geodesics)
        u, v = np.meshgrid(u, v)
        u, v = u.flatten(), v.flatten()
        %matplotlib inline


        # triple
        x = embedded_grid()[:,:,0].flatten().detach()
        y = embedded_grid()[:,:,1].flatten().detach()
        z = embedded_grid()[:,:,2].flatten().detach()

        # Triangulate parameter space to determine the triangles
        tri = mtri.Triangulation(u, v)

        fig = plt.figure(figsize = (10,10),dpi=300)
        # Plot the surface.  The triangles in parameter space determine which x, y, z
        # points are connected by an edge.


        ax = fig.add_subplot(1, 2, 1, projection='3d')

        ax.set_title(f"3d embedding of a grid on torus with $\lambda_{{\mathrm{{curv}}}} = ${curv_w}."+
                     additional_comment)
        
        p = ax.plot_trisurf(x, y, z, triangles=tri.triangles, cmap=cm.Spectral,vmax=z.max(),vmin=z.min())
        ax.set_zlim(-0.5, 0.5)
        ax.view_init(0, 30)

        cbar = fig.colorbar(p,shrink = 0.1)
        cbar.set_label("Height")
        cbar.set_ticks(ticks=[z.min(), 0., z.max()])
        cbar.set_ticklabels(ticklabels=[f'{z.min():.3f}','0', f'{z.max():.3f}'])

        if savefig == True:
            plt.savefig(Path_pictures+f"/3dembedding_optimization_history/plot{plot_number}.pdf",format="pdf")

        plt.show()
        return
    
    def forward(self):
        # Use the custom parameter
        return self.emb

In [None]:
# Create an instance of the model
embedded_grid = MyModel()
# Set up the optimizer
optimizer = torch.optim.SGD(embedded_grid.parameters(), lr=1e-1)

In [None]:
# checking initial guess adaquacy
"""
print("lenghts of horizontal geodesics:\n", horizontal_lengths_reshaped)
print("initial euclidean lengths in 3d embedding:\n",(embedded_grid()[1:,:,:] - embedded_grid()[:-1,:,:]).norm(dim = -1))

print("lenghts of vertical geodesics:\n", vertical_lengths_reshaped)
print("initial euclidean lengths in 3d embedding:\n",(embedded_grid()[:,1:,:] - embedded_grid()[:,:-1,:]).norm(dim = -1))
"""

# The training loop

In [None]:
def training_loop(epoch,num_iter=1, mode="diagnostic", loss_history = []):
    loss_values = loss_history  # List to store loss values
    for epoch in range(num_iter):
        optimizer.zero_grad()

        # Compute the embedded grid once to avoid redundant computations
        embedded = embedded_grid()

        # Calculate horizontal and vertical distances
        horizontal_grid_distances = (embedded[1:, :, :] - embedded[:-1, :, :]).norm(dim=-1)
        vertical_grid_distances = (embedded[:, 1:, :] - embedded[:, :-1, :]).norm(dim=-1)

        # Compute the losses
        loss_horizontal = (horizontal_lengths_reshaped - horizontal_grid_distances).square().mean()
        loss_vertical = (vertical_lengths_reshaped - vertical_grid_distances).square().mean()

        # Sum the losses
        loss = 1e2*(loss_horizontal + loss_vertical)

        # Backpropagation
        loss.backward()
        optimizer.step()

        # Store the loss value
        loss_values.append(loss.item())

        # Print diagnostics if needed
        #if mode == "diagnostic":
            #print(f"Epoch #{epoch + 1}, loss: {loss.item():.3f}")

    # Plot the loss values if in diagnostic mode
    if mode == "diagnostic":
        plt.figure()
        #plt.plot(range(1, (num_iter + 1)), loss_values, label='Loss')
        plt.plot(loss_values, label='Loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.title('Training Loss Over Time')
        plt.legend()
        plt.show()

    return loss_values


In [None]:
num_epochs = 10
loss_history = []
embedded_grid.plot_triang()
for epoch in range(num_epochs):
    loss_history = training_loop(epoch+1,num_iter=1000, loss_history=loss_history)
    embedded_grid.plot_triang(plot_number=epoch+1,savefig=True,
                              additional_comment=f'\n After {len(loss_history)} iterations.')

In [None]:
plt.plot(loss_history, label='Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.legend()
plt.savefig(Path_pictures+"/3dembedding_optimization_history/loss.pdf")
plt.show()

In [None]:
loss_history[-1]

# Plotting with the embedded grid with matplotlib triangulation

In [None]:
optimized_grid = embedded_grid().detach()

In [None]:
u = np.linspace(-2.0, 2.0 , endpoint=True, num=num_geodesics)
v = np.linspace(-2.0, 2.0 , endpoint=True, num=num_geodesics)
u, v = np.meshgrid(u, v)
u, v = u.flatten(), v.flatten()
%matplotlib inline


# triple
x = embedded_grid()[:,:,0].flatten().detach()
y = embedded_grid()[:,:,1].flatten().detach()
z = embedded_grid()[:,:,2].flatten().detach()

# Triangulate parameter space to determine the triangles
tri = mtri.Triangulation(u, v)

fig = plt.figure(figsize = (10,10),dpi=300)
# Plot the surface.  The triangles in parameter space determine which x, y, z
# points are connected by an edge.


ax = fig.add_subplot(1, 2, 1, projection='3d')

ax.set_title(f"3d embedding of a grid on torus with $\lambda_{{\mathrm{{curv}}}} = ${curv_w}.")
#linthresh = (z.abs().max()/100).item()

p = ax.plot_trisurf(x, y, z, triangles=tri.triangles, cmap=cm.Spectral,vmax=z.max(),vmin=z.min())
#norm = matplotlib.colors.SymLogNorm(linthresh = linthresh),vmax=z.max(),vmin=z.min())

ax.set_zlim(-0.5, 0.5)
ax.view_init(0, 30)
#ax.set_xticklabels('')

cbar = fig.colorbar(p,shrink = 0.1)
cbar.set_label("Height")
cbar.set_ticks(ticks=[z.min(), 0., z.max()])
cbar.set_ticklabels(ticklabels=[f'{z.min():.3f}','0', f'{z.max():.3f}'])

plt.savefig(Path_pictures+f"/torus3d_embedding_curw_w={curv_w}_num_geod={num_geodesics}.pdf",format="pdf")

plt.show()

# Plotting with the embedded grid with trimesh

In [None]:
import trimesh
%matplotlib notebook
import matplotlib.pyplot as plt
%matplotlib notebook


# Create the trimesh object directly from vertices and faces
triangulation_mesh = trimesh.Trimesh(vertices=optimized_grid.reshape(-1,3), faces=tri.triangles)

# Display the triangulated surface
triangulation_mesh.show()