# This notebook builds a report for pretrained torus AE

In [None]:
# prerequisites
%matplotlib inline
import sklearn
from sklearn import datasets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import math
import numpy as np
from tqdm.notebook import tqdm
import os

# Hyperparameters

In [None]:
# json file name
#experiment_json = f'../experiments/Swissroll_torus_AEexp91.json'
experiment_json = f'../experiments/MNIST01_torus_AEexp8.json'

violent_saving = True # if False it will not save plots
build_report = True

# Loading JSON file
import json
with open(experiment_json) as json_file:
    json_config = json.load(json_file)

print( json.dumps(json_config, indent=2 ) )

Path_experiments = json_config["Path_experiments"]
experiment_name = json_config["experiment_name"]
experiment_number = json_config["experiment_number"]
Path_pictures = json_config["Path_pictures"]

# # Number of workers in DataLoader
# num_workers = 10

In [None]:
# DUMP ONLY REPORTING PARTS
import pdfkit

keys2print = ['experiment_name','experiment_number','dataset',
 'architecture', 'optimization_parameters', 'losses', 'OOD_parameters', 'training_results_on_test_data']
json_config2print = {key : json_config[key] for key in keys2print}
print(json_config2print)
with open(f'{Path_experiments}/dummy_config.json', 'w') as json_file:
    json.dump(json_config2print, json_file, indent=4)

#pdfkit.from_string(json.dumps(json_config2print),output_path=f"{Path_pictures}/hyperparameters_exp{experiment_number}.pdf")
pdfkit.from_file(f'{Path_experiments}/dummy_config.json',output_path=f"{Path_pictures}/hyperparameters_exp{experiment_number}.pdf")

In [None]:
set_name    = json_config["dataset"]["name"]
split_ratio = json_config["optimization_parameters"]["split_ratio"]
batch_size  = json_config["optimization_parameters"]["batch_size"]

# Dataset uploading 

In [None]:
# import sys
# sys.path.append('../') # have to go 1 level up
import ricci_regularization

In [None]:
if set_name == "MNIST":
    #MNIST_SIZE = 28
    # MNIST Dataset
    D = 784
    train_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)
elif set_name == "MNIST01":
    D = 784
    full_mnist_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)
    
    mask = (full_mnist_dataset.targets == -1) 
    selected_labels = json_config["dataset"]["selected_labels"]
    for label in selected_labels:
        mask = mask | (full_mnist_dataset.targets == label)
    indices01 = torch.where(mask)[0]
    
    from torch.utils.data import Subset
    train_dataset = Subset(full_mnist_dataset, indices01) # MNIST only with 0,1 indices

elif set_name == "Synthetic":
    k = json_config["dataset"]["parameters"]["k"]
    n = json_config["dataset"]["parameters"]["n"]
    d = json_config["architecture"]["latent_dim"]
    D = json_config["architecture"]["input_dim"]
    shift_class = json_config["dataset"]["parameters"]["shift_class"]
    intercl_var = json_config["dataset"]["parameters"]["intercl_var"]
    var_class = json_config["dataset"]["parameters"]["var_class"]
    # Generate dataset
    # via classes
    torch.manual_seed(0) # reproducibility
    my_dataset = ricci_regularization.SyntheticDataset(k=k,n=n,d=d,D=D,
                                        shift_class=shift_class, intercl_var=intercl_var, var_class=var_class)

    train_dataset = my_dataset.create
