# Loading weights

In [None]:
from tqdm.notebook import tqdm
import torch
import ricci_regularization
import matplotlib.pyplot as plt
import matplotlib
import stochman
from stochman.manifold import EmbeddedManifold
from stochman.curves import CubicSpline

experiment_json = f'../experiments/MNIST_torus_AEexp34.json' # no curv_pen

#experiment_json = f'../experiments/MNIST01_torus_AEexp7.json'
mydict = ricci_regularization.get_dataloaders_tuned_nn(Path_experiment_json=experiment_json)

In [None]:
torus_ae = mydict["tuned_neural_network"]
test_loader = mydict["test_loader"]
json_cofig = mydict["json_config"]
Path_pictures = json_cofig["Path_pictures"]
exp_number = json_cofig["experiment_number"]
curv_w = json_cofig["losses"]["curv_w"]

In [None]:
D = 784
k = json_cofig["dataset"]["parameters"]["k"]
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
feature_space_encoding_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1,D)))
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(k,"jet"))
plt.show()

# Geodesic bvp on latent space with Stochman 

In [None]:
from stochman.manifold import EmbeddedManifold

In [None]:
# geodesics are computed minimizing "energy" in the embedding of the manifold,
# So no need to compute the Pullback metric. and thus the algorithm is fast
class Autoencoder(EmbeddedManifold):
    def embed(self, c, jacobian = False):
        return torus_ae.decoder_torus(c)

In [None]:
#selected_labels = json_cofig["dataset"]["selected_labels"]

In [None]:
model = Autoencoder()
torch.manual_seed(0)

t = torch.linspace(0.,1.,100)

# p0 and p1 can be chosen anywhere on R^2 with 2\pi periodic metric 
p0 = torch.tensor([-1.4,-1.]) #+11*torch.pi
p1 = torch.tensor([-1.5,-1.]) #+ 11*torch.pi
# find a pair of points with different labels (first in test loader) 
#p0 = encoded_points[torch.where(color_array==selected_labels[0])][0].detach()
#p1 = encoded_points[torch.where(color_array==selected_labels[1])][0].detach()
print(f"start:{p0}, \n end {p1}")
c, success = model.connecting_geodesic(p0, p1) # here the parameter t in c(t)should be a torch.tensor
print("Success:",success.item(),"\n length",model.curve_length(c(t)).item())

In [None]:
points_on_geodesic = c(t).detach()

In [None]:
straight_line = CubicSpline(p0,p1)
straight_line_points2plot = straight_line(t).detach()

geod_length = model.curve_length(c(t)).item()
straight_line_length = model.curve_length(straight_line(t)).item()

In [None]:
plt.title("Geodesic bvp")
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(k,"jet"))
plt.plot(points_on_geodesic[:,0],points_on_geodesic[:,1],c="orange",label=f"geodesic.\nLength:{geod_length:.3f}")
plt.plot(straight_line_points2plot[:,0],straight_line_points2plot[:,1],c="green",label=f"straight.\nLength:{straight_line_length:.3f}")
plt.legend()
#plt.savefig(f'{Path_pictures}/geodesic_vs_straight.pdf',bbox_inches='tight',format='pdf')
plt.show()

# Geodesic shooting

In [None]:
def geod_vect(x,dxdt):
    u = x
    v = dxdt
    dudt = v
    n = v.shape[0]
    dvdt = torch.zeros(n,2)
    Ch_at_u = ricci_regularization.Ch_jacfwd_vmap(u,function=torus_ae.decoder_torus,device=torch.device("cpu"))
    for l in range(2):
        for i in range(2):
            for j in range(2):
                dvdt[:,l] -= Ch_at_u[:,l,i,j] * v[:,i] * v[:,j]
    return dudt, dvdt

def rungekutta_vect(f, initial_point_array, initial_speed_array, t, args=()):
    n = len(t)
    #num_geodesics = initial_point_array.shape[0]
    x = torch.zeros((n, *tuple(initial_point_array.shape)))
    dxdt = torch.zeros((n, *tuple(initial_speed_array.shape)))
    x[0] = initial_point_array
    dxdt[0] = initial_speed_array
    #with torch.no_grad():
    #    curve_length = torch.zeros(num_geodesics)
    for i in range(n - 1):
        dudt, dvdt = f(x[i], dxdt[i], *args)
        
        #print()
        x[i+1] = x[i] + (t[i+1] - t[i])*dudt
        dxdt[i+1] = dxdt[i] + (t[i+1] - t[i])*dvdt
        
        
        #dxdt_length = torch.sqrt(((dxdt[i].unsqueeze(-2))@metric@(dxdt[i].unsqueeze(-1))).squeeze())
        #curve_length =+ dxdt_length
    return x, dxdt
    #return x, dxdt,curve_length
