Prerequisite: weights and architecture of a pre-trained AE.

This notebook builds a report for pretrained torus AE.

One can run this notebook only after running the AE_torus_training.ipynb

NB! All plots are only suitable for dimension $d=2$ of the latent space .
All plots are saved if 'violent_saving' == True.

This notebook contains:

0) Loading data and nn weights
1) Latent space scatter plots colored by reconstruction loss and encoder's Jacobian Frobenius norm
2) Histograms of scalar curvature $R$ and $R^2\sqrt{\det G}$ values over data and over a dense grid
3) Automatically scaled heatmaps over the latent space and scatterplots of datasets $\{X_i\}$ for :
    1) absolute value of scalar curvature $|R(X_i)|$
    2) scalar curvature $R(X_i)$
    3) square root of metric matrix determinant $\sqrt{\mathrm{det}G(X_i)}$ 
    4) half trace of metric matrix: $0.5 \cdot \mathrm{tr}G(X_i)$
4) Scatterplots of data vs heatmaps over the whole latent space. Unique colorbar scaling:
    1) absolute value of scalar curvature $|R(X_i)|$
    2) scalar curvature $R(X_i)$
    3) square root of metric matrix determinant $\sqrt{\mathrm{det}G(X_i)}$ 
    4) half trace of metric matrix: $0.5 \cdot \mathrm{tr}G(X_i)$
5) Merge pdfs: All plots are mereged into a single report in pdf format if 'build_report' == True. 

In [None]:
# prerequisites
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
from tqdm.notebook import tqdm
from pypdf import PdfWriter
import matplotlib
import ricci_regularization
import yaml
import os

from sklearn import datasets
from torchvision import datasets, transforms

In [None]:
# working modes
violent_saving = True # if False it will not save plots
build_report = True
use_test_data_for_plots = True
normalize_to_unit_square = True # all latent space plots are rescaled to fit into unit square