elif set_name == "Swissroll":
    sr_noise = json_config["dataset"]["parameters"]["sr_noise"]
    sr_numpoints = json_config["dataset"]["parameters"]["sr_numpoints"]

    D = 3
    train_dataset =  sklearn.datasets.make_swiss_roll(n_samples=sr_numpoints, noise=sr_noise)
    sr_points = torch.from_numpy(train_dataset[0]).to(torch.float32)
    #sr_points = torch.cat((sr_points,torch.zeros(sr_numpoints,D-3)),dim=1)
    sr_colors = torch.from_numpy(train_dataset[1]).to(torch.float32)
    from torch.utils.data import TensorDataset
    train_dataset = TensorDataset(sr_points,sr_colors)

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [m-int(m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# VAE structure

In [None]:
latent_dim = json_config["architecture"]["latent_dim"]
input_dim  = json_config["architecture"]["input_dim"]
architecture_type = json_config["architecture"]["name"]

if architecture_type== "TorusAE":
    torus_ae   = ricci_regularization.Architectures.TorusAE(x_dim=input_dim, h_dim1= 512, h_dim2=256, z_dim=latent_dim)
elif architecture_type =="TorusConvAE":
    torus_ae   = ricci_regularization.Architectures.TorusConvAE(x_dim=input_dim, h_dim1= 512, h_dim2=256, z_dim=latent_dim,pixels=28)
if torch.cuda.is_available():
    torus_ae.cuda()

### Loading the saved weights

In [None]:
# NO! Use the path ../experiments/<Your experiment>/nn_weights/
PATH_ae_wights = json_config["weights_saved_at"]
torus_ae.load_state_dict(torch.load(PATH_ae_wights))
torus_ae.eval()

In [None]:
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)

# Torus latent space

In [None]:
"""
#inspiration for torus_ae.encoder2lifting
def circle2anglevectorized(zLatentTensor,Z_DIM = Z_DIM):
    cosphi = zLatentTensor[:, 0:Z_DIM]
    sinphi = zLatentTensor[:, Z_DIM:2*Z_DIM]
    phi = torch.acos(cosphi)*torch.sgn(torch.asin(sinphi))
    return phi
"""

In [None]:
#Classes
N = json_config["dataset"]["parameters"]["k"]


In [None]:
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()
#assert torch.equal(enc,enc_tensor)

In [None]:
plt.figure(figsize=(8, 6))

if set_name == "Swissroll":
    plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap= 'jet')
else:
    plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
    plt.colorbar(ticks=range(N))
plt.grid(True)
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/latent_space.pdf",format="pdf")

## reconstruction loss

In [None]:
abs_error_tensor = input_dataset.view(-1,D) - recon_dataset
mse_array = abs_error_tensor.norm(dim=1).detach()
mse_array = mse_array**2/D
#F.mse_loss(input_dataset.view(-1,D)[0],recon_dataset[0],reduction='mean')

# metric losses computation

In [None]:
curvature_array = ricci_regularization.Sc_jacfwd_vmap(encoded_points,device = torch.device("cpu"),function=torus_ae.decoder_torus).detach()
metric_array = ricci_regularization.metric_jacfwd_vmap(encoded_points,function=torus_ae.decoder_torus).detach()
det_array = torch.det(metric_array)
trace_array = torch.einsum('jii->j',metric_array)

In [None]:
# latent \in [-\pi,\pi]. grid parameteres for evaluation.
latent = encoded_points_no_grad
left = latent[:,0].min()
right = latent[:,0].max()
bottom = latent[:,1].min()
top = latent[:,1].max()

xsize = right - left
ysize = top - bottom
xcenter = 0.5*(left + right)
ycenter = 0.5*(bottom + top)

In [None]:
linsize = 200 #200

import torch.func as TF
grid_on_ls = ricci_regularization.make_grid(linsize,xsize=xsize,ysize=ysize,xcenter=xcenter,ycenter=ycenter)

