# Torus AE training

This notebook performs the training of the autoencoder (AE). 

The AE consists of the encoder $\Phi$ and the decoder $\Psi$.
The latent space of the AE is topologically a $d-$ dimensional torus $\mathcal{T}^d$, i.e it can be considered as a periodic box $[-\pi, \pi]^d$. We define a Riemannian metric on the latent space  as the pull-back of the Euclidean metric in the output space $\mathbb{R}^D$ by the decoder function $\Psi$ of the AE:
\begin{equation}
    g = \nabla \Psi ^* \nabla \Psi \ .
\end{equation}


One can switch between regimes: 
"yaml_config["training_mode"]["diagnostic_mode"]", "compute_curvature_mode", "yaml_config["training_mode"]["OOD_regime"]".

If "yaml_config["training_mode"]["diagnostic_mode"]"==True, following losses are plotted: MSE, $\mathcal{L}_\mathrm{unif}$, $\mathcal{L}_\mathrm{curv}$, $\det(g)$, $\|g_{reg}^{-1}\|_F$, $\|\nabla \Psi \|^2_F$, $\|\nabla \Phi \|^2_F$, where:
\begin{equation*}
\mathcal{L}_\mathrm{curv} := \int_M R^2 \mu \ ,
\end{equation*}

\begin{equation*}
\mathcal{L}_\mathrm{unif} := \sum\limits_{k=1}^{m} |\int_M z^k  \mu_N (dz) |^2 \ ,
\end{equation*}
where $R$ states for scalar curvature (see https://en.wikipedia.org/wiki/Scalar_curvature), $\mu_N = \Phi\# ( \frac{1}{N}\sum\limits_{j=1}^{N} \delta_{X_i} ) $ is the push-forward of the natural measure induced by the dataset by the encoder $\Phi$, thus $\mu_N$ is a measure on $\mathcal{T}^d$,  $ \alpha_k = \frac{1}{N} \sum_{j=1}^{N} z_j^k$ is the empirical estimator of the $k$ -th moment of the data distribution in the latent space.

If $\xi \sim \mathcal{U}[-\pi, \pi]$ and $z = e^{i \xi}$ than all the moments of $z$ are zero, namely if $\mathcal{L}_\mathrm{unif} \to 0$ as $m \to \infty$, one obtains weak convergence of the data distribution in the latent space to the uniform distribution.

Also $g_{reg} = g + \varepsilon \cdot I$ is the regularized matrix of metric for stability of inverse matrix computation, $\|\|_F$ is the Frobenius norm of the matrix.

The notebook consists of

1) Imports. Choosing hyperparameters for dataset uploading, learning and plotting such as learning rate, batch size, weights of MSE loss, curvature loss, etc. Automatic loading of train and test dataloaders. Choice among data sets "Synthetic", "Swissroll", "MNIST", "MNIST01" (any selected labels from MNIST). 
2) Architecture and device. Architecture types: Fully connected (TorusAE), Convolutional (TorusConvAE). Device: cuda/cpu. 
3) Training.
4) Report of training. Printing of graphs of losses, saves of a json file with training params.


# Imports

In [None]:
# prerequisites
import torch
from sklearn import datasets
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import math
import numpy as np
from tqdm.notebook import tqdm
import json
import os
import ricci_regularization
import yaml

## Hyperparameters loading from YAML file

