This notebook builds a report for pretrained torus AE.

NB! All plots are only suitable for dimension of the latent space d=2.
All plots are saved if 'violent_saving' == True.

This notebook contains:

0) Loading data and nn weights
1) Data and reconstruction loss scatterplots
2) Histograms of curvature values over data and over a dense grid
3) Automatically scaled heatmaps over the latent space and scatterplots of datasets $\{X_i\}$ for :
    1) absolute value of scalar curvature $|R(X_i)|$
    2) scalar curvature $R(X_i)$
    3) square root of metric matrix determinant $\sqrt{\mathrm{det}G(X_i)}$ 
    4) half trace of metric matrix: $0.5 \cdot \mathrm{tr}G(X_i)$
4) Scatterplots of data vs heatmaps over the whole latent space. Unique colorbar scaling:
    1) absolute value of scalar curvature $|R(X_i)|$
    2) scalar curvature $R(X_i)$
    3) square root of metric matrix determinant $\sqrt{\mathrm{det}G(X_i)}$ 
    4) half trace of metric matrix: $0.5 \cdot \mathrm{tr}G(X_i)$
5) Jacobian of the encoder and decoder Frobenius norms
6) Merge pdfs: All plots are mereged into a single report in pdf format if 'build_report' == True. 

In [None]:
# prerequisites
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
from tqdm.notebook import tqdm
from pypdf import PdfWriter
import matplotlib
import ricci_regularization
import json, yaml
import os

from sklearn import datasets
from torchvision import datasets, transforms

# 0. Loading data and nn weights

In [None]:

violent_saving = True # if False it will not save plots
build_report = True
use_test_data_for_plots = True

In [None]:
#with open('../experiments/MNIST01_exp7_config.yaml', 'r') as yaml_file:
with open('../experiments/Swissroll_exp1_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"]
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

torus_ae = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config)

print("AE weights loaded successfully.")

In [None]:
experiment_name = yaml_config["experiment"]["name"]

#Path_pictures = yaml_config["experiment"]["path"]
Path_pictures = "../experiments/" + yaml_config["experiment"]["name"]
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name in ["MNIST", "MNIST01", "Synthetic"]:
    # k from the JSON configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
print("Experiment name:", experiment_name)
print("Plots saved at:", Path_pictures)

In [None]:
"""
# oldstyle loading using json
#experiment_json = f'../experiments/MNIST01_torus_AEexp7.json'
experiment_json = f'../experiments/Swissroll_torus_AEexp86.json'
mydict = ricci_regularization.get_dataloaders_tuned_nn(Path_experiment_json=experiment_json)

torus_ae = mydict["tuned_neural_network"]
train_loader = mydict["train_loader"]
test_loader = mydict["test_loader"]
json_config = mydict["json_config"]
Path_pictures = json_config["Path_pictures"]
Path_experiments = json_config["Path_experiments"]
experiment_name = json_config["experiment_name"]
experiment_number = json_config["experiment_number"]
try:
    curv_w = json_config["losses"]["curv_w"]
except KeyError:
    curv_w = json_config["optimization_parameters"]["curv_w"]

dataset_name = json_config["dataset"]["name"]
D = json_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name == "MNIST" or dataset_name == "MNIST01" or dataset_name == "Synthetic":
    # k from the JSON configuration file is the number of classes
    k = json_config["dataset"]["parameters"]["k"]
    selected_labels = json_config["dataset"]["selected_labels"]

# DUMP ONLY REPORTING PARTS to include in the pdf report
try:
    keys2print = ['experiment_name','experiment_number','dataset',
 'architecture', 'optimization_parameters', 'losses', 'OOD_parameters', 'training_results_on_test_data']
    json_config2print = {key : json_config[key] for key in keys2print}
except KeyError:
    keys2print = ["experiment_name","experiment_number","dataset", "optimization_parameters"]
    json_config2print = {key : json_config[key] for key in keys2print}
print(json_config2print)
"""

# 1. Data and reconstruction loss scatterplots

In [None]:
"""
#inspiration for torus_ae.encoder2lifting
def circle2anglevectorized(zLatentTensor,Z_DIM = Z_DIM):
    cosphi = zLatentTensor[:, 0:Z_DIM]
    sinphi = zLatentTensor[:, Z_DIM:2*Z_DIM]
    phi = torch.acos(cosphi)*torch.sgn(torch.asin(sinphi))
    return phi
"""

In [None]:
# choose train or test loader
if use_test_data_for_plots == True:
    loader = test_loader