grid_numpoints = grid_on_ls.shape[0]
bs = 4000
metric_det_list = []
metric_trace_list = []
curv_list = []
for i in range(grid_numpoints//bs):
    batch_of_grid = grid_on_ls[i*bs:(i+1)*bs]
    metric_on_batch_of_grid = ricci_regularization.metric_jacfwd_vmap(batch_of_grid,function=torus_ae.decoder_torus)
    metric_det_on_batch_of_grid = torch.det(metric_on_batch_of_grid)
    metric_trace_on_batch_of_grid = TF.vmap(torch.trace)(metric_on_batch_of_grid)
    curv_on_batch_of_grid = ricci_regularization.Sc_jacfwd_vmap(batch_of_grid,device = torch.device("cpu"), function = torus_ae.decoder_torus)
    metric_det_list.append(metric_det_on_batch_of_grid.tolist())
    metric_trace_list.append(metric_trace_on_batch_of_grid.tolist())
    curv_list.append(curv_on_batch_of_grid.tolist())
metric_det_on_grid = np.concatenate(metric_det_list)
metric_trace_on_grid = np.concatenate(metric_trace_list)
curv_on_the_grid = np.concatenate(curv_list)
"""
metric_on_grid = ricci_regularization.metric_jacfwd_vmap(grid_on_ls,function=torus_ae.decoder_torus)
metric_det_on_grid = torch.det(metric_on_grid)
metric_trace_on_grid = TF.vmap(torch.trace)(metric_on_grid)
curv_on_the_grid = ricci_regularization.Sc_jacfwd_vmap(grid_on_ls,device = torch.device("cpu"), function = torus_ae.decoder_torus)
"""

In [None]:
# latent \in [-1,1]. grid reparametrization for plotting
encoded_points_no_grad = encoded_points_no_grad/math.pi
latent = encoded_points_no_grad
left = latent[:,0].min()
right = latent[:,0].max()
bottom = latent[:,1].min()
top = latent[:,1].max()

xsize = right - left
ysize = top - bottom
xcenter = 0.5*(left + right)
ycenter = 0.5*(bottom + top)

# Recon loss

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0)
      #encoder.eval()
      #decoder.eval()
      with torch.no_grad():
         #rec_img  = decoder(encoder(img))
         rec_img  = decoder(encoder(img.reshape(1,D))).reshape(1,28,28)
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()   

In [None]:
if set_name == "MNIST":
    plot_ae_outputs(torus_ae.encoder2lifting,torus_ae.decoder_torus)

In [None]:
import matplotlib
from matplotlib import ticker

# (generate plot here)

plt.rcParams.update({'font.size': 16})

size_of_points = 20
fig, (ax00,ax0)= plt.subplots(ncols=2, nrows=1,figsize=(15,6),dpi=300)
# (ax3,ax4) can  be added

fig.tight_layout(pad=2.0)

ax00.title.set_text("AE latent space")
if set_name == "Synthetic" or set_name == "MNIST":
    p00 = ax00.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=color_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap=discrete_cmap(N, "jet"))
    fig.colorbar(p00,label="initial color", ticks=(np.arange(N)))    
else:
    p00 = ax00.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=color_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet')
    fig.colorbar(p00,label="initial color")
ax00.grid(True)
ax0.title.set_text("Reconstruction loss")
p0 = ax0.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=mse_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.LogNorm())
ax0.grid(True)
cb = fig.colorbar(p0,label="squared l2 norm errors")
#tick_locator = ticker.MaxNLocator(nbins=10)
#cb.locator = tick_locator
#cb.update_ticks()

if violent_saving == True:
    fig.savefig(f'{Path_pictures}/init_colors_recon_loss.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
plt.hist(curvature_array, bins = math.ceil(math.sqrt(curvature_array.shape[0])))
plt.show()

In [None]:
plt.hist(curv_on_the_grid, bins = 200)
plt.show()

In [None]:
torch.quantile((curvature_array**2*torch.sqrt(det_array)),.999999999)

In [None]:
plt.hist(curvature_array**2*torch.sqrt(det_array).detach(),bins=math.ceil(math.sqrt(curvature_array.shape[0])))
plt.show()

In [None]:
#xcenter = 0.0 
#ycenter = 0.0
xshift = 0.0
yshift = 0.0
numticks = 5
if set_name == "Synthetic":
    tick_decimals = 2
else:
    tick_decimals = 1
plt.rcParams.update({'font.size': 16})

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15,12),dpi=300)

fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5*xsize, xcenter + 0.5*xsize, numticks) 
yticks = np.linspace(ycenter - 0.5*ysize, ycenter + 0.5*ysize, numticks)