In [None]:
# checking Cuda availability 
print("CUDA available:", torch.cuda.is_available())
print("Torch CUDA version:", torch.version.cuda)
print("Current device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

# 0. Loading data and nn weights

In [None]:
# Choose setting of the experiment
AE_setting_name = 'MNIST_Setting_3'
#AE_setting_name = 'Swissroll_Setting_1'
# Open and read the YAML configuration file
with open(f'../experiments/{AE_setting_name}_config.yaml','r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

# Print the loaded YAML configuration
print(f"YAML Configuration loaded successfully from \n: {yaml_file.name}")

In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
#test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config)

print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)

In [None]:
Path_pictures = "../experiments/" + AE_setting_name
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name == "MNIST_subset":
    # k from the configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
elif dataset_name == "MNIST":
    selected_labels = np.arange(10)
    k = 10
elif dataset_name == "Synthetic":
    k = yaml_config["dataset"]["k"] 
print("Experiment name:", AE_setting_name)
print("Plots saved at:", Path_pictures)

# 1. Latent space scatter plots colored by reconstruction and contractive loss

In [None]:
# choose train or test loader
if use_test_data_for_plots == True:
    loader = test_loader
else:
    loader = train_loader
torus_ae.cpu()
# collect data for plots
colorlist = []
enc_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    enc_list.append(torus_ae.encoder_to_lifting(data.view(-1,D)))
    colorlist.append(labels) 

# concatenate input, reconstructed and encoded points and their colors
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

In [None]:
# plot test data manifold plot
plt.rcParams.update({'font.size': 20})
fig = ricci_regularization.point_plot(torus_ae.encoder_to_lifting,test_loader,0,yaml_config, figsize=(9,7))
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/latent_space_{AE_setting_name}_test_data.pdf",format="pdf",bbox_inches='tight')
plt.show()

In [None]:
# plot train data manifold plot
plt.rcParams.update({'font.size': 20})
fig = ricci_regularization.point_plot(torus_ae.encoder_to_lifting,train_loader,0,yaml_config, figsize=(9,7),show_title=False)
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/latent_space_{AE_setting_name}_train_data.pdf",format="pdf",bbox_inches='tight')
plt.show()

In [None]:
# plot reconstruction of digits (only for MNIST)
if dataset_name in ["MNIST", "MNIST_subset"]:
    if dataset_name == "MNIST":
        selected_labels = torch.arange(10).tolist()
    else:
        selected_labels = yaml_config["dataset"]["selected_labels"]
    test_dataset  = datasets.MNIST(root='../../datasets/', train=False, transform=transforms.ToTensor(), download=False)

    axes = ricci_regularization.plot_ae_outputs_selected(test_dataset=test_dataset, encoder=torus_ae.encoder_to_lifting,decoder=torus_ae.decoder_torus,selected_labels=selected_labels)
    p = axes.get_figure()
    if violent_saving == True:
        p.savefig(f"{Path_pictures}/recon_images.pdf",bbox_inches='tight', format="pdf")
plt.show()

## Reconstruction loss computation

In [None]:
abs_error_tensor = input_dataset.view(-1,D) - recon_dataset
mse_array = abs_error_tensor.norm(dim=1).detach()
mse_array = mse_array**2/D

## Metric losses computation

In [None]:
# computing curvature, metric, determinant and trace of metric on the data points
curvature_array, metric_array = ricci_regularization.curvature_loss_jacfwd(encoded_points,function=torus_ae.decoder_torus,
                                                                           eps=0., reduction="curvature_metric")
metric_array = metric_array.detach()
curvature_array = curvature_array.detach()
metric_det_array = torch.det(metric_array)
metric_trace_array = torch.einsum('jii->j',metric_array)

In [None]:
# same computation on a dense square grid
linsize = 200 # grid nodes per line (and per row)

grid_on_ls = ricci_regularization.make_grid(linsize,
                                            xsize=2*torch.pi,
                                            ysize=2*torch.pi,
                                            xcenter=0.,
                                            ycenter=0.)

grid_numpoints = grid_on_ls.shape[0]
# computationas are done on batches (parts of the grid) to avoid kernel exploding
bs = 4000
metric_det_list = []
metric_trace_list = []
curv_list = []
for i in range(grid_numpoints//bs):
    batch_of_grid = grid_on_ls[i*bs:(i+1)*bs]
    #computing metric and curvature on the batch of grid's points
    curv_on_batch_of_grid, metric_on_batch_of_grid = ricci_regularization.ricci_regularization.curvature_loss_jacfwd(batch_of_grid,
                                                                        function=torus_ae.decoder_torus,
                                                                        eps=0., reduction="curvature_metric")
    #detaching gradients
    metric_on_batch_of_grid = metric_on_batch_of_grid.detach()
    curv_on_batch_of_grid =curv_on_batch_of_grid.detach()
    #metric_on_batch_of_grid = ricci_regularization.metric_jacfwd_vmap(batch_of_grid,function=torus_ae.decoder_torus)
    metric_det_on_batch_of_grid = torch.det(metric_on_batch_of_grid)
    metric_trace_on_batch_of_grid = torch.func.vmap(torch.trace)(metric_on_batch_of_grid)
     
    # piling up a list
    metric_det_list.append(metric_det_on_batch_of_grid.tolist())
    metric_trace_list.append(metric_trace_on_batch_of_grid.tolist())
    curv_list.append(curv_on_batch_of_grid.tolist())
metric_det_on_grid = np.concatenate(metric_det_list)
metric_trace_on_grid = np.concatenate(metric_trace_list)
curv_on_the_grid = np.concatenate(curv_list)

In [None]:
# rescaling for plotting
if normalize_to_unit_square == True:
    # latent \in [-1,1]. grid reparametrization for plotting
    encoded_points_no_grad = encoded_points_no_grad/math.pi
    # choosing boxes for plots
    left = -1.
    right = 1.
    bottom = -1.
    top = 1.
else:
    left = - torch.pi
    right = torch.pi
    bottom = - torch.pi
    top = torch.pi
# end if
xsize = right - left
ysize = top - bottom
xcenter = 0.5*(left + right)
ycenter = 0.5*(bottom + top)

In [None]:
# Point plot: colored by labels and by reconstruction loss

plt.rcParams.update({'font.size': 16})

size_of_points = 20
fig, (ax00,ax0)= plt.subplots(ncols=2, nrows=1,figsize=(15,6),dpi=300)
# (ax3,ax4) can  be added

fig.tight_layout(pad=2.0)

ax00.title.set_text("AE latent space")
if dataset_name == "Synthetic" or dataset_name == "MNIST":
    p00 = ax00.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=color_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap=ricci_regularization.discrete_cmap(k, "jet"))
    fig.colorbar(p00,label="initial color", ticks=(np.arange(k)))    
else:
    p00 = ax00.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=color_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet')
    fig.colorbar(p00,label="initial color")
ax00.grid(True)
ax0.title.set_text("Reconstruction loss")
p0 = ax0.scatter( encoded_points_no_grad[:,0], encoded_points_no_grad[:,1], c=mse_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.LogNorm())
ax0.grid(True)
cb = fig.colorbar(p0,label="squared l2 norm errors")
#tick_locator = ticker.MaxNLocator(nbins=10)
#cb.locator = tick_locator
#cb.update_ticks()

if violent_saving == True:
    fig.savefig(f'{Path_pictures}/init_colors_recon_loss.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
# Frobenius norm of the encoder's Jacobian plot (for contractive loss)
encoder_jac_norm = ricci_regularization.Jacobian_norm_jacrev_vmap( input_dataset, 
                                function = torus_ae.encoder_torus,
                                input_dim = yaml_config["architecture"]["input_dim"] )
encoder_jac_norm_no_grad = encoder_jac_norm.detach()

import matplotlib.ticker as ticker

# Example data
vmin, vmax = encoder_jac_norm_no_grad.min(), encoder_jac_norm_no_grad.max()  # Define min and max values
plt.figure(figsize=(9, 9), dpi=200)
plt.rcParams.update({'font.size': 12})
# Create a figure and colorbar
fig, ax = plt.subplots()

vmin = min(encoder_jac_norm_no_grad).numpy()
vmax = max(encoder_jac_norm_no_grad).numpy()
sc = plt.scatter(encoded_points_no_grad[:, 0], encoded_points_no_grad[:, 1], 
        c=encoder_jac_norm_no_grad, marker='o', edgecolor='none', s = size_of_points, 
        cmap='jet', norm=matplotlib.colors.LogNorm(vmin=vmin,
                                                 vmax= vmax))
"""
plt.xticks([-3., -2., -1., 0., 1., 2., 3.])
plt.yticks([-3., -2., -1., 0., 1., 2., 3.])
"""
plt.xlim(left, right)
plt.ylim(bottom, top)

cbar = fig.colorbar(sc, ax=ax)
plt.title("Jacobian of the encoder Frobenius norm")
# Define custom ticks
tick_places = 10 ** np.linspace(np.log10(vmin), np.log10(vmax), num=4)
cbar.set_ticks(tick_places)


# using FormatStrFormatter (scientific notation)
cbar.ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.1e'))
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/encoder_Frobenius_norm_{AE_setting_name}.pdf",format="pdf",bbox_inches='tight')
plt.show()

# 2. Histograms of curvature values over data and over a dense grid

In [None]:
plt.title(f"Histogram of curvature values over {len(curvature_array)} test data samples")
plt.hist(curvature_array, bins = math.ceil(math.sqrt(curvature_array.shape[0])))
plt.show()

In [None]:
plt.title("Histogram of curvature values over grid cells \n" + rf"of ${linsize}\times{linsize}$ grid test data samples")
plt.hist(curv_on_the_grid, bins = 200)
plt.show()

In [None]:
# histogram of curvature functional on data (without outlyers)
Curvature_functional_on_data = np.square(curvature_array) * np.sqrt(metric_det_array)

ricci_regularization.PlottingTools.plot_histogram(
    data=Curvature_functional_on_data,
    title=r"$R^2 (Y_i) \sqrt{\det(g(Y_i))}$ values on test data (without outliers)",
    percentile_bounds=(1, 99),
    Path_pictures=Path_pictures,
    filename="Curvature_functional_values_on_data_histogram.pdf",
    violent_saving=True
)

In [None]:
# histogram of curvature functional on the dense grid (without outlyers)
Curvature_functional_on_grid = np.square(curv_on_the_grid) * np.sqrt(metric_det_on_grid)

ricci_regularization.PlottingTools.plot_histogram(
    data=Curvature_functional_on_data,
    title=r"$R^2 (Y) \sqrt{\det(g(Y))}$ values (without outlyers)\n" + rf" on grid of size ${linsize}\times{linsize}$",
    Path_pictures=Path_pictures,
    percentile_bounds=(1, 99),
    filename="Curvature_functional_values_on_grid_histogram.pdf",
    violent_saving=True
)

# 3. Automatically scaled heatmaps over the latent space and scatterplots of datasets $\{X_i\}$

In [None]:
#xcenter = 0.0 
#ycenter = 0.0
xshift = 0.0
yshift = 0.0
numticks = 5
if dataset_name == "Synthetic":
    tick_decimals = 2
else:
    tick_decimals = 1
plt.rcParams.update({'font.size': 16})

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15,12),dpi=300)

fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5*xsize, xcenter + 0.5*xsize, numticks) 
yticks = np.linspace(ycenter - 0.5*ysize, ycenter + 0.5*ysize, numticks)

xtick_labels = (xticks+xshift).tolist()
ytick_labels = (yticks+yshift).tolist()

xtick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels ]
ytick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks)*(linsize-1)

im1 = ax1.imshow(abs(curv_on_the_grid.reshape(linsize,linsize)),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.LogNorm())
fig.colorbar(im1,ax = ax1, shrink = 1, label = "curvature abs value")
ax1.set_title("Absolute value of scalar curvature")

im2 = ax2.imshow(curv_on_the_grid.reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=abs(0.01*curv_on_the_grid.mean()).item()))
fig.colorbar(im2,ax = ax2, shrink = 1, label = "curvature")
ax2.set_title("Scalar curvature")

im3 = ax3.imshow((np.sqrt(metric_det_on_grid)).reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im3,ax = ax3, shrink = 1, label = r"$\sqrt{det(G)}$")
ax3.set_title(r"$\sqrt{det(G)}$")

im4 = ax4.imshow((0.5*(metric_trace_on_grid)).reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im4, ax = ax4, shrink = 1, label = r"0.5$\cdot$tr(G)")
ax4.set_title(r"0.5$\cdot$tr(G)")