else:
    loader = train_loader
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
input_dataset_list = []
recon_dataset_list = []
#for (data, labels) in tqdm( test_loader, position=0 ):
for (data, labels) in tqdm( loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()
#assert torch.equal(enc,enc_tensor)

# latent \in [-1,1]. grid reparametrization for plotting
encoded_points_no_grad = encoded_points_no_grad/math.pi

In [None]:
plt.rcParams.update({'font.size': 20})
plt.figure(figsize=(9, 9),dpi=400)

if dataset_name == "Swissroll":
    plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap= 'jet')
else:
    plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap=ricci_regularization.discrete_cmap(k, 'jet'))
    #plt.colorbar(ticks=range(k))
plt.xticks([-1.,-0.5,0.,0.5,1.])
plt.yticks([-1.,-0.5,0.,0.5,1.])
plt.ylim(-1., 1.)
plt.xlim(-1., 1.)
plt.grid(True)
if violent_saving == True: 
    plt.savefig(f"{Path_pictures}/latent_space.pdf",format="pdf",bbox_inches='tight')
plt.savefig(f"{Path_pictures}/latent_space_{experiment_name}.jpg",bbox_inches='tight', format="jpeg")

In [None]:
if dataset_name == "MNIST" or dataset_name == "MNIST01":
    test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)
    p = ricci_regularization.plot_ae_outputs_selected(test_dataset=test_dataset, encoder=torus_ae.encoder2lifting,decoder=torus_ae.decoder_torus,selected_labels=selected_labels)
    if violent_saving == True:
        p.savefig(f"{Path_pictures}/recon_images.jpg",bbox_inches='tight', format="jpeg")

## Reconstruction loss computation

In [None]:
abs_error_tensor = input_dataset.view(-1,D) - recon_dataset
mse_array = abs_error_tensor.norm(dim=1).detach()
mse_array = mse_array**2/D
#torch.nn.functional.mse_loss(input_dataset.view(-1,D)[0],recon_dataset[0],reduction='mean')

## Metric losses computation

In [None]:
curvature_array = ricci_regularization.Sc_jacfwd_vmap(encoded_points,function=torus_ae.decoder_torus).detach()
metric_array = ricci_regularization.metric_jacfwd_vmap(encoded_points,function=torus_ae.decoder_torus).detach()
det_array = torch.det(metric_array)
trace_array = torch.einsum('jii->j',metric_array)

In [None]:
# latent \in [-\pi,\pi]. grid parameteres for evaluation.
latent = encoded_points_no_grad
left = latent[:,0].min()
right = latent[:,0].max()
bottom = latent[:,1].min()
top = latent[:,1].max()

xsize = right - left
ysize = top - bottom
xcenter = 0.5*(left + right)
ycenter = 0.5*(bottom + top)

In [None]:
linsize = 200 #200

grid_on_ls = ricci_regularization.make_grid(linsize,xsize=xsize,ysize=ysize,xcenter=xcenter,ycenter=ycenter)

grid_numpoints = grid_on_ls.shape[0]
bs = 4000
metric_det_list = []
metric_trace_list = []
curv_list = []
for i in range(grid_numpoints//bs):
    batch_of_grid = grid_on_ls[i*bs:(i+1)*bs]
    metric_on_batch_of_grid = ricci_regularization.metric_jacfwd_vmap(batch_of_grid,function=torus_ae.decoder_torus)
    metric_det_on_batch_of_grid = torch.det(metric_on_batch_of_grid)
    metric_trace_on_batch_of_grid = torch.func.vmap(torch.trace)(metric_on_batch_of_grid)
    curv_on_batch_of_grid = ricci_regularization.Sc_jacfwd_vmap(batch_of_grid, function = torus_ae.decoder_torus)
    metric_det_list.append(metric_det_on_batch_of_grid.tolist())
    metric_trace_list.append(metric_trace_on_batch_of_grid.tolist())
    curv_list.append(curv_on_batch_of_grid.tolist())
metric_det_on_grid = np.concatenate(metric_det_list)
metric_trace_on_grid = np.concatenate(metric_trace_list)
curv_on_the_grid = np.concatenate(curv_list)
"""
metric_on_grid = ricci_regularization.metric_jacfwd_vmap(grid_on_ls,function=torus_ae.decoder_torus)
metric_det_on_grid = torch.det(metric_on_grid)
metric_trace_on_grid = torch.func.vmap(torch.trace)(metric_on_grid)
curv_on_the_grid = ricci_regularization.Sc_jacfwd_vmap(grid_on_ls,device = torch.device("cpu"), function = torus_ae.decoder_torus)
"""

In [None]:
# latent \in [-1,1]. grid reparametrization for plotting
#encoded_points_no_grad = encoded_points_no_grad/math.pi
latent = encoded_points_no_grad
left = latent[:,0].min()
right = latent[:,0].max()
bottom = latent[:,1].min()
top = latent[:,1].max()

xsize = right - left
ysize = top - bottom
xcenter = 0.5*(left + right)
ycenter = 0.5*(bottom + top)

## Losses plotting

In [None]:
# (generate plot here)

plt.rcParams.update({'font.size': 16})

size_of_points = 20
fig, (ax00,ax0)= plt.subplots(ncols=2, nrows=1,figsize=(15,6),dpi=300)
# (ax3,ax4) can  be added

fig.tight_layout(pad=2.0)

ax00.title.set_text("AE latent space")
if dataset_name == "Synthetic" or dataset_name == "MNIST":
    p00 = ax00.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=color_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap=ricci_regularization.discrete_cmap(k, "jet"))
    fig.colorbar(p00,label="initial color", ticks=(np.arange(k)))    
