# This notebook builds a report for pretrained torus AE

In [None]:
# prerequisites
%matplotlib inline
import sklearn
from sklearn import datasets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import math
import numpy as np
from tqdm.notebook import tqdm
import os

# Hyperparameters

In [None]:
experiment_number = 0
experiment_name = "MNIST_torus_AE"
violent_saving = True # if False it will not save plots
build_report = True
Path_experiments = "/home/alazarev/CodeProjects/Experiments/"
Path_pictures = f"/home/alazarev/CodeProjects/Experiments/{experiment_name}/experiment{experiment_number}"
if os.path.exists(Path_pictures) == False:
    os.mkdir(Path_pictures) # needs to be commented once the folder for plots is created

# Hyperparameters for dataset
D = 784       #dimension
d = 2         # latent space dimension
k = 3         # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
shift_class = 0.0
var_class = 1.0
intercl_var = 0.1 # this has to be greater than 0.04
# this creates a gaussian, 
# i.e.random shift 
# proportional to the value of intercl_var
# Dimension of latent variables

# Number of workers in DataLoader
num_workers = 10
sr_noise = 1e-6
sr_numpoints = 18000 #k*n

Z_DIM = d
split_ratio = 0.2

In [None]:
import pdfkit
import json
with open(f'{Path_experiments}json_files/hyperparameters_exp{experiment_number}.json') as json_file:
    hyperparameters = json.load(json_file)
    del hyperparameters['Path_pictures'], hyperparameters['Path_weights']
pdfkit.from_string(json.dumps(hyperparameters),output_path=f"{Path_pictures}/hyperparameters_exp{experiment_number}.pdf")

In [None]:
set_name = hyperparameters["set_name"]
batch_size = hyperparameters["batch_size"]

# Set uploading 

In [None]:
import sys
sys.path.append('../') # have to go 1 level up
import ricci_regularization as RR

In [None]:
if set_name == "MNIST":
    #MNIST_SIZE = 28
    # MNIST Dataset
    train_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)

    # Data Loader (Input Pipeline)
    #train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    #test_loader  = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
elif set_name == "Synthetic":
    # Generate dataset
    # via classes
    torch.manual_seed(0) # reproducibility
    my_dataset = RR.SyntheticDataset(k=k,n=n,d=d,D=D,
                                        shift_class=shift_class, intercl_var=intercl_var, var_class=var_class)

    train_dataset = my_dataset.create
elif set_name == "Swissroll":
    D = 3
    train_dataset =  sklearn.datasets.make_swiss_roll(n_samples=sr_numpoints, noise=sr_noise)
    sr_points = torch.from_numpy(train_dataset[0]).to(torch.float32)
    #sr_points = torch.cat((sr_points,torch.zeros(sr_numpoints,D-3)),dim=1)
    sr_colors = torch.from_numpy(train_dataset[1]).to(torch.float32)
    from torch.utils.data import TensorDataset
    train_dataset = TensorDataset(sr_points,sr_colors)

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# VAE structure

In [None]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        # Non-linearity
        self.non_linearity = torch.sin
        self.non_linearity2 = torch.cos # should this not be vice versa??
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc3 = nn.Linear(h_dim2, z_dim)
        # decoder part
        # Double dimension as circle is mimicked using sin and cos charts
        self.fc4 = nn.Linear(2*z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = self.non_linearity(self.fc1(x))
        h = self.non_linearity(self.fc2(h))
        h = self.fc3(h)
        # Concatenate sin and cos non-linearities
        # Warning: Done along dimension 1, as dimension 0 is the batch dimension
        #h = torch.cat( (self.non_linearity(h), self.non_linearity2(h)), 1)
        h = torch.cat( (self.non_linearity2(h), self.non_linearity(h)), 1)
        return h # Latent variable z, Wannabe uniform on the circle
    def encoder2lifting(self, x):
        h = self.non_linearity(self.fc1(x))
        h = self.non_linearity(self.fc2(h))
        h = self.fc3(h)
        # Concatenate sin and cos non-linearities
        # Warning: Done along dimension 1, as dimension 0 is the batch dimension
        #h = torch.cat( (self.non_linearity(h), self.non_linearity2(h)), 1)
        # cosphi,sinphi
        h = torch.cat( (self.non_linearity2(h), self.non_linearity(h)), 1) 
        cosphi = h[:, 0:Z_DIM]
        sinphi = h[:, Z_DIM:2*Z_DIM]
        phi = torch.acos(cosphi)*torch.sgn(torch.asin(sinphi))
        return phi
    def encoder_torus(self, x):   
        #This is a mapping to a feature space so it would be wrong to use it
        h = self.non_linearity(self.fc1(x))
        h = self.non_linearity(self.fc2(h))
        h = self.fc3(h)
        return h
        
    def decoder(self, z):
        #h = self.non_linearity( math.pi*z + self.decoderBias ) # Expects 2pi periodic non-linearity to create torus topology
        h = z
        h = self.non_linearity( self.fc4(h))
        h = self.non_linearity( self.fc5(h))
        return self.non_linearity( self.fc6(h) )
    def decoder_torus(self, z):
        h = z
        h = torch.cat( (self.non_linearity2(h), self.non_linearity(h)), 1)
        h = self.non_linearity( self.fc4(h))
        h = self.non_linearity( self.fc5(h))
        return self.non_linearity( self.fc6(h) )
    
    def forward(self, x):
        z = self.encoder(x.view(-1, D))
        return self.decoder(z), z

# old model
vae = VAE(x_dim=D, h_dim1= 512, h_dim2=256, z_dim=Z_DIM)
#changed model
#vae = VAE(x_dim=D, h_dim1= 3, h_dim2=2, z_dim=Z_DIM)

if torch.cuda.is_available():
    vae.cuda()

### Loading the saved weights

In [None]:
PATH_vae = f'../nn_weights/exp{experiment_number}.pt'
vae.load_state_dict(torch.load(PATH_vae))
vae.eval()

In [None]:
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)

