# Imports

In [None]:
# prerequisites
%matplotlib inline
import torch
from torch.func import jacrev,jacfwd
import sklearn
from sklearn import datasets
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import math
import numpy as np
from tqdm.notebook import tqdm
import json
import os
import torch.func as TF

## Hyperparameters

In [None]:
#set_name = "Swissroll"
#set_name = "Synthetic"
set_name = "MNIST"

experiment_name = f"{set_name}_torus_AE"
experiment_number = 34
violent_saving = True # if False it will not save plots
Path_experiments = "../experiments/"
Path_pictures = f"../experiments/{experiment_name}/experiment{experiment_number}"
if violent_saving == True:
    if os.path.exists(Path_pictures) == False:
        if os.path.exists(f"../experiments/{experiment_name}/") == False:
            os.mkdir(f"../experiments/{experiment_name}/")
        os.mkdir(Path_pictures) # needs to be commented once the folder for plots is created
Path_weights = "../nn_weights/"

d = 2         # latent space dimension
weights_loaded = False
weights_saved = True

In [None]:
architecture_type = "TorusAE" # 
#architecture_type = "TorusConvAE" #

mse_w = 1.0 # 1e1
unif_w = 0.5e-3#1e-3 #0.5 #5e0 # 4e1 for mnist
num_moments = 4
lambda_contractive_decoder = 0.
delta_decoder = 2 # \delta
#djnpm = "mean" #decoder_jac_norm_penalization_mode 

lambda_contractive_encoder = 0.#1e-4 #1e-5 #10.
delta_encoder = 0. #1e+1
#ejnpm = "max" #encoder_jac_norm_penalization_mode

eps = 0.0

diagnostic_mode = True
compute_curvature = True
curvature_penalization_mode = "mean" #"mean"

if compute_curvature == True:
    curv_w =  0.1 
else:
    curv_w = 0.

delta_curv = 0.1
    
OOD_regime = False

### Define an optimizer (both for the encoder and the decoder!)
lr         = 1e-3
num_epochs = 40
#num_batches = 15000

#curv_w_increase_rate = 4/num_epochs

# Hyperparameters for data loaders
batch_size  = 128 # was 256 for MNIST
split_ratio = 0.2
weight_decay = 0.
random_shuffling = False
random_seed = 0
Force_CPU = False

# Set manual seed for reproducibility
torch.manual_seed(random_seed)

In [None]:
# OOD sampling parameters
T_ood = 20 # 100 # period of OOD penalization
n_ood = 5 # number of OOD samples per point
sigma_ood = 2e-1 # sigma of OOD Gaussian samples: 2e-1 swissroll
N_extr = 16 # 32 batch size of extremal curvature points
r_ood = 1e-3 # 1e-2 decay factor
OOD_w = curv_w
start_ood = 1

## Set uploading 

In [None]:
#import sys
#sys.path.append('../') # have to go 1 level up
import ricci_regularization 

In [None]:
# Number of workers in DataLoader
num_workers = 11

if set_name == "MNIST":
    D = 784
    k = 10 # number of classes
    #MNIST_SIZE = 28
    # MNIST Dataset
    train_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)

    set_parameters = {"k" : k}
elif set_name == "Synthetic":
    D = 784       #dimension
    k = 3         # num of 2d planes in dim D
    n = 6*(10**3) # num of points in each plane
    shift_class = 0.0
    var_class = 1.0
    intercl_var = 0.1 # this has to be greater than 0.04
    # this creates a gaussian, 
    # i.e.random shift 
    # proportional to the value of intercl_var
    # Generate dataset
    # via classes
    torch.manual_seed(0) # reproducibility
    my_dataset = ricci_regularization.SyntheticDataset(k=k,n=n,d=d,D=D,
                                        shift_class=shift_class, intercl_var=intercl_var, var_class=var_class)

    train_dataset = my_dataset.create
    set_parameters = {
    "k" : k,
    "n" : n,
    "shift_class" : shift_class,
    "var_class" : var_class,
    "intercl_var" : intercl_var
    }