else:
    p00 = ax00.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=color_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet')
    fig.colorbar(p00,label="initial color")
ax00.grid(True)
ax0.title.set_text("Reconstruction loss")
p0 = ax0.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=mse_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.LogNorm())
ax0.grid(True)
cb = fig.colorbar(p0,label="squared l2 norm errors")
#tick_locator = ticker.MaxNLocator(nbins=10)
#cb.locator = tick_locator
#cb.update_ticks()

if violent_saving == True:
    fig.savefig(f'{Path_pictures}/init_colors_recon_loss.pdf',bbox_inches='tight',format='pdf')
plt.show()

# 3. Histograms of curvature values over data and over a dense grid

In [None]:
plt.title(f"Histogram of curvature values over {len(curvature_array)} test data samples")
plt.hist(curvature_array, bins = math.ceil(math.sqrt(curvature_array.shape[0])))
plt.show()

In [None]:
plt.title("Histogram of curvature values over grid cells \n" + rf"of ${linsize}\times{linsize}$ grid test data samples")
plt.hist(curv_on_the_grid, bins = 200)
plt.show()

# 4. Automatically scaled heatmaps over the latent space and scatterplots of datasets $\{X_i\}$

In [None]:
#xcenter = 0.0 
#ycenter = 0.0
xshift = 0.0
yshift = 0.0
numticks = 5
if dataset_name == "Synthetic":
    tick_decimals = 2
else:
    tick_decimals = 1
plt.rcParams.update({'font.size': 16})

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15,12),dpi=300)

fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5*xsize, xcenter + 0.5*xsize, numticks) 
yticks = np.linspace(ycenter - 0.5*ysize, ycenter + 0.5*ysize, numticks)

xtick_labels = (xticks+xshift).tolist()
ytick_labels = (yticks+yshift).tolist()

xtick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels ]
ytick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks)*(linsize-1)

im1 = ax1.imshow(abs(curv_on_the_grid.reshape(linsize,linsize)),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.LogNorm())
fig.colorbar(im1,ax = ax1, shrink = 1, label = "curvature abs value")
ax1.set_title("Absolute value of scalar curvature")

im2 = ax2.imshow(curv_on_the_grid.reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=abs(0.01*curv_on_the_grid.mean()).item()))
fig.colorbar(im2,ax = ax2, shrink = 1, label = "curvature")
ax2.set_title("Scalar curvature")

im3 = ax3.imshow((np.sqrt(metric_det_on_grid)).reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im3,ax = ax3, shrink = 1, label = "$\sqrt{det(G)}$")
ax3.set_title("$\sqrt{det(G)}$")

im4 = ax4.imshow((0.5*(metric_trace_on_grid)).reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im4, ax = ax4, shrink = 1, label = "0.5$\cdot$tr(G)")
ax4.set_title("0.5$\cdot$tr(G)")

axs = (ax1, ax2, ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)

if violent_saving == True:
    plt.savefig(f'{Path_pictures}/heatmaps_not_scaled.pdf',bbox_inches='tight',format='pdf')
plt.show()

# 5. Scatterplots of data vs heatmaps over the whole latent space. Unique colorbar scaling