axs = (ax1, ax2, ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)

if violent_saving == True:
    plt.savefig(f'{Path_pictures}/heatmaps_not_scaled.pdf',bbox_inches='tight',format='pdf')
plt.show()

Curvature heatmap only (for latex)

In [None]:
#plt.rcParams.update({'font.size': 20})
#plt.figure(figsize=(9, 9),dpi=200)

fig, ax = plt.subplots(figsize=(9, 9),dpi=200)
fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5 * xsize, xcenter + 0.5 * xsize, numticks)
yticks = np.linspace(ycenter - 0.5 * ysize, ycenter + 0.5 * ysize, numticks)

xtick_labels = (xticks + xshift).tolist()
ytick_labels = (yticks + yshift).tolist()

xtick_labels = ['%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels]
ytick_labels = ['%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks) * (linsize - 1)

# Plot scalar curvature heatmap
im = ax.imshow(curv_on_the_grid.reshape(linsize, linsize),
               origin="lower", cmap="jet",
               norm=matplotlib.colors.SymLogNorm(linthresh=abs(0.01 * curv_on_the_grid.mean()).item()))

# Add color bar
fig.colorbar(im, ax=ax, shrink=0.8 ) #, label="curvature")
#ax.set_title("Scalar curvature")

# Set tick positions and labels
ax.set_xticks(ticks_places, labels=xtick_labels)
ax.set_yticks(ticks_places, labels=ytick_labels)

# Optionally save the figure
if violent_saving:
    plt.savefig(f'{Path_pictures}/scalar_curvature_heatmap_{AE_setting_name}.pdf', bbox_inches='tight', format='pdf')
plt.show()


# 4. Scatterplots of data vs heatmaps over the whole latent space. Unique colorbar scaling

In [None]:
# Find the maximum curvature value on the grid (scalar, not tensor/array)
max_curvature = curv_on_the_grid.max().item()

# Find the minimum curvature value on the grid (scalar)
min_curvature = curv_on_the_grid.min().item()

# Define a "linear threshold" for plotting purposes: 
# 1% of the absolute mean curvature value
# Values of curvature close to zero (within ~1% of the mean curvature’s magnitude) 
# will be plotted linearly.
linthresh_curvature = 0.01 * abs(curv_on_the_grid.mean()).item()

# Find the maximum absolute curvature value on the grid
max_abs_curvature = abs(curv_on_the_grid).max().item()

# Define a minimum absolute curvature scale:
# 1% of the mean absolute curvature value
min_abs_curvature = 0.01 * abs(curv_on_the_grid).mean().item()


In [None]:
# Scalar curvature and its absolute value point plots and heatmaps 
# Create a 2x2 grid of subplots with specified figure size and resolution
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15, 12), dpi=300)

# Adjust layout to prevent overlap between subplots
fig.tight_layout(pad=2.0)

# Define x and y ticks for the plots
xticks = np.linspace(xcenter - 0.5 * xsize, xcenter + 0.5 * xsize, numticks)
yticks = np.linspace(ycenter - 0.5 * ysize, ycenter + 0.5 * ysize, numticks)