# Torus latent space

In [None]:
"""
#inspiration for vae.encoder2lifting
def circle2anglevectorized(zLatentTensor,Z_DIM = Z_DIM):
    cosphi = zLatentTensor[:, 0:Z_DIM]
    sinphi = zLatentTensor[:, Z_DIM:2*Z_DIM]
    phi = torch.acos(cosphi)*torch.sgn(torch.asin(sinphi))
    return phi
"""

In [None]:
#Classes
if set_name == "Synthetic":
    N = k
elif set_name == "MNIST":
    N = 10

In [None]:
#zlist = []
colorlist = []
enc_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( train_loader, position=0 ):
#for (data, labels) in train_loader:
    input_dataset_list.append(data)
    recon_dataset_list.append(vae(data)[0])
    #zlist.append(vae(data)[1])
    enc_list.append(vae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()
#assert torch.equal(enc,enc_tensor)

In [None]:
#angleLatentviatorch = circle2anglevectorized(zLatent_tensor)/math.pi
#plt.scatter(angleLatentviatorch[:,0],angleLatentviatorch[:,1], c=labels, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
#enc = vae.encoder2lifting(train_dataset.data.reshape(-1,784).to(dtype = torch.float32)).detach()
#enc = vae.encoder2lifting(train_dataset.data.reshape(-1,784)/256).detach() # this works!!!
#enc = vae.encoder2lifting(train_dataset.data.reshape(-1,784)/256).detach()
#enc = vae.encoder_torus(train_dataset.data.reshape(-1,784)/256).detach()
#plt.scatter(enc[:,0],enc[:,1], c=train_dataset.targets, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.figure(figsize=(8, 6))
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.colorbar(ticks=range(N))
plt.grid(True)
plt.savefig(f"{Path_pictures}/latent_space.pdf",format="pdf")

## reconstruction loss

In [None]:
abs_error_tensor = input_dataset.view(-1,D) - recon_dataset
mse_array = abs_error_tensor.norm(dim=1).detach()
mse_array = mse_array**2/D
#F.mse_loss(input_dataset.view(-1,D)[0],recon_dataset[0],reduction='mean')

# metric losses computation

In [None]:
curvature_array = RR.Sc_jacfwd_vmap(encoded_points,function=vae.decoder_torus).detach()
metric_array = RR.metric_jacfwd_vmap(encoded_points,function=vae.decoder_torus).detach()
det_array = torch.det(metric_array)
trace_array = torch.einsum('jii->j',metric_array)

In [None]:
latent = encoded_points_no_grad
left = latent[:,0].min()
right = latent[:,0].max()
bottom = latent[:,1].min()
top = latent[:,1].max()

xsize = right - left
ysize = top - bottom
xcenter = 0.5*(left + right)
ycenter = 0.5*(bottom + top)

In [None]:
linsize = 200

import torch.func as TF
grid_on_ls = RR.make_grid(linsize,xsize=xsize,ysize=ysize,xcenter=xcenter,ycenter=ycenter)
metric_on_grid = RR.metric_jacfwd_vmap(grid_on_ls,function=vae.decoder_torus)
metric_det_on_grid = torch.det(metric_on_grid)
metric_trace_on_grid = TF.vmap(torch.trace)(metric_on_grid)
curv_on_the_grid = RR.Sc_jacfwd_vmap(grid_on_ls, function = vae.decoder_torus)

# Recon loss

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0)
      #encoder.eval()
      #decoder.eval()
      with torch.no_grad():
         #rec_img  = decoder(encoder(img))
         rec_img  = decoder(encoder(img.reshape(1,D))).reshape(1,28,28)
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()   