In [None]:
# Open and read the YAML configuration file
with open('init_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

# Print the loaded YAML configuration
print("YAML Configuration loaded successfully.")


In [None]:
# Construct the experiment_name
experiment_name = yaml_config["dataset"]["name"] + "_torus_AE"
print(f"Experiment Name: {experiment_name}")  # Print the constructed experiment name

# Paths for saving  pictures
Path_pictures = f"../experiments/{experiment_name}/experiment" + str(yaml_config["experiment"]["experiment_number"])
print(f"Path for Pictures: {Path_pictures}")  # Print the path for pictures

# Check and create directories based on configuration
if yaml_config["experiment"]["violent_saving"]:  # Check if violent saving is enabled
    print("Plots will be saved")
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        if not os.path.exists(f"../experiments/{experiment_name}/"):  # Check if the experiment directory does not exist
            os.mkdir(f"../experiments/{experiment_name}/")  # Create the experiment directory if not yet created
            print(f"Created directory: ../experiments/{experiment_name}/")  # Print directory creation feedback
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
else:
    print("Plots will not be saved")
Path_weights = "../nn_weights/"  # Path for saving neural network weights
print(f"Path for Weights: {Path_weights}")  # Print the path for neural network weights


In [None]:
# OOD sampling parameters
# This parameters mean nothing if yaml_config["training_mode"]["OOD_regime"] == False

T_ood = 20 # 100 # period of OOD penalization
n_ood = 5 # number of OOD samples per point
sigma_ood = 2e-1 # sigma of OOD Gaussian samples: 2e-1 swissroll
N_extr = 16 # 32 batch size of extremal curvature points
r_ood = 1e-3 # 1e-2 decay factor
OOD_w = yaml_config["loss_settings"]["lambda_curv"] # weight on curvature in OOD sampling
start_ood = 0 # OOD starting batch

## Dataset loading 

In [None]:
# Set the random seed for data loader reproducibility
torch.manual_seed(yaml_config["data_loader_settings"]["random_seed"])
print(f"Set random seed to: {yaml_config['data_loader_settings']['random_seed']}")

# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"]
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

# Calculate number of batches per epoch
batches_per_epoch = len(train_loader)
print(f"Number of batches per epoch: {batches_per_epoch}")

# Architecture and device

In [None]:
# Set the random seed for reproducibility
torch.manual_seed(yaml_config["data_loader_settings"]["random_seed"])
print(f"Set random seed to: {yaml_config['data_loader_settings']['random_seed']}")

# Selecting the architecture type based on YAML configuration
if yaml_config["architecture"]["type"] == "TorusConvAE":
    torus_ae = ricci_regularization.Architectures.TorusConvAE(
        x_dim=yaml_config["architecture"]["output_dim"],
        h_dim1=512,
        h_dim2=256,
        z_dim=yaml_config["architecture"]["latent_dim"],
        pixels=28
    )
    print("Selected architecture: TorusConvAE")
else:
    torus_ae = ricci_regularization.Architectures.TorusAE(
        x_dim=yaml_config["architecture"]["output_dim"],
        h_dim1=512,
        h_dim2=256,
        z_dim=yaml_config["architecture"]["latent_dim"]
    )
    print("Selected architecture: TorusAE")

# Check GPU availability and set device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available! Training will use GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is NOT available! Using CPU.")

# Move the AE model to the selected device
torus_ae.to(device)
print(f"Moved model to device: {device}")

### Loading the saved weights

In [None]:
if yaml_config["experiment"]["weights_loaded"] == True:
    PATH_weights_loaded = f'../nn_weights/{yaml_config["dataset"]["name"]}_exp{yaml_config["experiment"]["experiment_number"]-1}.pt'
    torus_ae.load_state_dict(torch.load(PATH_weights_loaded))
    torus_ae.eval()
    print(f"Weights loaded from {PATH_weights_loaded}")
else:
    print("No weights loaded as 'weights_loaded' is set to False in the configuration.")


## Optimizer and loss function


In [None]:
optimizer = optim.Adam(torus_ae.parameters(),
        lr=yaml_config["optimizer_settings"]["lr"],
        weight_decay=yaml_config["optimizer_settings"]["weight_decay"])
print(f"Optimizer configured with learning rate {yaml_config['optimizer_settings']['lr']} and weight decay {yaml_config['optimizer_settings']['weight_decay']}.")

In [None]:
def curv_func(encoded_data, function=torus_ae.decoder_torus):
    metric_on_data = ricci_regularization.metric_jacfwd_vmap(encoded_data,
                                           function=function)
    det_on_data = torch.det(metric_on_data)
    Sc_on_data = ricci_regularization.Sc_jacfwd_vmap(encoded_data,
                                           function=function)
    N = metric_on_data.shape[0]
    Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*torch.square(Sc_on_data)).sum()
    #Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*(Sc_on_data**4)).sum()
    return Integral_of_Sc

print("curv_func: Calculated integral of scalar curvature")