elif set_name == "Swissroll":
    D = 3
    sr_noise = 1e-6
    sr_numpoints = 18000 #k*n
    train_dataset =  sklearn.datasets.make_swiss_roll(n_samples=sr_numpoints, noise=sr_noise,random_state=random_seed)
    sr_points = torch.from_numpy(train_dataset[0]).to(torch.float32)
    #sr_points = torch.cat((sr_points,torch.zeros(sr_numpoints,D-3)),dim=1)
    sr_colors = torch.from_numpy(train_dataset[1]).to(torch.float32)
    from torch.utils.data import TensorDataset
    train_dataset = TensorDataset(sr_points,sr_colors)
    set_parameters = {
    "sr_noise" : sr_noise,
    "sr_numpoints" : sr_numpoints
    }

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=random_shuffling) # was true
batches_per_epoch = len(train_loader)
#start_ood = batches_per_epoch

# Architecture

In [None]:
torch.manual_seed(random_seed)
if architecture_type == "TorusConvAE":
    torus_ae = ricci_regularization.Architectures.TorusConvAE(x_dim=D, h_dim1= 512, h_dim2=256, z_dim=d,pixels=28)
else:
    torus_ae = ricci_regularization.Architectures.TorusAE(x_dim=D, h_dim1= 512, h_dim2=256, z_dim=d)
    
if torch.cuda.is_available():
    torus_ae.cuda()

### Loading the saved weights

In [None]:
if weights_loaded == True:
    PATH_weights_loaded = f'../nn_weights/{set_name}_exp{experiment_number-1}.pt'
    torus_ae.load_state_dict(torch.load(PATH_weights_loaded))
    torus_ae.eval()

## Optimizer and loss function


In [None]:
optimizer = optim.Adam(torus_ae.parameters(),lr=lr, weight_decay=weight_decay)
# Check if the GPU is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

#Force CPU
if Force_CPU == True:
    device = torch.device("cpu")

# Move the AE to the selected device
torus_ae.to(device)

In [None]:
def curv_func(encoded_data, function=torus_ae.decoder_torus):
    metric_on_data = ricci_regularization.metric_jacfwd_vmap(encoded_data,
                                           function=function)
    det_on_data = torch.det(metric_on_data)
    Sc_on_data = ricci_regularization.Sc_jacfwd_vmap(encoded_data,
                                           function=function)
    N = metric_on_data.shape[0]
    Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*torch.square(Sc_on_data)).sum()
    #Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*(Sc_on_data**4)).sum()
    return Integral_of_Sc
"""
# minimizing |g-I|_F
def curv_func(encoded_data, function=torus_ae.decoder_torus):
    metric_on_data = ricci_regularization.metric_jacfwd_vmap(encoded_data,
                                           function=function)
    N = metric_on_data.shape[0]
    func = (1/N)*(metric_on_data-torch.eye(d)).norm(dim=(1,2)).sum()
    return func
"""
    