In [None]:
plot_ae_outputs(vae.encoder2lifting,vae.decoder_torus)

In [None]:
import matplotlib
from matplotlib import ticker

# (generate plot here)

plt.rcParams.update({'font.size': 16})

size_of_points = 20
fig, (ax00,ax0)= plt.subplots(ncols=2, nrows=1,figsize=(15,6),dpi=300)
# (ax3,ax4) can  be added

fig.tight_layout(pad=2.0)

ax00.title.set_text("AE latent space")
if set_name == "Synthetic" or set_name == "MNIST":
    p00 = ax00.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=color_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap=discrete_cmap(N, "jet"))
    fig.colorbar(p00,label="initial color", ticks=(np.arange(N)))    
else:
    p00 = ax00.scatter( encoded_points[:,0], encoded_points[:,1], c=labels, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet')
    fig.colorbar(p00,label="initial color")

ax0.title.set_text("Reconstruction loss")
p0 = ax0.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=mse_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet')#,norm=matplotlib.colors.LogNorm())
cb = fig.colorbar(p0,label="squared l2 norm errors")
tick_locator = ticker.MaxNLocator(nbins=10)
cb.locator = tick_locator
cb.update_ticks()

if violent_saving == True:
    fig.savefig(f'{Path_pictures}/init_colors_recon_loss.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
p = plt.scatter(torch.rand(5),torch.rand(5),c = torch.rand(5),norm=matplotlib.colors.LogNorm())
cb = plt.colorbar(p)
tick_locator = ticker.MaxNLocator(nbins=5)
cb.locator = tick_locator
cb.update_ticks()
plt.show()

In [None]:
plt.hist(curvature_array, bins = 60)
plt.show()

In [None]:
plt.hist(curv_on_the_grid.detach(), bins = 200)
plt.show()

In [None]:
#xcenter = 0.0 
#ycenter = 0.0
xshift = 0.0
yshift = 0.0
numticks = 5
if set_name == "Synthetic":
    tick_decimals = 2
else:
    tick_decimals = 1
plt.rcParams.update({'font.size': 16})

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15,12),dpi=300)

fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5*xsize, xcenter + 0.5*xsize, numticks) 
yticks = np.linspace(ycenter - 0.5*ysize, ycenter + 0.5*ysize, numticks)

xtick_labels = (xticks+xshift).tolist()
ytick_labels = (yticks+yshift).tolist()

xtick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels ]
ytick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks)*(linsize-1)

im1 = ax1.imshow(abs(curv_on_the_grid.detach().reshape(linsize,linsize)),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.LogNorm())
fig.colorbar(im1,ax = ax1, shrink = 1, label = "curvature abs value")
ax1.set_title("Absolute value of scalar curvature")

im2 = ax2.imshow(curv_on_the_grid.detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=abs(0.01*curv_on_the_grid.mean()).item()))
fig.colorbar(im2,ax = ax2, shrink = 1, label = "curvature")
ax2.set_title("Scalar curvature")

im3 = ax3.imshow((torch.sqrt(metric_det_on_grid)).detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im3,ax = ax3, shrink = 1, label = "$\sqrt{det(G)}$")
ax3.set_title("$\sqrt{det(G)}$")

im4 = ax4.imshow((0.5*(metric_trace_on_grid)).detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im4, ax = ax4, shrink = 1, label = "0.5$\cdot$tr(G)")
ax4.set_title("0.5$\cdot$tr(G)")

axs = (ax1, ax2, ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)

if violent_saving == True:
    plt.savefig(f'{Path_pictures}/heatmaps_not_scaled.pdf',bbox_inches='tight',format='pdf')
plt.show()

### scalar curvature

In [None]:
max_curvature = curv_on_the_grid.max().item()
min_curvature = curv_on_the_grid.min().item()
linthresh_curvature = 0.01*abs(curv_on_the_grid.mean()).item()
linthresh_curvature

max_abs_curvature = abs(curv_on_the_grid).max().item()
min_abs_curvature = 0.01*abs(curv_on_the_grid).mean().item()

In [None]:

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15,12),dpi=300)

fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5*xsize, xcenter + 0.5*xsize, numticks) 
yticks = np.linspace(ycenter - 0.5*ysize, ycenter + 0.5*ysize, numticks)

xtick_labels = (xticks+xshift).tolist()
ytick_labels = (yticks+yshift).tolist()

xtick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels]
ytick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks)*(linsize-1)