In [None]:
max_curvature = curv_on_the_grid.max().item()
min_curvature = curv_on_the_grid.min().item()
linthresh_curvature = 0.01*abs(curv_on_the_grid.mean()).item()
linthresh_curvature

max_abs_curvature = abs(curv_on_the_grid).max().item()
min_abs_curvature = 0.01*abs(curv_on_the_grid).mean().item()

In [None]:
# Create a 2x2 grid of subplots with specified figure size and resolution
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15, 12), dpi=300)

# Adjust layout to prevent overlap between subplots
fig.tight_layout(pad=2.0)

# Define x and y ticks for the plots
xticks = np.linspace(xcenter - 0.5 * xsize, xcenter + 0.5 * xsize, numticks)
yticks = np.linspace(ycenter - 0.5 * ysize, ycenter + 0.5 * ysize, numticks)

# Adjust ticks with shifts and format labels with specified decimal places
xtick_labels = (xticks + xshift).tolist()
ytick_labels = (yticks + yshift).tolist()
xtick_labels = ['%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels]
ytick_labels = ['%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

# Determine positions for ticks
ticks_places = np.linspace(0, 1, numticks) * (linsize - 1)

# Plot for ax1: Scatter plot of latent space colored by absolute value of curvature_array
ax1.title.set_text("Absolute value of scalar curvature")
p1 = ax1.scatter(latent[:, 0], latent[:, 1], c=abs(curvature_array),
                 alpha=1, s=size_of_points, marker='o',
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.LogNorm(vmin=min_abs_curvature,
                                                vmax=max_abs_curvature))
fig.colorbar(p1, label="curvature abs value")

# Plot for ax2: Image plot of absolute value of curv_on_the_grid reshaped to grid dimensions
ax2.title.set_text("Absolute value of scalar curvature overall")
im1 = ax2.imshow(abs(curv_on_the_grid.reshape(linsize, linsize)),
                 origin="lower", cmap="jet",
                 norm=matplotlib.colors.LogNorm(vmin=min_abs_curvature,
                                                 vmax=max_abs_curvature))
fig.colorbar(im1, ax=ax2, shrink=1, label="curvature abs value")
ax2.set_title("Absolute value of scalar curvature overall")

# Plot for ax3: Scatter plot of latent space colored by curvature_array
ax3.title.set_text("Scalar curvature")
p2 = ax3.scatter(latent[:, 0], latent[:, 1], c=curvature_array,
                 alpha=1, s=size_of_points, marker='o',
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                    vmin=min_curvature,
                                                    vmax=max_curvature))
fig.colorbar(p2, label="curvature")

# Plot for ax4: Image plot of curv_on_the_grid reshaped to grid dimensions
ax4.title.set_text("Scalar curvature overall")
im2 = ax4.imshow(curv_on_the_grid.reshape(linsize, linsize),
                 origin="lower", cmap="jet",
                 norm=matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                    vmin=min_curvature,
                                                    vmax=max_curvature))
fig.colorbar(im2, ax=ax4, shrink=1, label="curvature")
ax4.set_title("Scalar curvature overall")

# Adjust limits and ticks for ax1 and ax3
axs = (ax1, ax3)
for ax in axs:
    ax.set_ylim(bottom, top)
    ax.set_xlim(left, right)
    ax.set_xticks(list(map(float, xtick_labels)), labels=xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels=ytick_labels)

# Adjust ticks for ax2 and ax4
axs = (ax2, ax4)
for ax in axs:
    ax.set_xticks(ticks_places, labels=xtick_labels)
    ax.set_yticks(ticks_places, labels=ytick_labels)

# Save the figure if violent_saving is set to True
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/curvature_heatmaps.pdf', bbox_inches='tight', format='pdf')

# Display the plot
plt.show()


## Metric losses: 

In [None]:
# Create a 2x2 grid of subplots with specified figure size and resolution
fig, ((ax1, ax3), (ax2, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15, 12), dpi=300)

# Adjust layout to prevent overlap between subplots
fig.tight_layout(pad=2.0)

# Plot for ax1: Scatter plot of latent space colored by the square root of the determinant of the metric tensor
ax1.title.set_text("$\sqrt{det(G)}$")
p = ax1.scatter(latent[:, 0], latent[:, 1],
                c=torch.sqrt(abs(det_array)), alpha=1, s=size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=np.sqrt(metric_det_on_grid.max()))
fig.colorbar(p, label="$\sqrt{det(G)}$")