# Loss = MSE + uniform_loss + curv_loss
#  where the uniform_loss uses modulis of Fourier modes, of the empirical distribution.
#  This requires batch size to be in the range of CLT.
#
# Inputs:
#   recon_data: reconstructed data via decoder
#   data: original data
#   z: latent variable
def loss_function(recon_data, data, z, decoder,device=device,compute_curvature = compute_curvature,diagnostic_mode = diagnostic_mode):
    MSE = F.mse_loss(recon_data, data.view(-1, D), reduction='mean')
    # Splits sines and cosines
    z_sin = z[:, 0:d]
    z_cos = z[:, d:2*d]
    #
    # Compute empirical first mode
    mode1 = torch.mean( z, dim = 0)
    mode1 = torch.sum( mode1*mode1 )
    #
    # Compute empirical second mode
    mode2_1 = torch.mean( 2*z_cos*z_cos-1, dim = 0)
    mode2_1 = torch.sum( mode2_1*mode2_1)
    mode2_2 = torch.mean( 2*z_sin*z_cos, dim = 0)
    mode2_2 = torch.sum( mode2_2*mode2_2 )
    mode2 = mode2_1 + mode2_2
    #
    unif_loss = mode1 + mode2
    #
    if num_moments > 2:
        mode3_1 = torch.mean( 4*z_cos**3-3*z_cos, dim = 0)
        mode3_1 = torch.sum( mode3_1*mode3_1)
        mode3_2 = torch.mean( z_sin*(8*z_cos**3-4*z_cos), dim = 0)
        mode3_2 = torch.sum( mode3_2*mode3_2 )
        mode3 = mode3_1 + mode3_2
        unif_loss += mode3
    #
    if num_moments > 3:
        mode4_1 = torch.mean( 8*z_cos**4-8*z_cos**2+1, dim = 0)
        mode4_1 = torch.sum( mode4_1*mode4_1)
        mode4_2 = torch.mean( z_sin*(16*z_cos**4-12*z_cos**2+1), dim = 0)
        mode4_2 = torch.sum( mode4_2*mode4_2 )
        mode4 = mode4_1 + mode4_2
        unif_loss += mode4
    dict_losses = {
        "MSE": MSE,
        "uniform_loss": unif_loss,
    }
    
    if compute_curvature == True:
        encoded_points_no_grad = torus_ae.encoder2lifting(data.view(-1, D)).detach()
        #curv_loss = curv_func(encoded_points_no_grad,function=torus_ae.decoder_torus)

        metric_on_data = ricci_regularization.metric_jacfwd_vmap(encoded_points_no_grad,
                                           function=decoder)
        det_on_data = torch.det(metric_on_data)
        Sc_on_data = ricci_regularization.Sc_jacfwd_vmap(encoded_points_no_grad,
                                           function=decoder,device=device,eps=eps)
        
        if curvature_penalization_mode == "mean":
            curv_loss = (torch.sqrt(det_on_data)*torch.square(Sc_on_data)).mean()
        elif curvature_penalization_mode == "max":
            curv_outlyers = torch.nn.ReLU()(torch.sqrt(det_on_data)*torch.square(Sc_on_data) - delta_curv)
            curv_loss = (curv_outlyers).max()
        
        dict_losses["curvature_loss"] = curv_loss
        if diagnostic_mode == True:
            curv_squared_mean = (torch.square(Sc_on_data)).mean()
            curv_squared_max = (torch.square(Sc_on_data)).max()
            dict_losses["curv_squared_mean"] = curv_squared_mean
            dict_losses["curv_squared_max"] = curv_squared_max
    if diagnostic_mode == True:
        if compute_curvature == False:
            encoded_points_no_grad = torus_ae.encoder2lifting(data.view(-1, D)).detach()
            metric_on_data = ricci_regularization.metric_jacfwd_vmap(encoded_points_no_grad,function=decoder)
            det_on_data = torch.det(metric_on_data)    
        #regularization term
        #eps = 0.01
        g_inv_train_batch = torch.linalg.inv(metric_on_data + eps*torch.eye(d).to(device))
        g_inv_norm_train_batch = torch.linalg.matrix_norm(g_inv_train_batch)
        g_inv_norm_mean = torch.mean(g_inv_norm_train_batch)
        g_inv_norm_max = torch.max(g_inv_norm_train_batch)
        g_det_mean = det_on_data.mean()
        g_det_max = det_on_data.max()
        g_det_min = det_on_data.min()

        decoder_jac_norm = torch.func.vmap(torch.trace)(metric_on_data)
        decoder_jac_norm_mean = decoder_jac_norm.mean()
        decoder_jac_norm_max = decoder_jac_norm.max()
        dict_losses["g_inv_norm_mean"] = g_inv_norm_mean
        dict_losses["g_inv_norm_max"] = g_inv_norm_max
        dict_losses["g_det_mean"] = g_det_mean
        dict_losses["g_det_max"] = g_det_max
        dict_losses["g_det_min"] = g_det_min
        
        # decoder jac
        dict_losses["decoder_jac_norm_mean"] = decoder_jac_norm_mean
        dict_losses["decoder_jac_norm_max"] = decoder_jac_norm_max

        outlyers_decoder_norm = torch.nn.ReLU()(decoder_jac_norm - delta_decoder)
        dict_losses["decoder_contractive_loss"] = (outlyers_decoder_norm).max()

        # encoder jac. THIS is VERY suboptimal for large D!!! Redo with  
        metric_array_encoder = ricci_regularization.metric_jacrev_vmap(data.view(-1, D),
                                                                       function=torus_ae.encoder_torus,
                                                                       latent_space_dim=D)
        encoder_jac_norm = torch.func.vmap(torch.trace)(metric_array_encoder)
        encoder_jac_norm_mean = encoder_jac_norm.mean()
        encoder_jac_norm_max = encoder_jac_norm.max()

        dict_losses["encoder_jac_norm_mean"] = encoder_jac_norm_mean
        dict_losses["encoder_jac_norm_max"] = encoder_jac_norm_max
        outlyers_encoder_norm = torch.nn.ReLU()(encoder_jac_norm - delta_encoder)
        #num_outlyers = torch.nonzero(outlyers_encoder_norm).flatten().shape[0]
        #dict_losses["encoder_contractive_loss"] = (1/num_outlyers)*(outlyers_encoder_norm).sum()
        dict_losses["encoder_contractive_loss"] = (outlyers_encoder_norm).max()

    return dict_losses