ax1.title.set_text("Absolute value of scalar curvature")
p1 = ax1.scatter( latent[:,0], latent[:,1], c=abs(curvature_array), 
                 alpha=1, s = size_of_points, marker='o', 
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.LogNorm(vmin = min_abs_curvature, 
                                                vmax = max_abs_curvature))
fig.colorbar(p1,label="curvature abs value")

ax2.title.set_text("Absolute value of scalar curvature overall")
im1 = ax2.imshow(abs(curv_on_the_grid.detach().reshape(linsize,linsize)),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.LogNorm(vmin = min_abs_curvature, 
                                                  vmax = max_abs_curvature))
fig.colorbar(im1,ax = ax2, shrink = 1, label = "curvature abs value")
ax1.set_title("Absolute value of scalar curvature")

ax3.title.set_text("Scalar curvature")
p2 = ax3.scatter( latent[:,0], latent[:,1], c=curvature_array, 
                 alpha=1, s = size_of_points, marker='o', 
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                   vmin = min_curvature, 
                                                   vmax = max_curvature))
fig.colorbar(p2,label="curvature")

ax4.title.set_text("Scalar curvature overall")
im2 = ax4.imshow(curv_on_the_grid.detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                   vmin = min_curvature, 
                                                   vmax = max_curvature))
fig.colorbar(im2,ax = ax4, shrink = 1, label = "curvature")
ax4.set_title("Scalar curvature overall")

axs = (ax1, ax3)
for ax in axs:
    ax.set_ylim(bottom,top)
    ax.set_xlim(left,right)
    ax.set_xticks(list(map(float, xtick_labels)), labels = xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels = ytick_labels)

axs = (ax2, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/curvature_heatmaps.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
fig, ((ax1,ax3),(ax2,ax4))= plt.subplots(ncols=2,nrows=2,figsize = (15,12),dpi=300)

fig.tight_layout(pad=2.0)

ax1.title.set_text("$\sqrt{det(G)}$")
p = ax1.scatter( latent[:,0], latent[:,1],
                c=torch.sqrt(abs(det_array)), alpha=1, s = size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=metric_det_on_grid.max().sqrt().item())
fig.colorbar(p,label="$\sqrt{det(G)}$")
ax2.title.set_text("0.5$\cdot$tr(G)")
q = ax2.scatter( latent[:,0], latent[:,1], 
                c=0.5*(trace_array), alpha=1, s= size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=0.5*metric_trace_on_grid.max().item())
fig.colorbar(q,label="0.5$\cdot$tr(G)")

im3 = ax3.imshow((torch.sqrt(metric_det_on_grid)).detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im3,ax = ax3, shrink = 1, label = "$\sqrt{det(G)}$")
ax3.set_title("$\sqrt{det(G)}$")

im4 = ax4.imshow((0.5*(metric_trace_on_grid)).detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None,
                 vmax=0.5*metric_trace_on_grid.max().item())
fig.colorbar(im4, ax = ax4, shrink = 1, label = "0.5$\cdot$tr(G)")
ax4.set_title("0.5$\cdot$tr(G)")

axs = (ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)

axs = (ax1, ax2)
for ax in axs:
    ax.set_ylim(bottom,top)
    ax.set_xlim(left,right)
    ax.set_xticks(list(map(float, xtick_labels)), labels = xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels = ytick_labels)

if violent_saving == True:
    #plt.savefig(f'{Path_pictures}/metric_det_trace.eps',bbox_inches='tight',format='eps')
    plt.savefig(f'{Path_pictures}/metric_det_trace.pdf',bbox_inches='tight',format='pdf')
plt.show()

### Merge pdfs

In [None]:
from pypdf import PdfMerger

In [None]:
#build_report = True
#experiment_number = 1
#Path_pictures = f"/home/alazarev/CodeProjects/Experiments/{experiment_name}/experiment{experiment_number}"
if build_report == True:
    pdfs = [f"{Path_pictures}/hyperparameters_exp{experiment_number}.pdf",f'{Path_pictures}/losses_exp{experiment_number}.pdf',f'{Path_pictures}/init_colors_recon_loss.pdf', f'{Path_pictures}/curvature_heatmaps.pdf', f'{Path_pictures}/metric_det_trace.pdf']
    #pdfs = [f'{Path_pictures}/losses.pdf', f'{Path_pictures}/9losses.pdf', f'{Path_pictures}/init_colors_recon_loss.pdf', f'{Path_pictures}/curvature_heatmaps.pdf', f'{Path_pictures}/metric_det_trace.pdf']

    merger = PdfMerger()

    for pdf in pdfs:
        merger.append(pdf)

    merger.write(f"{Path_pictures}/report_{experiment_name}_exp_{experiment_number}.pdf")
    merger.close()