# Plot for ax2: Scatter plot of latent space colored by half the trace of the metric tensor
ax2.title.set_text("0.5$\cdot$tr(G)")
q = ax2.scatter(latent[:, 0], latent[:, 1], 
                c=0.5 * trace_array, alpha=1, s=size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=0.5 * metric_trace_on_grid.max().item())
fig.colorbar(q, label="0.5$\cdot$tr(G)")

# Plot for ax3: Image plot of the square root of the determinant of the metric tensor grid
im3 = ax3.imshow((np.sqrt(metric_det_on_grid)).reshape(linsize, linsize),
                 origin="lower", cmap="jet", norm=None)
fig.colorbar(im3, ax=ax3, shrink=1, label="$\sqrt{det(G)}$")
ax3.set_title("$\sqrt{det(G)}$")

# Plot for ax4: Image plot of half the trace of the metric tensor grid
im4 = ax4.imshow((0.5 * metric_trace_on_grid).reshape(linsize, linsize),
                 origin="lower", cmap="jet", norm=None,
                 vmax=0.5 * metric_trace_on_grid.max().item())
fig.colorbar(im4, ax=ax4, shrink=1, label="0.5$\cdot$tr(G)")
ax4.set_title("0.5$\cdot$tr(G)")

# Setting ticks and labels for image plots (ax3 and ax4)
axs = (ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places, labels=xtick_labels)
    ax.set_yticks(ticks_places, labels=ytick_labels)

# Setting limits and ticks for scatter plots (ax1 and ax2)
axs = (ax1, ax2)
for ax in axs:
    ax.set_ylim(bottom, top)
    ax.set_xlim(left, right)
    ax.set_xticks(list(map(float, xtick_labels)), labels=xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels=ytick_labels)

# Save the figure if violent_saving is set to True
if violent_saving == True:
    # plt.savefig(f'{Path_pictures}/metric_det_trace.eps', bbox_inches='tight', format='eps')
    plt.savefig(f'{Path_pictures}/metric_det_trace.pdf', bbox_inches='tight', format='pdf')

# Display the figure
plt.show()


# 6. Jacobian of the encoder and decoder Frobenius norms

In [None]:
validation_set_size_limit = 4000
metric_array_encoder = ricci_regularization.metric_jacrev_vmap(input_dataset[:validation_set_size_limit],function=torus_ae.encoder2lifting,latent_space_dim=D).detach()
trace_array_encoder = torch.einsum('jii->j',metric_array_encoder)

In [None]:
fig, axes = plt.subplots(ncols=2,nrows=1, figsize=(15,6))
p0 = axes[1].scatter( latent[:,0], latent[:,1],
                c=trace_array, alpha=1, s = size_of_points, 
                marker='o', edgecolor='none', cmap='jet', norm= matplotlib.colors.LogNorm())
cb0 = plt.colorbar(p0, label=r"$\|\nabla \Psi \|_F = \mathrm{tr} (G) $")
axes[1].set_title("Jacobian of the decoder")

p1 = axes[0].scatter( latent[:validation_set_size_limit,0], latent[:validation_set_size_limit,1],
                c=trace_array_encoder, alpha=1, s = size_of_points, 
                marker='o', edgecolor='none', cmap='jet',norm= matplotlib.colors.LogNorm())
cb1 = plt.colorbar(p1, label=r"$\|\nabla \Phi \|_F$")
axes[0].set_title("Jacobian of the encoder")
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/jac_norms_encoder_decoder.pdf',bbox_inches='tight',format='pdf')
plt.show()

# 7. Merge pdfs

In [None]:
if build_report == True:
    #pdfs = [f"{Path_pictures}/hyperparameters_exp{experiment_number}.pdf",f'{Path_pictures}/losses_exp{experiment_number}.pdf',f'{Path_pictures}/init_colors_recon_loss.pdf', f'{Path_pictures}/curvature_heatmaps.pdf', f'{Path_pictures}/metric_det_trace.pdf', f'{Path_pictures}/jac_norms_encoder_decoder.pdf']
    pdfs = [f'{Path_pictures}/losses.pdf',f'{Path_pictures}/init_colors_recon_loss.pdf', f'{Path_pictures}/curvature_heatmaps.pdf', f'{Path_pictures}/metric_det_trace.pdf', f'{Path_pictures}/jac_norms_encoder_decoder.pdf']
    merger = PdfWriter()

    for pdf in pdfs:
        merger.append(pdf)

    merger.write(f"{Path_pictures}/report_{experiment_name}.pdf")
    merger.close()