# Loss = MSE + uniform_loss + curv_loss
def loss_function(recon_data, data, z, decoder):
    MSE = F.mse_loss(recon_data, data.view(-1, yaml_config["dataset"]["D"]), reduction='mean')
    z_sin = z[:, 0:yaml_config["architecture"]["latent_dim"]]
    z_cos = z[:, yaml_config["architecture"]["latent_dim"]:2*yaml_config["architecture"]["latent_dim"]]
    mode1 = torch.mean( z, dim = 0)
    mode1 = torch.sum( mode1*mode1 )
    mode2_1 = torch.mean( 2*z_cos*z_cos-1, dim = 0)
    mode2_1 = torch.sum( mode2_1*mode2_1)
    mode2_2 = torch.mean( 2*z_sin*z_cos, dim = 0)
    mode2_2 = torch.sum( mode2_2*mode2_2 )
    mode2 = mode2_1 + mode2_2
    unif_loss = mode1 + mode2
    if yaml_config["loss_settings"]["num_moments"] > 2:
        mode3_1 = torch.mean( 4*z_cos**3-3*z_cos, dim = 0)
        mode3_1 = torch.sum( mode3_1*mode3_1)
        mode3_2 = torch.mean( z_sin*(8*z_cos**3-4*z_cos), dim = 0)
        mode3_2 = torch.sum( mode3_2*mode3_2 )
        mode3 = mode3_1 + mode3_2
        unif_loss += mode3
    if yaml_config["loss_settings"]["num_moments"] > 3:
        mode4_1 = torch.mean( 8*z_cos**4-8*z_cos**2+1, dim = 0)
        mode4_1 = torch.sum( mode4_1*mode4_1)
        mode4_2 = torch.mean( z_sin*(16*z_cos**4-12*z_cos**2+1), dim = 0)
        mode4_2 = torch.sum( mode4_2*mode4_2 )
        mode4 = mode4_1 + mode4_2
        unif_loss += mode4
    dict_losses = {
        "MSE": MSE,
        "uniform_loss": unif_loss,
    }
    
    if yaml_config["training_mode"]["compute_curvature"] == True:
        encoded_points_no_grad = torus_ae.encoder2lifting(data.view(-1, yaml_config["dataset"]["D"])).detach()
        metric_on_data = ricci_regularization.metric_jacfwd_vmap(encoded_points_no_grad,
                                           function=decoder)
        det_on_data = torch.det(metric_on_data)
        Sc_on_data = ricci_regularization.Sc_jacfwd_vmap(encoded_points_no_grad,
                                           function=decoder,eps=yaml_config["loss_settings"]["eps"])
        
        if yaml_config["loss_settings"]["curvature_penalization_mode"] == "mean":
            curv_loss = (torch.sqrt(det_on_data)*torch.square(Sc_on_data)).mean()
        elif yaml_config["loss_settings"]["curvature_penalization_mode"] == "max":
            curv_outlyers = torch.nn.ReLU()(torch.sqrt(det_on_data)*torch.square(Sc_on_data) - yaml_config["loss_settings"]["delta_curv"])
            curv_loss = (curv_outlyers).max()
        
        dict_losses["curvature_loss"] = curv_loss
        if yaml_config["training_mode"]["diagnostic_mode"] == True:
            curv_squared_mean = (torch.square(Sc_on_data)).mean()
            curv_squared_max = (torch.square(Sc_on_data)).max()
            dict_losses["curv_squared_mean"] = curv_squared_mean
            dict_losses["curv_squared_max"] = curv_squared_max
    if yaml_config["training_mode"]["diagnostic_mode"] == True:
        if yaml_config["training_mode"]["compute_curvature"] == False:
            encoded_points_no_grad = torus_ae.encoder2lifting(data.view(-1, yaml_config["dataset"]["D"])).detach()
            metric_on_data = ricci_regularization.metric_jacfwd_vmap(encoded_points_no_grad,function=decoder)
            det_on_data = torch.det(metric_on_data)    
        g_inv_train_batch = torch.linalg.inv(metric_on_data + yaml_config["loss_settings"]["eps"]*torch.eye(yaml_config["architecture"]["latent_dim"]).to(device))
        g_inv_norm_train_batch = torch.linalg.matrix_norm(g_inv_train_batch)
        g_inv_norm_mean = torch.mean(g_inv_norm_train_batch)
        g_inv_norm_max = torch.max(g_inv_norm_train_batch)
        g_det_mean = det_on_data.mean()
        g_det_max = det_on_data.max()
        g_det_min = det_on_data.min()
        decoder_jac_norm = torch.func.vmap(torch.trace)(metric_on_data)
        decoder_jac_norm_mean = decoder_jac_norm.mean()
        decoder_jac_norm_max = decoder_jac_norm.max()
        dict_losses["g_inv_norm_mean"] = g_inv_norm_mean
        dict_losses["g_inv_norm_max"] = g_inv_norm_max
        dict_losses["g_det_mean"] = g_det_mean
        dict_losses["g_det_max"] = g_det_max
        dict_losses["g_det_min"] = g_det_min
        dict_losses["decoder_jac_norm_mean"] = decoder_jac_norm_mean
        dict_losses["decoder_jac_norm_max"] = decoder_jac_norm_max
        outlyers_decoder_norm = torch.nn.ReLU()(decoder_jac_norm - yaml_config["loss_settings"]["delta_decoder"])
        dict_losses["decoder_contractive_loss"] = (outlyers_decoder_norm).max()
        metric_array_encoder = ricci_regularization.metric_jacrev_vmap(data.view(-1, yaml_config["dataset"]["D"]),
                function=torus_ae.encoder_torus,
                latent_space_dim=yaml_config["dataset"]["D"]) # D here is not the latent space dimension (naming is counter-intuitive)!
        encoder_jac_norm = torch.func.vmap(torch.trace)(metric_array_encoder)
        encoder_jac_norm_mean = encoder_jac_norm.mean()
        encoder_jac_norm_max = encoder_jac_norm.max()
        dict_losses["encoder_jac_norm_mean"] = encoder_jac_norm_mean
        dict_losses["encoder_jac_norm_max"] = encoder_jac_norm_max
        outlyers_encoder_norm = torch.nn.ReLU()(encoder_jac_norm - yaml_config["loss_settings"]["delta_encoder"])
        dict_losses["encoder_contractive_loss"] = (outlyers_encoder_norm).max()
    return dict_losses