"""
dict_losses = {
        "MSE": MSE,
        "uniform_loss": unif_loss,
        "curvature_loss":curv_loss,
        "curv_squared_mean":curv_squared_mean,
        "curv_squared_max":curv_squared_max,
        "g_inv_norm_mean":g_inv_norm_mean,
        "g_inv_norm_max":g_inv_norm_max,
        "g_det_mean":g_det_mean,
        "g_det_max":g_det_max,
        "decoder_jac_norm_mean": decoder_jac_norm_mean,
        "decoder_jac_norm_max": decoder_jac_norm_max
    }
"""   

In [None]:
def train(epoch=1,batch_idx = 0,dict_loss_arrays={},diagnostic_mode = False):
    if batch_idx == 0:
        dict_loss_arrays = {}
    torus_ae.train()
    print("Epoch %d"%epoch)
    t = tqdm( train_loader, desc="Train", position=0 )
    
    for (data, labels) in t:
        #data = data.cuda()
        #data = data.cpu()
        data = data.to(device)
        optimizer.zero_grad()
        # Forward
        """
        if architecture_type == "TorusConvAE":
            recon_batch, z = decoder(encoder(data)) , encoder(data)
        else:
            recon_batch, z = torus_ae(data)
        """
        recon_batch, z = torus_ae(data)
        #mse_loss, uniform_loss, curvature_loss, g_inv_loss, curvature_squared = loss_function(recon_batch, data, z,decoder=torus_ae.decoder_torus)
        
        dict_losses = loss_function(recon_batch, data, z,decoder=torus_ae.decoder_torus,diagnostic_mode=diagnostic_mode)
        mse_loss = dict_losses["MSE"]
        uniform_loss = dict_losses["uniform_loss"]
        loss = mse_w*mse_loss + unif_w*uniform_loss 

        #relu_function = torch.nn.ReLU()
        #decoder_penalty = relu_function( dict_losses[f"decoder_jac_norm_{djnpm}"]- delta_decoder)
        #encoder_penalty = relu_function( dict_losses[f"encoder_jac_norm_{ejnpm}"]- delta_encoder)
        if diagnostic_mode == True:
            encoder_contractive_loss = dict_losses["encoder_contractive_loss"]
            decoder_contractive_loss = dict_losses["decoder_contractive_loss"]
            loss = loss + + lambda_contractive_decoder*decoder_contractive_loss + lambda_contractive_encoder*encoder_contractive_loss#lambda_contractive_encoder*encoder_penalty


        
        if (compute_curvature == True): 
            curvature_loss = dict_losses["curvature_loss"]
            loss = loss + curv_w*curvature_loss
        else:
            curvature_loss_mean_per_epoch = "nan"
        
        # OOD regime (optional)
        if OOD_regime == True:
            if batch_idx == 0:
                extreme_curv_points_tensor = torus_ae.encoder2lifting(data.view(-1,D)[:N_extr]).detach()
                extreme_curv_value_tensor = ricci_regularization.Sc_jacfwd_vmap(extreme_curv_points_tensor, function=torus_ae.decoder_torus,eps=eps)
            extreme_curv_points_tensor, extreme_curv_value_tensor = ricci_regularization.find_extreme_curvature_points(data_batch=data,
                                extreme_curv_points_tensor=extreme_curv_points_tensor,
                                extreme_curv_value_tensor=extreme_curv_value_tensor,batch_idx=batch_idx,
                                encoder=torus_ae.encoder2lifting,decoder=torus_ae.decoder_torus,r_ood=r_ood,N_extr=N_extr,output_dim=D)
            
            if (batch_idx % T_ood == 0) & (batch_idx > start_ood):
                OOD_curvature_loss = ricci_regularization.OODTools.curv_loss_on_OOD_samples(extreme_curv_points_tensor=extreme_curv_points_tensor,
                                                                                            decoder=torus_ae.decoder_torus,
                                                                                            sigma_ood=sigma_ood,n_ood=n_ood,N_extr=N_extr,
                                                                                            latent_space_dim=d)
                if diagnostic_mode == True:
                    print("Curvature functional at OOD points", OOD_curvature_loss)
                loss = OOD_w * OOD_curvature_loss
        
        # Backpropagate
        loss.backward()
        optimizer.step()

        #appending losses per batch to loss arrays
        if batch_idx == 0:
            #dict_loss_arrays["decoder_penalty"] = []
            #dict_loss_arrays["encoder_penalty"] = []
            for key in dict_losses.keys():
                dict_loss_arrays[key] = []
        #dict_loss_arrays["decoder_penalty"].append(decoder_penalty.item())
        #dict_loss_arrays["encoder_penalty"].append(decoder_penalty.item())
        for key in dict_losses.keys():
            dict_loss_arrays[key].append(dict_losses[key].item())
        
        # Progress bar
        batch_idx += 1
        MSE_mean_per_epoch = np.array(dict_loss_arrays["MSE"])[-batches_per_epoch:].mean()
        uniform_loss_mean_per_epoch = np.array(dict_loss_arrays["uniform_loss"])[-batches_per_epoch:].mean()    
        if compute_curvature == True:
            curvature_loss_mean_per_epoch = np.array(dict_loss_arrays["curvature_loss"])[-batches_per_epoch:].mean()
        else:
            curvature_loss_mean_per_epoch = "nan"
        #decoder_penalty_mean_per_epoch = np.array(dict_loss_arrays["decoder_contractive_loss"])[-batches_per_epoch:].mean()
        if diagnostic_mode == True:
            encoder_contractive_loss_mean_per_epoch = np.array(dict_loss_arrays["encoder_contractive_loss"])[-batches_per_epoch:].mean()
            t.set_description_str(desc=f"MSE:{MSE_mean_per_epoch}, Uniform:{uniform_loss_mean_per_epoch}, Encoder_penalty:{encoder_contractive_loss_mean_per_epoch}, Curvature:{curvature_loss_mean_per_epoch}.\n")
        else:
            t.set_description_str(desc=f"MSE:{MSE_mean_per_epoch}, Uniform:{uniform_loss_mean_per_epoch}")
        #if batch_idx > num_batches:
        #    break
    #end for
    return batch_idx, dict_loss_arrays