# x is of shape [num_grid_points,num_geodesics,dimension=2]

In [None]:
from torch.nn.functional import normalize

num_approximation_points = 101 # how good the approximation is
max_parameter_value = 1 #3 # how far to go
time_array = torch.linspace(0, max_parameter_value, num_approximation_points)

num_geodesics = 200

#starting_points = torch.tensor([-2.,0.]).repeat(num_geodesics,1) # common starting point
starting_points = p0.repeat(num_geodesics,1) # common starting point

maxtangent = 2 # max slope of geodesics 
starting_speeds = torch.cat([torch.tensor([1.,0. + k]) for k in torch.linspace(-maxtangent,maxtangent,num_geodesics) ]).reshape(num_geodesics,2)
#starting_speeds = c.deriv(torch.zeros(1)).reshape(num_geodesics,2)
#starting_speeds = normalize(starting_speeds) #make norms of all speeds equal

geodesics2plot,_ = rungekutta_vect(f=geod_vect,initial_point_array=starting_points,
                                   initial_speed_array=starting_speeds,t=time_array)
geodesics2plot = geodesics2plot.detach()

In [None]:
#scalar_curvature_on_geodesics = ricci_regularization.Sc_jacfwd_vmap(geodesics2plot.reshape(-1,2),function=torus_ae.decoder_torus,device=torch.device("cpu"))
#scalar_curvature_on_geodesics = scalar_curvature_on_geodesics.reshape(num_approximation_points,num_geodesics).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(2,"jet"))
for i in range(num_geodesics):
    #plt.scatter(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c=time_array,cmap="jet")
    #plt.scatter(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c=scalar_curvature_on_geodesics[:,i],cmap="viridis",norm=matplotlib.colors.SymLogNorm(linthresh=1e-2))
    plt.plot(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c="black")
#plt.colorbar(label="scalar curvature along geodesics")
plt.show()

# Geodesic bvp for several geodesics

In [None]:
num_geodesics = 10

x_left = -torch.pi/2#-torch.pi #-2.0
y_bottom = -torch.pi/2#-torch.pi #-2.0

x_size = -x_left*2#2*torch.pi # 4.
y_size = -y_bottom*2#2*torch.pi #4. # max shift of geodesics 

x_right = x_left + x_size
y_top = y_bottom + y_size