print("loss_function: Calculated MSE, uniform_loss, and optionally curvature_loss and diagnostic metrics")

# OOD initialization
extreme_curv_value_tensor = None
extreme_curv_points_tensor = None
if yaml_config["training_mode"]["OOD_regime"] == True:
    first_batch,_ = next(iter(train_loader))
    first_batch = first_batch.to(device)
    extreme_curv_points_tensor = torus_ae.encoder2lifting(first_batch.view(-1,yaml_config["dataset"]["D"])[:N_extr]).detach()
    extreme_curv_points_tensor.to(device)
    extreme_curv_value_tensor = ricci_regularization.Sc_jacfwd_vmap(extreme_curv_points_tensor, 
            function=torus_ae.decoder_torus,eps=yaml_config["loss_settings"]["eps"])

print("OOD initialization: Initialized OOD regime with extreme curvature points")

In [None]:
def train(epoch=1,batch_idx = 0,dict_loss_arrays={},
          extreme_curv_points_tensor = extreme_curv_points_tensor, 
          extreme_curv_value_tensor = extreme_curv_value_tensor):
    if batch_idx == 0:
        dict_loss_arrays = {}
    torus_ae.train()
    print("Epoch %d"%epoch)
    t = tqdm( train_loader, desc="Train", position=0 )
    
    for (data, labels) in t:
        #data = data.cuda()
        #data = data.cpu()
        data = data.to(device)
        optimizer.zero_grad()
        # Forward
        """
        if yaml_config["architecture"]["type"] == "TorusConvAE":
            recon_batch, z = decoder(encoder(data)) , encoder(data)
        else:
            recon_batch, z = torus_ae(data)
        """
        recon_batch, z = torus_ae(data)
        #mse_loss, uniform_loss, curvature_loss, g_inv_loss, curvature_squared = loss_function(recon_batch, data, z,decoder=torus_ae.decoder_torus)
        
        dict_losses = loss_function(recon_batch, data, z,
                                    decoder=torus_ae.decoder_torus)
        mse_loss = dict_losses["MSE"]
        uniform_loss = dict_losses["uniform_loss"]
        loss = yaml_config["loss_settings"]["lambda_recon"]*mse_loss + yaml_config["loss_settings"]["lambda_unif"]*uniform_loss 

        #relu_function = torch.nn.ReLU()
        #decoder_penalty = relu_function( dict_losses[f"decoder_jac_norm_{djnpm}"]- yaml_config["loss_settings"]["delta_decoder"])
        #encoder_penalty = relu_function( dict_losses[f"encoder_jac_norm_{ejnpm}"]- yaml_config["loss_settings"]["delta_encoder"])
        if yaml_config["training_mode"]["diagnostic_mode"] == True:
            encoder_contractive_loss = dict_losses["encoder_contractive_loss"]
            decoder_contractive_loss = dict_losses["decoder_contractive_loss"]
            loss = loss + + yaml_config["loss_settings"]["lambda_contractive_decoder"]*decoder_contractive_loss + yaml_config["loss_settings"]["lambda_contractive_encoder"]*encoder_contractive_loss#yaml_config["loss_settings"]["lambda_contractive_encoder"]*encoder_penalty


        
        if (yaml_config["training_mode"]["compute_curvature"] == True): 
            curvature_loss = dict_losses["curvature_loss"]
            loss = loss + yaml_config["loss_settings"]["lambda_curv"]*curvature_loss
        else:
            curvature_loss_mean_per_epoch = "nan"
        
        # OOD regime (optional)
        if yaml_config["training_mode"]["OOD_regime"] == True:
            extreme_curv_points_tensor, extreme_curv_value_tensor = ricci_regularization.find_extreme_curvature_points(data_batch=data,
                                extreme_curv_points_tensor=extreme_curv_points_tensor,
                                extreme_curv_value_tensor=extreme_curv_value_tensor,batch_idx=batch_idx,
                                encoder=torus_ae.encoder2lifting,decoder=torus_ae.decoder_torus,r_ood=r_ood,
                                N_extr=N_extr,output_dim=yaml_config["dataset"]["D"])
            
            if (batch_idx % T_ood == 0) & (batch_idx >= start_ood):
                OOD_curvature_loss = ricci_regularization.OODTools.curv_loss_on_OOD_samples(extreme_curv_points_tensor=extreme_curv_points_tensor,
                                                                                            decoder=torus_ae.decoder_torus,
                                                                                            sigma_ood=sigma_ood,n_ood=n_ood,N_extr=N_extr,
                                                                                            latent_space_dim=yaml_config["architecture"]["latent_dim"])
                if yaml_config["training_mode"]["diagnostic_mode"] == True:
                    print("Curvature functional at OOD points", OOD_curvature_loss)
                loss = OOD_w * OOD_curvature_loss
        
        # Backpropagate
        loss.backward()
        optimizer.step()

        #appending losses per batch to loss arrays
        if batch_idx == 0:
            #dict_loss_arrays["decoder_penalty"] = []
            #dict_loss_arrays["encoder_penalty"] = []
            for key in dict_losses.keys():
                dict_loss_arrays[key] = []
        #dict_loss_arrays["decoder_penalty"].append(decoder_penalty.item())
        #dict_loss_arrays["encoder_penalty"].append(decoder_penalty.item())
        for key in dict_losses.keys():
            dict_loss_arrays[key].append(dict_losses[key].item())
        
        # Progress bar
        batch_idx += 1
        MSE_mean_per_epoch = np.array(dict_loss_arrays["MSE"])[-batches_per_epoch:].mean()
        uniform_loss_mean_per_epoch = np.array(dict_loss_arrays["uniform_loss"])[-batches_per_epoch:].mean()    
        if yaml_config["training_mode"]["compute_curvature"] == True:
            curvature_loss_mean_per_epoch = np.array(dict_loss_arrays["curvature_loss"])[-batches_per_epoch:].mean()
        else:
            curvature_loss_mean_per_epoch = "nan"
        #decoder_penalty_mean_per_epoch = np.array(dict_loss_arrays["decoder_contractive_loss"])[-batches_per_epoch:].mean()
        
        # Losses to be printed online in yaml_config["training_mode"]["diagnostic_mode"]
        if yaml_config["training_mode"]["diagnostic_mode"] == True:
            encoder_contractive_loss_mean_per_epoch = np.array(dict_loss_arrays["encoder_contractive_loss"])[-batches_per_epoch:].mean()
            t.set_description_str(desc=f"MSE:{MSE_mean_per_epoch}, Uniform:{uniform_loss_mean_per_epoch}, Encoder_penalty:{encoder_contractive_loss_mean_per_epoch}, Curvature:{curvature_loss_mean_per_epoch}.\n")
        
        # Losses to be printed online in yaml_config["training_mode"]["compute_curvature"] mode
        if yaml_config["training_mode"]["compute_curvature"] == True:
            t.set_description_str(desc=f"MSE:{MSE_mean_per_epoch}, Uniform:{uniform_loss_mean_per_epoch}, Curvature:{curvature_loss_mean_per_epoch}.\n")
        else:
            t.set_description_str(desc=f"MSE:{MSE_mean_per_epoch}, Uniform:{uniform_loss_mean_per_epoch}")
        #if batch_idx > num_batches: # if one wants to stop after certain number of batches
        #    break
    #end for
    return batch_idx, dict_loss_arrays