# Adjust ticks with shifts and format labels with specified decimal places
xtick_labels = (xticks + xshift).tolist()
ytick_labels = (yticks + yshift).tolist()
xtick_labels = ['%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels]
ytick_labels = ['%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

# Determine positions for ticks
ticks_places = np.linspace(0, 1, numticks) * (linsize - 1)

# Plot for ax1: Scatter plot of latent space colored by absolute value of curvature_array
ax1.title.set_text("Absolute value of scalar curvature")
p1 = ax1.scatter(encoded_points_no_grad[:, 0], encoded_points_no_grad[:, 1], c=abs(curvature_array),
                 alpha=1, s=size_of_points, marker='o',
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.LogNorm(vmin=min_abs_curvature,
                                                vmax=max_abs_curvature))
fig.colorbar(p1, label="curvature abs value")

# Plot for ax2: Image plot of absolute value of curv_on_the_grid reshaped to grid dimensions
ax2.title.set_text("Absolute value of scalar curvature overall")
im1 = ax2.imshow(abs(curv_on_the_grid.reshape(linsize, linsize)),
                 origin="lower", cmap="jet",
                 norm=matplotlib.colors.LogNorm(vmin=min_abs_curvature,
                                                 vmax=max_abs_curvature))
fig.colorbar(im1, ax=ax2, shrink=1, label="curvature abs value")
ax2.set_title("Absolute value of scalar curvature overall")

# Plot for ax3: Scatter plot of latent space colored by curvature_array
ax3.title.set_text("Scalar curvature")
p2 = ax3.scatter(encoded_points_no_grad[:, 0], encoded_points_no_grad[:, 1], c=curvature_array,
                 alpha=1, s=size_of_points, marker='o',
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                    vmin=min_curvature,
                                                    vmax=max_curvature))
fig.colorbar(p2, label="curvature")

# Plot for ax4: Image plot of curv_on_the_grid reshaped to grid dimensions
ax4.title.set_text("Scalar curvature overall")
im2 = ax4.imshow(curv_on_the_grid.reshape(linsize, linsize),
                 origin="lower", cmap="jet",
                 norm=matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                    vmin=min_curvature,
                                                    vmax=max_curvature))
fig.colorbar(im2, ax=ax4, shrink=1, label="curvature")
ax4.set_title("Scalar curvature overall")

# Adjust limits and ticks for ax1 and ax3
axs = (ax1, ax3)
for ax in axs:
    ax.set_ylim(bottom, top)
    ax.set_xlim(left, right)
    ax.set_xticks(list(map(float, xtick_labels)), labels=xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels=ytick_labels)

# Adjust ticks for ax2 and ax4
axs = (ax2, ax4)
for ax in axs:
    ax.set_xticks(ticks_places, labels=xtick_labels)
    ax.set_yticks(ticks_places, labels=ytick_labels)

# Save the figure if violent_saving is set to True
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/curvature_heatmaps.pdf', bbox_inches='tight', format='pdf')

# Display the plot
plt.show()


For chapter 5

In [None]:
# ONLY for latex chapter 5
# Plot: Scatter plot of latent space colored by absolute value of curvature_array
plt.figure(figsize=(6, 6), dpi=100)
plt.scatter(encoded_points_no_grad[:, 0], encoded_points_no_grad[:, 1], c=abs(curvature_array),
            alpha=1, s=size_of_points, marker='o',
            edgecolor='none', cmap='jet',
            norm=matplotlib.colors.LogNorm(vmin=min_abs_curvature,
                                           vmax=max_abs_curvature))
plt.xticks([])
plt.yticks([])
plt.xlim(-1,1)
plt.ylim(-1,1)
if violent_saving:
    plt.savefig(f'{Path_pictures}/manifold_plot_colored_by_curvature.pdf', bbox_inches='tight', format='pdf')
plt.show()


In [None]:
# ONLY for latex chapter 5
# Plot: Image plot of absolute value of curv_on_the_grid reshaped to grid dimensions
plt.figure(figsize=(7.5, 6), dpi=100)
plt.imshow(abs(curv_on_the_grid.reshape(linsize, linsize)),
           origin="lower", cmap="jet",
           norm=matplotlib.colors.LogNorm(vmin=min_abs_curvature,
                                           vmax=max_abs_curvature))