xtick_labels = (xticks+xshift).tolist()
ytick_labels = (yticks+yshift).tolist()

xtick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels ]
ytick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks)*(linsize-1)

im1 = ax1.imshow(abs(curv_on_the_grid.reshape(linsize,linsize)),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.LogNorm())
fig.colorbar(im1,ax = ax1, shrink = 1, label = "curvature abs value")
ax1.set_title("Absolute value of scalar curvature")

im2 = ax2.imshow(curv_on_the_grid.reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=abs(0.01*curv_on_the_grid.mean()).item()))
fig.colorbar(im2,ax = ax2, shrink = 1, label = "curvature")
ax2.set_title("Scalar curvature")

im3 = ax3.imshow((np.sqrt(metric_det_on_grid)).reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im3,ax = ax3, shrink = 1, label = "$\sqrt{det(G)}$")
ax3.set_title("$\sqrt{det(G)}$")

im4 = ax4.imshow((0.5*(metric_trace_on_grid)).reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im4, ax = ax4, shrink = 1, label = "0.5$\cdot$tr(G)")
ax4.set_title("0.5$\cdot$tr(G)")

axs = (ax1, ax2, ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)

if violent_saving == True:
    plt.savefig(f'{Path_pictures}/heatmaps_not_scaled.pdf',bbox_inches='tight',format='pdf')
plt.show()

### scalar curvature

In [None]:
max_curvature = curv_on_the_grid.max().item()
min_curvature = curv_on_the_grid.min().item()
linthresh_curvature = 0.01*abs(curv_on_the_grid.mean()).item()
linthresh_curvature

max_abs_curvature = abs(curv_on_the_grid).max().item()
min_abs_curvature = 0.01*abs(curv_on_the_grid).mean().item()

In [None]:

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15,12),dpi=300)

fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5*xsize, xcenter + 0.5*xsize, numticks) 
yticks = np.linspace(ycenter - 0.5*ysize, ycenter + 0.5*ysize, numticks)

xtick_labels = (xticks+xshift).tolist()
ytick_labels = (yticks+yshift).tolist()

xtick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels]
ytick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks)*(linsize-1)


ax1.title.set_text("Absolute value of scalar curvature")
p1 = ax1.scatter( latent[:,0], latent[:,1], c=abs(curvature_array), 
                 alpha=1, s = size_of_points, marker='o', 
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.LogNorm(vmin = min_abs_curvature, 
                                                vmax = max_abs_curvature))
fig.colorbar(p1,label="curvature abs value")

ax2.title.set_text("Absolute value of scalar curvature overall")
im1 = ax2.imshow(abs(curv_on_the_grid.reshape(linsize,linsize)),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.LogNorm(vmin = min_abs_curvature, 
                                                  vmax = max_abs_curvature))
fig.colorbar(im1,ax = ax2, shrink = 1, label = "curvature abs value")
ax1.set_title("Absolute value of scalar curvature")

ax3.title.set_text("Scalar curvature")
p2 = ax3.scatter( latent[:,0], latent[:,1], c=curvature_array, 
                 alpha=1, s = size_of_points, marker='o', 
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                   vmin = min_curvature, 
                                                   vmax = max_curvature))
fig.colorbar(p2,label="curvature")

ax4.title.set_text("Scalar curvature overall")
im2 = ax4.imshow(curv_on_the_grid.reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                   vmin = min_curvature, 
                                                   vmax = max_curvature))
fig.colorbar(im2,ax = ax4, shrink = 1, label = "curvature")
ax4.set_title("Scalar curvature overall")

axs = (ax1, ax3)
for ax in axs:
    ax.set_ylim(bottom,top)
    ax.set_xlim(left,right)
    ax.set_xticks(list(map(float, xtick_labels)), labels = xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels = ytick_labels)