def test(mse_loss_array=[], uniform_loss_array=[], curvature_loss_array = [],g_inv_loss_array=[],curvature_squared_array=[]):
    torus_ae.eval()
    with torch.no_grad():
        t = tqdm( test_loader, desc="Test", position=1 )
        for data, _ in t:
            data = data.cpu()
            recon_batch, z = torus_ae(data)
            dict_losses = loss_function(recon_batch, data, z,decoder=torus_ae.decoder_torus)
            if yaml_config["training_mode"]["compute_curvature"] == True:
                curvature_loss = dict_losses["curvature_loss"]
            else:
                curvature_loss = torch.zeros(1)
        
            mse_loss_array.append( dict_losses["MSE"].item() ) 
            uniform_loss_array.append( dict_losses["uniform_loss"].item() )
            curvature_loss_array.append(curvature_loss.item())
    print(f"Test losses. \nMSE:{np.array(mse_loss_array).mean()}, Uniform:{np.array(uniform_loss_array).mean()}, Curvature:{np.array(curvature_loss_array).mean()}.\n")
    return mse_loss_array, uniform_loss_array, curvature_loss_array

# Training

In [None]:
batch_idx=0
dict_loss_arrays = {}

# Launch
for epoch in range(1, yaml_config["optimizer_settings"]["num_epochs"] + 1):
  torus_ae.to(device)
  batch_idx, dict_loss_arrays = train(epoch=epoch,batch_idx=batch_idx,dict_loss_arrays=dict_loss_arrays)
  if yaml_config["training_mode"]["diagnostic_mode"] == True :
    dict2print = ricci_regularization.PlottingTools.translate_dict(dict2print=dict_loss_arrays, 
                include_curvature_plots=yaml_config["training_mode"]["compute_curvature"],
                eps=yaml_config["loss_settings"]["eps"])
    ricci_regularization.PlottingTools.plotsmart(dict2print)
  #else:
  #  ricci_regularization.PlottingTools.plotfromdict(dict_of_losses=dict_loss_arrays)
      
  
    
  if (yaml_config["dataset"]["name"] in ["MNIST","MNIST01"]): #& (yaml_config["architecture"]["type"] == "TorusAE"):
    ricci_regularization.PlottingTools.plot_ae_outputs_selected(test_dataset=test_dataset,
                                                       encoder=torus_ae.cpu().encoder2lifting,
                                                       decoder=torus_ae.cpu().decoder_torus,
                                                       selected_labels=yaml_config["dataset"]["selected_labels"])
  #test() 