def test(mse_loss_array=[], uniform_loss_array=[], curvature_loss_array = [],g_inv_loss_array=[],curvature_squared_array=[]):
    torus_ae.eval()
    with torch.no_grad():
        t = tqdm( test_loader, desc="Test", position=1 )
        for data, _ in t:
            data = data.cpu()
            recon_batch, z = torus_ae(data)
            dict_losses = loss_function(recon_batch, data, z,decoder=torus_ae.decoder_torus)
            mse_loss = dict_losses["MSE"]
            uniform_loss = dict_losses["uniform_loss"]
            if compute_curvature == True:
                curvature_loss = dict_losses["curvature_loss"]
            else:
                curvature_loss = torch.zeros(1)
        
            mse_loss_array.append(mse_loss.item())
            uniform_loss_array.append(uniform_loss.item())
            curvature_loss_array.append(curvature_loss.item())
    print(f"Test losses. \nMSE:{np.array(mse_loss_array).mean()}, Uniform:{np.array(uniform_loss_array).mean()}, Curvature:{np.array(curvature_loss_array).mean()}.\n")
    return mse_loss_array, uniform_loss_array, curvature_loss_array

# Training

In [None]:
batch_idx=0
dict_loss_arrays = {}
#diagnostic_mode = False

# Launch
for epoch in range(1, num_epochs + 1):
  torus_ae.to(device)
  batch_idx, dict_loss_arrays = train(epoch=epoch,batch_idx=batch_idx,dict_loss_arrays=dict_loss_arrays,
                                                 diagnostic_mode=diagnostic_mode)
  if diagnostic_mode == True :
    dict2print = ricci_regularization.PlottingTools.translate_dict(dict2print=dict_loss_arrays, include_curvature_plots=compute_curvature,eps=eps)
    ricci_regularization.PlottingTools.plotsmart(dict2print)
  #else:
  #  ricci_regularization.PlottingTools.plotfromdict(dict_of_losses=dict_loss_arrays)
      
  # update curvature weight
  #curv_w = curv_w * 10**(curv_w_increase_rate)
    
  if (set_name == "MNIST"): #& (architecture_type == "TorusAE"):
    ricci_regularization.PlottingTools.plot_ae_outputs(test_dataset=test_dataset,
                                                       encoder=torus_ae.cpu().encoder2lifting,
                                                       decoder=torus_ae.cpu().decoder_torus)
  #test() 