plt.colorbar(label="curvature absolute value")
plt.xticks([])
plt.yticks([])
if violent_saving:
    plt.savefig(f'{Path_pictures}/curvature_heatmap.pdf', bbox_inches='tight', format='pdf')
plt.show()

## Metric losses: 

In [None]:
# Metric sqrt of the determinant and half-trace: point6plot of test data and heatmap
# Create a 2x2 grid of subplots with specified figure size and resolution
fig, ((ax1, ax3), (ax2, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15, 12), dpi=300)

# Adjust layout to prevent overlap between subplots
fig.tight_layout(pad=2.0)

# Plot for ax1: Scatter plot of latent space colored by the square root of the determinant of the metric tensor
ax1.title.set_text(r"$\sqrt{det(G)}$")
p = ax1.scatter(encoded_points_no_grad[:, 0], encoded_points_no_grad[:, 1],
                c=torch.sqrt(abs(metric_det_array)), alpha=1, s=size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=np.sqrt(metric_det_on_grid.max()))
fig.colorbar(p, label=r"$\sqrt{det(G)}$")

# Plot for ax2: Scatter plot of latent space colored by half the trace of the metric tensor
ax2.title.set_text(r"0.5$\cdot$tr(G)")
q = ax2.scatter(encoded_points_no_grad[:, 0], encoded_points_no_grad[:, 1], 
                c=0.5 * metric_trace_array, alpha=1, s=size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=0.5 * metric_trace_on_grid.max().item())
fig.colorbar(q, label=r"0.5$\cdot$tr(G)")

# Plot for ax3: Image plot of the square root of the determinant of the metric tensor grid
im3 = ax3.imshow((np.sqrt(metric_det_on_grid)).reshape(linsize, linsize),
                 origin="lower", cmap="jet", norm=None)
fig.colorbar(im3, ax=ax3, shrink=1, label=r"$\sqrt{det(G)}$")
ax3.set_title(r"$\sqrt{det(G)}$")

# Plot for ax4: Image plot of half the trace of the metric tensor grid
im4 = ax4.imshow((0.5 * metric_trace_on_grid).reshape(linsize, linsize),
                 origin="lower", cmap="jet", norm=None,
                 vmax=0.5 * metric_trace_on_grid.max().item())
fig.colorbar(im4, ax=ax4, shrink=1, label=r"0.5$\cdot$tr(G)")
ax4.set_title(r"0.5$\cdot$tr(G)")

# Setting ticks and labels for image plots (ax3 and ax4)
axs = (ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places, labels=xtick_labels)
    ax.set_yticks(ticks_places, labels=ytick_labels)

# Setting limits and ticks for scatter plots (ax1 and ax2)
axs = (ax1, ax2)
for ax in axs:
    ax.set_ylim(bottom, top)
    ax.set_xlim(left, right)
    ax.set_xticks(list(map(float, xtick_labels)), labels=xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels=ytick_labels)

# Save the figure if violent_saving is set to True
if violent_saving == True:
    # plt.savefig(f'{Path_pictures}/metric_det_trace.eps', bbox_inches='tight', format='eps')
    plt.savefig(f'{Path_pictures}/metric_det_trace.pdf', bbox_inches='tight', format='pdf')

# Display the figure
plt.show()

# 5. Merge pdfs

In [None]:
if build_report == True:
    #pdfs = [f"{Path_pictures}/hyperparameters_exp{experiment_number}.pdf",f'{Path_pictures}/losses_exp{experiment_number}.pdf',f'{Path_pictures}/init_colors_recon_loss.pdf', f'{Path_pictures}/curvature_heatmaps.pdf', f'{Path_pictures}/metric_det_trace.pdf', f'{Path_pictures}/jac_norms_encoder_decoder.pdf']
    pdfs = [
            f'{Path_pictures}/recon_images.pdf',
            f'{Path_pictures}/losses.pdf',
            f'{Path_pictures}/init_colors_recon_loss.pdf', 
            f'{Path_pictures}/curvature_heatmaps.pdf', 
            f'{Path_pictures}/metric_det_trace.pdf', 
            f'{Path_pictures}/encoder_Frobenius_norm_{AE_setting_name}.pdf',
            f'{Path_pictures}/Curvature_functional_values_on_data_histogram.pdf']
    merger = PdfWriter()

    for pdf in pdfs:
        merger.append(pdf)

    merger.write(f"{Path_pictures}/report_{AE_setting_name}.pdf")
    merger.close()