# Report of training

## Test losses, $R^2$

In [None]:
def compute_R_squared_losses(data_loader=test_loader):
    len_test_loader = len(test_loader)
    mse_loss = 0
    curv_loss = 0
    unif_loss = 0
    input_dataset_list = []
    torus_ae.to(device)
    for (data, labels) in test_loader:
        data = data.to(device)
        input_dataset_list.append(data.cpu())
        input = data
        recon = torus_ae(data)[0]
        z = torus_ae(data)[1]
        enc = torus_ae.encoder2lifting(data.view(-1,yaml_config["dataset"]["D"]))
        dict_losses = loss_function(recon, input, z,
                                    decoder=torus_ae.decoder_torus)
        mse_loss += dict_losses["MSE"].cpu().detach()/len_test_loader
        unif_loss += dict_losses["uniform_loss"].cpu().detach()/len_test_loader
        curv_loss += dict_losses["curvature_loss"].cpu().detach()/len_test_loader
    input_dataset_tensor = torch.cat(input_dataset_list).view(-1,yaml_config["dataset"]["D"])
    var = torch.var(input_dataset_tensor.flatten())
    R_squared = 1 - mse_loss/var
    return mse_loss, unif_loss, curv_loss, R_squared

In [None]:
test_mse, test_unif, test_curv, test_R_squared = compute_R_squared_losses(test_loader)
#train_mse, train_unif, train_curv, train_R_squared, train_decoder_penalty = compute_R_squared_losses(train_loader)