# Test losses and $R^2$

In [None]:
def compute_R_squared_losses(data_loader=test_loader):
    len_test_loader = len(test_loader)
    mse_loss = 0
    curv_loss = 0
    unif_loss = 0
    input_dataset_list = []
    torus_ae.to(device)
    for (data, labels) in test_loader:
        data = data.to(device)
        input_dataset_list.append(data.cpu())
        input = data
        recon = torus_ae(data)[0]
        z = torus_ae(data)[1]
        enc = torus_ae.encoder2lifting(data.view(-1,D))
        dict_losses = loss_function(recon, input, z,
                                    device=device,
                                    decoder=torus_ae.decoder_torus, 
                                    compute_curvature=True,diagnostic_mode=False)
        mse_loss += dict_losses["MSE"].cpu().detach()/len_test_loader
        unif_loss += dict_losses["uniform_loss"].cpu().detach()/len_test_loader
        curv_loss += dict_losses["curvature_loss"].cpu().detach()/len_test_loader
    input_dataset_tensor = torch.cat(input_dataset_list).view(-1,D)
    var = torch.var(input_dataset_tensor.flatten())
    R_squared = 1 - mse_loss/var
    return mse_loss, unif_loss, curv_loss, R_squared

In [None]:
test_mse, test_unif, test_curv, test_R_squared = compute_R_squared_losses(test_loader)
#train_mse, train_unif, train_curv, train_R_squared, train_decoder_penalty = compute_R_squared_losses(train_loader)

In [None]:
#print(f"Train losses:\nmse:{train_mse}, unif_loss:{train_unif}, decoder_penalty:{train_decoder_penalty}, curv_loss:{train_curv}")
#print(f"R_squared: {train_R_squared.item():.4f}")
#print(f"Test losses:\nmse:{test_mse}")
print(f"Test losses:\nmse:{test_mse}, unif_loss:{test_unif}, curv_loss:{test_curv}")
print(f"R_squared: {test_R_squared.item():.4f}")

In [None]:
"""
import torch
from torcheval.metrics import R2Score
R_squared = R2Score()#(multioutput="raw_values")
input = input_dataset_tensor.flatten()
target = recon_dataset_tensor.flatten()
R_squared.update(input, target)
R_squared.compute()#.shape
"""

## Saving the model

In [None]:
if weights_saved == True:
    PATH_vae = f'../nn_weights/{set_name}_exp{experiment_number}.pt'
    torch.save(torus_ae.state_dict(), PATH_vae)

## Losses plot

In [None]:
# loss ploting
if diagnostic_mode == True:
    #fig,axes = ricci_regularization.PlottingTools.plotsmart(dict2print)
    fig,axes = ricci_regularization.PlottingTools.PlotSmartConvolve(dict2print)