axs = (ax2, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/curvature_heatmaps.pdf',bbox_inches='tight',format='pdf')
plt.show()

# Jacobian of Encoder and Decoder norm heatmaps

In [None]:
from torch.func import jacrev,jacfwd

In [None]:
validation_set_size = 4000
metric_array_encoder = ricci_regularization.metric_jacrev_vmap(input_dataset[:validation_set_size],function=torus_ae.encoder2lifting,latent_space_dim=D).detach()
trace_array_encoder = torch.einsum('jii->j',metric_array_encoder)

In [None]:
fig, axes = plt.subplots(ncols=2,nrows=1, figsize=(15,6))
p0 = axes[1].scatter( latent[:,0], latent[:,1],
                c=trace_array, alpha=1, s = size_of_points, 
                marker='o', edgecolor='none', cmap='jet', norm= matplotlib.colors.LogNorm())
cb0 = plt.colorbar(p0, label=r"$\|\nabla \Psi \|_F = \mathrm{tr} (G) $")
axes[1].set_title("Jacobian of the decoder")

p1 = axes[0].scatter( latent[:validation_set_size,0], latent[:validation_set_size,1],
                c=trace_array_encoder, alpha=1, s = size_of_points, 
                marker='o', edgecolor='none', cmap='jet',norm= matplotlib.colors.LogNorm())
cb1 = plt.colorbar(p1, label=r"$\|\nabla \Phi \|_F$")
axes[0].set_title("Jacobian of the encoder")
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/jac_norms_encoder_decoder.pdf',bbox_inches='tight',format='pdf')
plt.show()

# metric

In [None]:
fig, ((ax1,ax3),(ax2,ax4))= plt.subplots(ncols=2,nrows=2,figsize = (15,12),dpi=300)

fig.tight_layout(pad=2.0)

ax1.title.set_text("$\sqrt{det(G)}$")
p = ax1.scatter( latent[:,0], latent[:,1],
                c=torch.sqrt(abs(det_array)), alpha=1, s = size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=np.sqrt(metric_det_on_grid.max()))
fig.colorbar(p,label="$\sqrt{det(G)}$")
ax2.title.set_text("0.5$\cdot$tr(G)")
q = ax2.scatter( latent[:,0], latent[:,1], 
                c=0.5*(trace_array), alpha=1, s= size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=0.5*metric_trace_on_grid.max().item())
fig.colorbar(q,label="0.5$\cdot$tr(G)")

im3 = ax3.imshow((np.sqrt(metric_det_on_grid)).reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im3,ax = ax3, shrink = 1, label = "$\sqrt{det(G)}$")
ax3.set_title("$\sqrt{det(G)}$")

im4 = ax4.imshow((0.5*(metric_trace_on_grid)).reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None,
                 vmax=0.5*metric_trace_on_grid.max().item())
fig.colorbar(im4, ax = ax4, shrink = 1, label = "0.5$\cdot$tr(G)")
ax4.set_title("0.5$\cdot$tr(G)")

axs = (ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)

axs = (ax1, ax2)
for ax in axs:
    ax.set_ylim(bottom,top)
    ax.set_xlim(left,right)
    ax.set_xticks(list(map(float, xtick_labels)), labels = xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels = ytick_labels)

if violent_saving == True:
    #plt.savefig(f'{Path_pictures}/metric_det_trace.eps',bbox_inches='tight',format='eps')
    plt.savefig(f'{Path_pictures}/metric_det_trace.pdf',bbox_inches='tight',format='pdf')
plt.show()

### Merge pdfs

In [None]:
from pypdf import PdfWriter

In [None]:
if build_report == True:
    pdfs = [f"{Path_pictures}/hyperparameters_exp{experiment_number}.pdf",f'{Path_pictures}/losses_exp{experiment_number}.pdf',f'{Path_pictures}/init_colors_recon_loss.pdf', f'{Path_pictures}/curvature_heatmaps.pdf', f'{Path_pictures}/metric_det_trace.pdf', f'{Path_pictures}/jac_norms_encoder_decoder.pdf']

    merger = PdfWriter()

    for pdf in pdfs:
        merger.append(pdf)

    merger.write(f"{Path_pictures}/report_{experiment_name}_exp_{experiment_number}.pdf")
    merger.close()