starting_points = torch.cat([torch.tensor([x_left,y_bottom + k]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)
end_points = torch.cat([torch.tensor([x_right,y_bottom + k]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)

starting_points_vertical = torch.cat([torch.tensor([x_left +k, y_bottom]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)
end_points_vertical = torch.cat([torch.tensor([x_left + k, y_top]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)

In [None]:
horizontal_geodesicts, success = model.connecting_geodesic(starting_points, end_points)
vertical_geodesicts, success = model.connecting_geodesic(starting_points_vertical, end_points_vertical)
print("Success:", success.item())

In [None]:
num_approximation_points = 20
t = torch.linspace(0,1,num_approximation_points)

In [None]:
geodesics2plot_horisontal = horizontal_geodesicts(t).detach()
geodesics2plot_vertical = vertical_geodesicts(t).detach()

In [None]:
scalar_curvature_on_geodesics_bvp = ricci_regularization.Sc_jacfwd_vmap(geodesics2plot_horisontal.reshape(-1,2),function=torus_ae.decoder_torus,device=torch.device("cpu"))
scalar_curvature_on_geodesics_bvp = scalar_curvature_on_geodesics_bvp.reshape(num_geodesics,num_approximation_points).detach()

In [None]:
plt.figure(dpi=300)
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(k,"jet"))
plt.title(f"Geodesic grid on MNIST with {k} labels with $\lambda_{{\mathrm{{curv}}}}={curv_w}$")
for i in range(num_geodesics):
    plt.plot(geodesics2plot_horisontal[i,:,0],geodesics2plot_horisontal[i,:,1],c="black")
    plt.plot(geodesics2plot_vertical[i,:,0],geodesics2plot_vertical[i,:,1],c="black")
plt.xlim(-torch.pi,torch.pi)
plt.ylim(-torch.pi,torch.pi)
plt.savefig(f'{Path_pictures}/multiple_geodesics_exp{exp_number}.pdf',bbox_inches='tight',format='pdf')
plt.show()

# Logarithmic map

$log_{p_0} (p_1) = v, \ $ where $\gamma$ is the geodesic, s.t. $\gamma(0) = p_0$ and $\gamma(1) = p_1$ and $v = \dot \gamma(0), \ \|v\| = $ length of the geodesic. 

In [None]:
geod,_ = model.connecting_geodesic(p0,p1)

In [None]:
model.logmap(p0.unsqueeze(0).detach(),p1.unsqueeze(0).detach())

In [None]:
geod.deriv(torch.zeros(1))

In [None]:
p2 = torch.rand(2)
p3 = torch.rand(2)
#p2 = encoded_points[torch.where(color_array==selected_labels[0])][1].detach()
#p3 = encoded_points[torch.where(color_array==selected_labels[1])][1].detach()

In [None]:
geod,_ = model.connecting_geodesic(p0,p1)
new_geod,_ = model.connecting_geodesic(p2,p3)

In [None]:
num_points_on_new_geod = 20
points_on_geod = geod(torch.linspace(0,1,num_points_on_new_geod))
points_on_geod = points_on_geod.detach()

points_on_new_geod = new_geod(torch.linspace(0,1,num_points_on_new_geod))
points_on_new_geod = points_on_new_geod.detach()

In [None]:
geod_p0_p1_log_at_p0 = model.logmap(p0.repeat(num_points_on_new_geod,1).detach(),points_on_geod)
geod_p2_p3_log_at_p0 = model.logmap(p0.repeat(num_points_on_new_geod,1).detach(),points_on_new_geod)

In [None]:
p0

In [None]:
p0_logmap = model.logmap(p0.unsqueeze(0), p0.unsqueeze(0)).squeeze()
p1_logmap = model.logmap(p0.unsqueeze(0), p1.unsqueeze(0)).squeeze()
p2_logmap = model.logmap(p0.unsqueeze(0), p2.unsqueeze(0)).squeeze()
p3_logmap = model.logmap(p0.unsqueeze(0), p3.unsqueeze(0)).squeeze()

In [None]:
plt.title("Two geodesics connecting $p_0$ and $p_1$, $p_2$ and $p_3$ \n after log map with base point $p_0$")
plt.scatter(geod_p0_p1_log_at_p0[:,0],geod_p0_p1_log_at_p0[:,1],c="red")
plt.plot(geod_p0_p1_log_at_p0[:,0],geod_p0_p1_log_at_p0[:,1],c="red")
plt.scatter(geod_p2_p3_log_at_p0[:,0],geod_p2_p3_log_at_p0[:,1])
plt.plot(geod_p2_p3_log_at_p0[:,0],geod_p2_p3_log_at_p0[:,1])
plt.scatter(p0_logmap[0],p0_logmap[1],marker = "*",c="blue",s = 120,label = f"Base point $p_0$",zorder = 3)
plt.scatter(p1_logmap[0],p1_logmap[1],marker = "*",c="green",s = 120,label = f"$p_1$",zorder = 3)
plt.scatter(p2_logmap[0],p2_logmap[1],marker = "*",c="magenta",s = 120,label = f"$p_2$",zorder = 3)
plt.scatter(p3_logmap[0],p3_logmap[1],marker = "*",c="yellow",s = 120,label = f"$p_3$",zorder = 3)
plt.legend()
plt.show()

# Grid of geodesics in log map

In [None]:
# choosing base point for logarithmic map
base_point = torch.tensor([0.,0.]).unsqueeze(0)

In [None]:
num_geodesics = 10
num_approximation_points = 20

horizontal_geodesics2plot_logmap = model.logmap(base_point.repeat(num_geodesics*num_approximation_points,1),geodesics2plot_horisontal.reshape(-1,2))
vertical_geodesics2plot_logmap = model.logmap(base_point.repeat(num_geodesics*num_approximation_points,1),geodesics2plot_vertical.reshape(-1,2))

In [None]:
horizontal_geodesics2plot_logmap = horizontal_geodesics2plot_logmap.reshape(num_geodesics,num_approximation_points,-1)
vertical_geodesics2plot_logmap = vertical_geodesics2plot_logmap.reshape(num_geodesics,num_approximation_points,-1)

In [None]:
left_base = base_point - torch.tensor([x_size/2,0.])
right_base = base_point + torch.tensor([x_size/2,0.])
top_base = base_point + torch.tensor([0., y_size/2])
bottom_base = base_point - torch.tensor([0., y_size/2])
print(f"Left {left_base}, top {top_base}, right {right_base}, bottom {bottom_base}")

In [None]:
base_geod,_ = model.connecting_geodesic(base_point.repeat(4,1), torch.cat((left_base,top_base,right_base,bottom_base)))

In [None]:
base_geod_points = base_geod(t).detach()

In [None]:
base_geod2plot_logmap = model.logmap(base_point.repeat(4*num_approximation_points,1),base_geod_points.reshape(-1,2))

In [None]:
base_geod2plot_logmap = base_geod2plot_logmap.reshape(4,num_approximation_points,2)
base_point_x = base_point.squeeze()[0]
base_point_y = base_point.squeeze()[1]

In [None]:
plt.figure(dpi=300)
plt.title(f"Geodesic grid in $T_{{p_0}} M$ after log map with base point $p_0$, \n experiment # {exp_number} with $\lambda_{{\mathrm{{curv}}}}={curv_w}$")
for i in range(num_geodesics):
    plt.plot(horizontal_geodesics2plot_logmap[i,:,0],horizontal_geodesics2plot_logmap[i,:,1],c="orange")
    plt.plot(vertical_geodesics2plot_logmap[i,:,0],vertical_geodesics2plot_logmap[i,:,1],c="orange")
    plt.scatter(horizontal_geodesics2plot_logmap[i,:,0], horizontal_geodesics2plot_logmap[i,:,1])
    plt.scatter(vertical_geodesics2plot_logmap[i,:,0], vertical_geodesics2plot_logmap[i,:,1],c="black")
plt.plot(base_geod2plot_logmap[3,:,0],base_geod2plot_logmap[3,:,1],c="red",label="Geodesics through base point")
for j in range(3):
    plt.plot(base_geod2plot_logmap[j,:,0],base_geod2plot_logmap[j,:,1],c="red")
plt.scatter(base_point[:,0],base_point[:,1],marker = "*",c="blue",s = 120,label = f"Base point $p_0$ = ({base_point_x}, {base_point_y})",zorder = 3)
plt.legend()
plt.savefig(f'{Path_pictures}/geodesic_grid_logmap_exp{exp_number}.pdf',bbox_inches='tight',format='pdf')
plt.show()

# Reconstructing geodesics

In [None]:
points,_ = rungekutta_vect(geod_vect,base_point.repeat(200,1),horizontal_geodesics2plot_logmap.reshape(200,2),t=time_array)

In [None]:
points = points.reshape(-1,num_geodesics,num_approximation_points,2)
points = points.detach()

In [None]:
for i in range(num_geodesics):
    plt.plot(points[-1,i,:,0], points[-1,i,:,1],c="blue")
    plt.plot(geodesics2plot_horisontal[i,:,0],geodesics2plot_horisontal[i,:,1],c="orange")
plt.show()

# straight lines in logmap

In [None]:
num_points = 20
vectors = horizontal_geodesics2plot_logmap[:,-1,:]-horizontal_geodesics2plot_logmap[:,0,:]

In [None]:
t = torch.linspace(0,1,num_points)

In [None]:
straight_lines = horizontal_geodesics2plot_logmap[:,0,:] + torch.tensordot(t.unsqueeze(0),vectors,dims=0).reshape(num_points,num_geodesics,2)

In [None]:
for i in range(num_geodesics):
    plt.plot(horizontal_geodesics2plot_logmap[i,:,0],horizontal_geodesics2plot_logmap[i,:,1],c="orange")
    plt.plot(straight_lines[:,i,0],straight_lines[:,i,1],c="black")

In [None]:
geodesic_lengths = model.curve_length(horizontal_geodesicts(t))
print(f"geodesic lengths: {geodesic_lengths}")

In [None]:
exp_base_point_x_i,_ = rungekutta_vect(f=geod_vect,initial_point_array=base_point.repeat(num_points*num_geodesics,1),
                    initial_speed_array=straight_lines.reshape(-1,2), t=time_array)

In [None]:
y = exp_base_point_x_i[-1].reshape(num_points, num_geodesics, 2).detach()
y.shape

In [None]:
plt.title("Geodesics and straight lines in logmap")
for i in range(num_geodesics):
    plt.scatter(y[:,i,0],y[:,i,1])
    #plt.plot(straight_lines[:,i,0],straight_lines[:,i,1],c="black")

In [None]:
geodesics_y_i,_ = model.connecting_geodesic(y[:-1].reshape(-1,2), y[1:].reshape(-1,2))

In [None]:
log_straight_lines_length_approx = model.curve_length(geodesics_y_i(t)).reshape(num_points-1,num_geodesics).sum(dim = 0)
print(f"straight_lines_length_approx: {log_straight_lines_length_approx}")

In [None]:
geodesic_lengths/log_straight_lines_length_approx

In [None]:
geod_length_ratio = (geodesic_lengths/log_straight_lines_length_approx).mean().item()
print(f"geodesic length ratio:\n{geod_length_ratio}")

In [None]:
dict = {"geod_length_ratio":geod_length_ratio}

In [None]:
import json
with open(f'{Path_pictures}/geodesic_length_ratio_exp{exp_number}.json', 'w') as json_file:
    json.dump(dict, json_file, indent=4)

# Random multiple geodesics and different logmap base points

In [None]:
torch.manual_seed(0)

num_approximation_points = 20
t = torch.linspace(0,1,num_approximation_points)

num_geodesics = 7
#selecting geodesic start/end points and log map base points randomly
random_starting_points = torch.pi*(torch.rand(num_geodesics,2)-0.5)
random_end_points = torch.pi*(torch.rand(num_geodesics,2)-0.5)
base_points = torch.pi*(torch.rand(num_geodesics,2)-0.5)

random_geodesicts, success = model.connecting_geodesic(random_starting_points, random_end_points)
random_geodesicts2plot = random_geodesicts(t).detach()

In [None]:
random_geodesicts2plot.shape

In [None]:
plt.title("Random geodesics and basepoints")
for i in range(num_geodesics):
    plt.scatter(base_points[i,0],base_points[i,1])
    plt.plot(random_geodesicts2plot[i,:,0],random_geodesicts2plot[i,:,1])
plt.xlim(-torch.pi, torch.pi)
plt.ylim(-torch.pi, torch.pi)
plt.show()

In [None]:
random_geodesics2plot_logmap = model.logmap(base_points.repeat(1,num_approximation_points).reshape(num_approximation_points*num_geodesics,2),random_geodesicts2plot.reshape(num_approximation_points*num_geodesics,2))
random_geodesics2plot_logmap = random_geodesics2plot_logmap.reshape(num_geodesics,num_approximation_points,-1)

In [None]:
model.logmap(base_points,base_points)

In [None]:
random_geodesic_lengths = model.curve_length(random_geodesicts(t))
print(f"geodesic lengths: {random_geodesic_lengths}")

In [None]:
num_points = 10 #number of intermediate poits on an image of a  a geodesic in log map

vectors = random_geodesics2plot_logmap[:,-1,:]-random_geodesics2plot_logmap[:,0,:]
t = torch.linspace(0,1,num_points)
straight_lines = random_geodesics2plot_logmap[:,0,:] + torch.tensordot(t.unsqueeze(0),vectors,dims=0).reshape(num_points,num_geodesics,2)

In [None]:
plt.title("Geodesics and straight lines in logmap")
for i in range(num_geodesics):
    plt.plot(random_geodesics2plot_logmap[i,:,0],random_geodesics2plot_logmap[i,:,1],c="orange")
    plt.plot(straight_lines[:,i,0],straight_lines[:,i,1],c="black")

In [None]:
exp_base_point_x_i,_ = rungekutta_vect(f=geod_vect,initial_point_array=base_points.repeat(num_points,1).reshape(num_points*num_geodesics,2),
                    initial_speed_array=straight_lines.reshape(-1,2), t=time_array)
y = exp_base_point_x_i[-1].reshape(num_points,num_geodesics, 2)
y = y.detach()

In [None]:
geodesics_y_i,_ = model.connecting_geodesic(y[:-1].reshape(-1,2), y[1:].reshape(-1,2))
geodesics_y_i2plot = geodesics_y_i(t).reshape((num_points-1),num_geodesics,num_points,2).detach()

In [None]:
plt.title("Images of straight lines in logmap through exp maps with appropriate basepoints ")
for i in range(num_geodesics):
    plt.plot(y[:,i,0],y[:,i,1])
    plt.scatter(y[:,i,0],y[:,i,1])
    #plt.plot(geodesics_y_i2plot[:,i,:,0],geodesics_y_i2plot[:,i,:,1],c="black")
    #plt.plot(straight_lines[:,i,0],straight_lines[:,i,1],c="black")

In [None]:
log_straight_lines_length_approx = model.curve_length(geodesics_y_i(t)).reshape(num_points-1,num_geodesics).sum(dim = 0)
print(f"straight_lines_length_approx: {log_straight_lines_length_approx}")

In [None]:
random_geodesic_lengths/log_straight_lines_length_approx

In [None]:
random_geod_length_ratio = (random_geodesic_lengths/log_straight_lines_length_approx).mean().item()
print(f"geodesic length ratio:\n{random_geod_length_ratio}")

In [None]:
fig,(ax1,ax2) = plt.subplots(ncols=2,figsize=(12,6))
fig.suptitle(f"Experiment # {exp_number} with $\lambda_{{\mathrm{{curv}}}}={curv_w}.$")
ax1.set_title("Random geodesics and basepoints")
ax2.set_title("Images of these geodesics through logmaps \n w.r.t. corresponding base points")
for i in range(num_geodesics):
    p = ax1.plot(random_geodesicts2plot[i,:,0],random_geodesicts2plot[i,:,1])
    automatic_color = p[-1].get_color()  
    ax1.scatter(base_points[i,0],base_points[i,1],c = automatic_color)
    ax2.scatter(random_geodesics2plot_logmap[i,:,0],random_geodesics2plot_logmap[i,:,1],c = automatic_color)
    ax2.plot(random_geodesics2plot_logmap[i,:,0],random_geodesics2plot_logmap[i,:,1],c = automatic_color)
fig.text(0.1,0,f"Geodesic length ratio:{random_geod_length_ratio:.4f}")
plt.savefig(f'{Path_pictures}/random_geodesics_exp{exp_number}.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
# tha accuracy here has to be less then the threshehold in the algorithm
dict = {"geod_length_ratio":random_geod_length_ratio}

import json
with open(f'{Path_pictures}/random_geodesic_length_ratio_exp{exp_number}.json', 'w') as json_file:
    json.dump(dict, json_file, indent=4)

In [None]:
from pypdf import PdfWriter
build_report = True
if build_report == True:
    pdfs = [f'{Path_pictures}/multiple_geodesics_exp{exp_number}.pdf',f'{Path_pictures}/geodesic_grid_logmap_exp{exp_number}.pdf',f"{Path_pictures}/random_geodesics_exp{exp_number}.pdf"]

    merger = PdfWriter()

    for pdf in pdfs:
        merger.append(pdf)

    merger.write(f"{Path_pictures}/report_exp_{exp_number}.pdf")
    merger.close()

# Frechet mean

In [None]:
num_cluster_points = 3
cluster = torch.pi*(torch.rand(num_cluster_points,2)-0.5)
plt.scatter(cluster[:,0],cluster[:,1])
plt.show()

In [None]:
frechet_mean = cluster[0]
for i in range(1,num_cluster_points):
    geodesic,_ = model.connecting_geodesic(frechet_mean, cluster[i])
    frechet_mean = geodesic(torch.tensor([1 / (i + 1)]))
frechet_mean = frechet_mean.detach()
print("frechet_mean:", frechet_mean)

In [None]:
frechet_mean = frechet_mean.squeeze()
plt.scatter(cluster[:,0],cluster[:,1])
plt.scatter(frechet_mean[0],frechet_mean[1], c = "red",marker = "*",s=200)
plt.show()