else:
    fig,axes = ricci_regularization.PlottingTools.plotfromdict(dict_loss_arrays)
if violent_saving == True:
    fig.savefig(f"{Path_pictures}/losses_exp{experiment_number}.pdf",bbox_inches='tight',format="pdf")

In [None]:
"""
fig,axes = ricci_regularization.PlottingTools.plot9losses(mse_loss_array,curvature_loss_array,g_inv_meanperbatch_array)
if violent_saving == True:
    fig.savefig(f"{Path_pictures}/9losses_exp{experiment_number}.pdf",bbox_inches='tight',format="pdf")
"""

## Torus latent space

In [None]:
#inspiration for torus_ae.encoder2lifting
"""
def circle2anglevectorized(zLatentTensor,d = d):
    cosphi = zLatentTensor[:, 0:d]
    sinphi = zLatentTensor[:, d:2*d]
    phi = torch.acos(cosphi)*torch.sgn(torch.asin(sinphi))
    return phi
"""


In [None]:
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( train_loader, position=0 ):
#for (data, labels) in train_loader:
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    #zlist.append(torus_ae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
encoded_points_no_grad = encoded_points.detach()/math.pi
color_array = torch.cat(colorlist).detach()

In [None]:
plt.figure(figsize=(8, 6))
if set_name == "Swissroll":
    my_cmap = "jet"
else:
    my_cmap = ricci_regularization.PlottingTools.discrete_cmap(k, 'jet')
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap=my_cmap)

if set_name in ["Synthetic","MNIST"]:
    plt.colorbar(ticks=range(k))
plt.grid(True)
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/latent_space_exp{experiment_number}.pdf",bbox_inches='tight',format="pdf")
plt.show()

In [None]:
import json
json_config = {
    "experiment_name": experiment_name,
    "experiment_number": experiment_number,
    "dataset":
    {
        "name": set_name,
        "parameters": set_parameters,
    },
    "architecture" :
    {
        "name":architecture_type,
        "input_dim": D,
        "latent_dim": d
    },
    "optimization_parameters": 
    {
	    "learning_rate": lr,
	    "batch_size": batch_size,
        "split_ratio": split_ratio,
	    "num_epochs": num_epochs,
        "weight_decay": weight_decay,
        "random_shuffling":random_shuffling,
        "random_seed": random_seed,
        "device": device.type
    },
    "losses":
    {
	    "mse_w": mse_w,
	    "unif_w": unif_w,
        "Number of moments used": num_moments,
	    "curv_w": curv_w,
        "delta_curv": delta_curv,
        "curvature_penalization_mode": curvature_penalization_mode,
        "g_inv regularization eps": eps,
        "lambda_contractive_encoder": lambda_contractive_encoder,
        "delta_encoder" : delta_encoder,
        "lambda_contractive_decoder" : lambda_contractive_decoder,
        "delta_decoder" : delta_decoder,
#        "decoder_jac_norm_penalization_mode " : djnpm,
#        "encoder_jac_norm_penalization_mode " : ejnpm,
        "diagnostic_mode": diagnostic_mode,
        "compute_curvature": compute_curvature
    },
    "OOD_parameters": 
    {
        "OOD_regime": OOD_regime,
        "start_ood":start_ood,
        "T_ood":T_ood,
        "n_ood":n_ood,
        "sigma_ood":sigma_ood,
        "N_extr":N_extr,
        "r_ood": r_ood,
        "OOD_w":OOD_w
    },
    "training_results_on_test_data":
    {
        "R^2": test_R_squared.item(),
        "mse_loss": test_mse.item(),
        "unif_loss": test_unif.item(),
        "curv_loss": test_curv.item()
    },
    "Path_pictures": Path_pictures,
    "Path_weights": Path_weights,
    "Path_experiments": Path_experiments,
    "weights_saved_at": PATH_vae
}
if weights_loaded == True:
    json_config["weights_loaded_from"] = PATH_weights_loaded
# Save dictionary to JSON file
with open(f'{Path_experiments}/{experiment_name}exp{experiment_number}.json', 'w') as json_file:
    json.dump(json_config, json_file, indent=4)