In [None]:
#print(f"Train losses:\nmse:{train_mse}, unif_loss:{train_unif}, decoder_penalty:{train_decoder_penalty}, curv_loss:{train_curv}")
#print(f"R_squared: {train_R_squared.item():.4f}")
#print(f"Test losses:\nmse:{test_mse}")
print(f"Test losses:\nmse:{test_mse}, unif_loss:{test_unif}, curv_loss:{test_curv}")
print(f"R_squared: {test_R_squared.item():.4f}")

In [None]:
"""
import torch
from torcheval.metrics import R2Score
R_squared = R2Score()#(multioutput="raw_values")
input = input_dataset_tensor.flatten()
target = recon_dataset_tensor.flatten()
R_squared.update(input, target)
R_squared.compute()#.shape
"""

## Saving the model

In [None]:
if yaml_config["experiment"]["weights_saved"] == True:
    PATH_ae = f'../nn_weights/{yaml_config["dataset"]["name"]}_exp{yaml_config["experiment"]["experiment_number"]}.pt'
    torch.save(torus_ae.state_dict(), PATH_ae)

## Losses plot

In [None]:
# loss ploting
if yaml_config["training_mode"]["diagnostic_mode"] == True:
    #fig,axes = ricci_regularization.PlottingTools.plotsmart(dict2print)
    fig,axes = ricci_regularization.PlottingTools.PlotSmartConvolve(dict2print)
else:
    fig,axes = ricci_regularization.PlottingTools.plotfromdict(dict_loss_arrays)
if yaml_config["experiment"]["violent_saving"] == True:
    fig.savefig(f"{Path_pictures}/losses_exp"+str(yaml_config["experiment"]["experiment_number"])+".pdf",bbox_inches='tight',format="pdf")

In [None]:
"""
fig,axes = ricci_regularization.PlottingTools.plot9losses(mse_loss_array,curvature_loss_array,g_inv_meanperbatch_array)
if yaml_config["experiment"]["violent_saving"] == True:
    fig.savefig(f"{Path_pictures}/9losses_exp{yaml_config["experiment"]["experiment_number"]}.pdf",bbox_inches='tight',format="pdf")
"""

## Torus latent space

In [None]:
#inspiration for torus_ae.encoder2lifting
"""
def circle2anglevectorized(zLatentTensor,d = yaml_config["architecture"]["latent_dim"]):
    cosphi = zLatentTensor[:, 0:yaml_config["architecture"]["latent_dim"]]
    sinphi = zLatentTensor[:, yaml_config["architecture"]["latent_dim"]:2*yaml_config["architecture"]["latent_dim"]]
    phi = torch.acos(cosphi)*torch.sgn(torch.asin(sinphi))
    return phi
"""


In [None]:
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( train_loader, position=0 ):
#for (data, labels) in train_loader:
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    #zlist.append(torus_ae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,yaml_config["dataset"]["D"])))
    colorlist.append(labels) 

In [None]:
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
encoded_points_no_grad = encoded_points.detach()/math.pi
color_array = torch.cat(colorlist).detach()

In [None]:
plt.figure(figsize=(8, 6))
if yaml_config["dataset"]["name"] == "Swissroll":
    my_cmap = "jet"
else:
    my_cmap = ricci_regularization.PlottingTools.discrete_cmap(yaml_config["dataset"]["k"], 'jet')
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array-1, marker='o', edgecolor='none', cmap=my_cmap)

if yaml_config["dataset"]["name"] in ["Synthetic","MNIST","MNIST01"]:
    plt.colorbar(ticks=range(yaml_config["dataset"]["k"]),orientation="vertical")
plt.grid(True)
if yaml_config["experiment"]["violent_saving"] == True:
    plt.savefig(f"{Path_pictures}/latent_space_exp"+str(yaml_config["experiment"]["experiment_number"])+".pdf",bbox_inches='tight',format="pdf")
plt.show()

In [None]:
json_config = {
    "experiment_name": experiment_name
}

## Saving json report

In [None]:
json_config = {
    "experiment_name": experiment_name,
    "experiment_number": yaml_config["experiment"]["experiment_number"],
    "dataset": {
        "name": yaml_config["dataset"]["name"],
        "parameters": {
            "k": yaml_config["dataset"].get("k", None),
            "n": yaml_config["dataset"].get("n", None),
            "D": yaml_config["dataset"].get("D", None),
            "d": yaml_config["dataset"].get("d", None),
            "shift_class": yaml_config["dataset"].get("shift_class", None),
            "intercl_var": yaml_config["dataset"].get("intercl_var", None),
            "var_class": yaml_config["dataset"].get("var_class", None),
            "sr_noise": yaml_config["dataset"].get("sr_noise", None)
        },
        "selected_labels": yaml_config["dataset"].get("selected_labels", None)
    },
    "architecture": {
        "name": yaml_config["architecture"]["type"], 
        "input_dim": yaml_config["architecture"]["output_dim"],
        "latent_dim": yaml_config["architecture"]["latent_dim"]
    },
    "optimization_parameters": {
        "learning_rate": yaml_config["optimizer_settings"]["lr"],
        "batch_size": yaml_config["data_loader_settings"]["batch_size"],
        "split_ratio": yaml_config["data_loader_settings"]["split_ratio"],
        "num_epochs": yaml_config["optimizer_settings"]["num_epochs"],
        "weight_decay": yaml_config["optimizer_settings"]["weight_decay"],
        "random_shuffling": yaml_config["data_loader_settings"]["random_shuffling"],
        "random_seed": yaml_config["data_loader_settings"]["random_seed"]
    },
    "losses": {
        "lambda_recon": yaml_config["loss_settings"]["lambda_recon"],
        "lambda_unif": yaml_config["loss_settings"]["lambda_unif"],
        "Number of moments used": yaml_config["loss_settings"]["num_moments"],
        "lambda_curv": yaml_config["loss_settings"]["lambda_curv"],
        "delta_curv": yaml_config["loss_settings"]["delta_curv"],
        "curvature_penalization_mode": yaml_config["loss_settings"]["curvature_penalization_mode"],
        "g_inv regularization eps": yaml_config["loss_settings"]["eps"],
        "lambda_contractive_encoder": yaml_config["loss_settings"]["lambda_contractive_encoder"],
        "delta_encoder": yaml_config["loss_settings"]["delta_encoder"],
        "lambda_contractive_decoder": yaml_config["loss_settings"]["lambda_contractive_decoder"],
        "delta_decoder": yaml_config["loss_settings"]["delta_decoder"],
        "diagnostic_mode": yaml_config["training_mode"]["diagnostic_mode"],
        "compute_curvature": yaml_config["training_mode"]["compute_curvature"]
    },
    "OOD_parameters": {
        "OOD_regime": yaml_config["training_mode"]["OOD_regime"],
        "start_ood": yaml_config["OOD_settings"]["start_ood"],
        "T_ood": yaml_config["OOD_settings"]["T_ood"],
        "n_ood": yaml_config["OOD_settings"]["n_ood"],
        "sigma_ood": yaml_config["OOD_settings"]["sigma_ood"],
        "N_extr": yaml_config["OOD_settings"]["N_extr"],
        "r_ood": yaml_config["OOD_settings"]["r_ood"],
        "OOD_w": yaml_config["OOD_settings"]["OOD_w"]
    },
    "Path_pictures": f"../experiments/{experiment_name}/experiment"+str(yaml_config['experiment']['experiment_number']),
    "Path_weights": "../nn_weights/",
    "Path_experiments": "../experiments/"
}
if yaml_config["experiment"]["weights_saved"] == True:
    json_config["weights_saved_at"] = PATH_ae
if yaml_config["experiment"]["weights_loaded"] == True:
    json_config["weights_loaded_from"] = PATH_weights_loaded
# Save dictionary to JSON file
with open(f"../experiments/{experiment_name}exp"+str(yaml_config["experiment"]["experiment_number"])+".json", 'w') as json_file:
    json.dump(json_config, json_file, indent=4)

In [None]:
json_results = {
    "training_results_on_test_data": {
        "R^2": test_R_squared,
        "mse_loss": test_mse,  
        "unif_loss": test_unif, 
        "curv_loss": test_curv  
    }
}
if yaml_config["experiment"]["violent_saving"] == True:
    with open(f"{Path_pictures}/exp"+str(yaml_config["experiment"]["experiment_number"])+".json", 'w') as json_file:
        json.dump(json_results, json_file, indent=4)