# Torus AE training

This notebook performs the training of the autoencoder (AE). 

The AE consists of the encoder $\Phi$ and the decoder $\Psi$.
The latent space of the AE is topologically a $d-$ dimensional torus $\mathcal{T}^d$, i.e it can be considered as a periodic box $[-\pi, \pi]^d$. We define a Riemannian metric on the latent space  as the pull-back of the Euclidean metric in the output space $\mathbb{R}^D$ by the decoder function $\Psi$ of the AE:
\begin{equation}
    g = \nabla \Psi ^* \nabla \Psi \ .
\end{equation}


One can switch on/off following training modes: 

- "compute_curvature"
- "compute_contractive_loss"
- "OOD_regime"
- "diagnostic_mode"

If "compute_curvature"==True, curvature functional is computed and latent space is regularized. The curvature functional is given by:
\begin{equation*}
\mathcal{L}_\mathrm{curv} := \int_M R^2 \mu \ .
\end{equation*}

If "compute_contractive_loss"==True, contractive loss that penalizes the Frobenius norm of outlyers of encoder's Jacobian is computed, i.e.:
$$
 \mathcal{L}_\mathrm{contractive} = Relu\left( \|\Phi\|_F - \delta_\mathrm{encoder}\right)
$$
the Frobenius norm of the encoder functional is computed and latent space is regularized. One might want to turn it off for faster training for initial tuning of the parameters.

If "OOD_regime"==True, than OOD sampling is performed to refine the curvature regularization results.

One might want to turn off any of the modes to speed up the training in order to tune faster the "vanilla" AE (without regularization) and obtain the optimal hyperparameters that are the initial guess to start from for training the AE with regularization.

If "diagnostic_mode"==True, following losses are plotted: MSE, $\mathcal{L}_\mathrm{unif}$, $\mathcal{L}_\mathrm{curv}$, $\det(g)$, $\|g_{reg}^{-1}\|_F$, $\|\nabla \Psi \|^2_F$, $\|\nabla \Phi \|^2_F$, where:
\begin{equation*}
\mathcal{L}_\mathrm{curv} := \int_M R^2 \mu \ ,
\end{equation*}

\begin{equation*}
\mathcal{L}_\mathrm{unif} := \sum\limits_{k=1}^{m} |\int_M z^k  \mu_N (dz) |^2 \ ,
\end{equation*}
where $R$ states for scalar curvature (see https://en.wikipedia.org/wiki/Scalar_curvature), $\mu_N = \Phi\# ( \frac{1}{N}\sum\limits_{j=1}^{N} \delta_{X_i} ) $ is the push-forward of the natural measure induced by the dataset by the encoder $\Phi$, thus $\mu_N$ is a measure on $\mathcal{T}^d$,  $ \alpha_k = \frac{1}{N} \sum_{j=1}^{N} z_j^k$ is the empirical estimator of the $k$ -th moment of the data distribution in the latent space.

If $\xi \sim \mathcal{U}[-\pi, \pi]$ and $z = e^{i \xi}$ than all the moments of $z$ are zero, namely if $\mathcal{L}_\mathrm{unif} \to 0$ as $m \to \infty$, one obtains weak convergence of the data distribution in the latent space to the uniform distribution.

Also $g_{reg} = g + \varepsilon \cdot I$ is the regularized matrix of metric for stability of inverse matrix computation, $\|\|_F$ is the Frobenius norm of the matrix.


The notebook consists of

1) Imports. Choosing hyperparameters for dataset uploading, learning and plotting such as learning rate, batch size, weights of MSE loss, curvature loss, etc. Automatic loading of train and test dataloaders. Choice among data sets "Synthetic", "Swissroll", "MNIST", "MNIST01" (any selected labels from MNIST). 
2) Architecture and device. Architecture types: Fully connected (TorusAE), Convolutional (TorusConvAE). Device: cuda/cpu. 
3) Training.
4) Report of training. Printing of graphs of losses, saves of a json file with training params.


# 1. Imports

In [None]:
# prerequisites
import torch
import matplotlib.pyplot as plt
import math
import numpy as np
from tqdm.notebook import tqdm
import os
import ricci_regularization
import yaml

## Hyperparameters loading from YAML file

In [None]:
# Open and read the YAML configuration file
#with open('../experiments/Swissroll_exp5_config.yaml', 'r') as yaml_file: 

# Some other experiments to try. Uncomment for trying and comment the previous line.

#with open('../experiments/MNIST01_exp8_config.yaml', 'r') as yaml_file: # MNIST with labels 5,8 without curvature
with open('../experiments/MNIST_Setting_1_config.yaml','r') as yaml_file:
#with open('../experiments/Synthetic_uniform_config.yaml','r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

# Print the loaded YAML configuration
print(f"YAML Configuration loaded successfully from \n: {yaml_file.name}")


In [None]:
# Construct the experiment_name
print("Experiment Name:", yaml_config["experiment"]["name"])  # Print the constructed experiment name

# Paths for saving  pictures
Path_pictures = f"../experiments/" + yaml_config["experiment"]["name"]
print(f"Path for experiment results: {Path_pictures}")  # Print the path for pictures

# Check and create directories based on configuration
if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
    os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
    print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
else:
    print(f"Directiry already exists: {Path_pictures}") 

## Dataset loading 

In [None]:
# Set the random seed for data loader reproducibility
torch.manual_seed(yaml_config["data_loader_settings"]["random_seed"])
print(f"Set random seed to: {yaml_config['data_loader_settings']['random_seed']}")

try:
    dtype = getattr(torch, yaml_config["architecture"]["weights_dtype"]) # Convert the string to actual torch dtype
except KeyError:
    dtype = torch.float32

# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype = dtype )# Convert the string to actual torch dtype

train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

# Calculate number of batches per epoch
batches_per_epoch = len(train_loader)
print(f"Number of batches per epoch: {batches_per_epoch}")

# 2. Architecture and device

In [None]:
# Set the random seed for reproducibility
torch.manual_seed(yaml_config["data_loader_settings"]["random_seed"])
print(f"Set random seed to: {yaml_config['data_loader_settings']['random_seed']}")

# Selecting the architecture type based on YAML configuration
if yaml_config["architecture"]["type"] == "TorusConvAE":
    torus_ae = ricci_regularization.Architectures.TorusConvAE(
        x_dim=yaml_config["architecture"]["output_dim"],
        h_dim1=512,
        h_dim2=256,
        z_dim=yaml_config["architecture"]["latent_dim"],
        pixels=28
    )
    print("Selected architecture: TorusConvAE")
else:
    torus_ae = ricci_regularization.Architectures.TorusAE(
        x_dim=yaml_config["architecture"]["output_dim"],
        h_dim1=512,
        h_dim2=256,
        z_dim=yaml_config["architecture"]["latent_dim"],
        dtype = dtype
    )
    print("Selected architecture: TorusAE")

# Check GPU availability and set device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available! Training will use GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is NOT available! Using CPU.")

# Move the AE model to the selected device
torus_ae.to(device)
print(f"Moved model to device: {device}")

### Loading the saved weights

In [None]:
if yaml_config["experiment"]["weights_loaded_from"] != False:
    PATH_weights_loaded = yaml_config["experiment"]["weights_loaded_from"]
    torus_ae.load_state_dict(torch.load(PATH_weights_loaded))
    torus_ae.eval()
    print(f"Weights loaded from {PATH_weights_loaded}")
else:
    print("No pretrained weights loaded as per the config.")


## Optimizer and loss function


In [None]:
optimizer = torch.optim.Adam( torus_ae.parameters(),
        lr = yaml_config["optimizer_settings"]["lr"],
        weight_decay = yaml_config["optimizer_settings"]["weight_decay"] )
print(f"Optimizer configured with learning rate {yaml_config['optimizer_settings']['lr']} and weight decay {yaml_config['optimizer_settings']['weight_decay']}.")

In [None]:
# Loss = MSE + uniform_loss + curv_loss + contractive_loss
# add loss computation mode, that is yaml_config["training_mode"]
def loss_function(recon_data, data, z, torus_ae, config):
    MSE = torch.nn.functional.mse_loss(recon_data, data, reduction='mean')
    unif_loss = ricci_regularization.LossComputation.uniform_loss( z,
        latent_dim = yaml_config["architecture"]["latent_dim"],
        num_moments = yaml_config["loss_settings"]["num_moments"])
    
    dict_losses = {
        "MSE": MSE,
        "Uniform": unif_loss,
    }
    if config["training_mode"]["compute_contractive_loss"] == True:
        encoder_jac_norm = ricci_regularization.Jacobian_norm_jacrev_vmap( data, 
                                    function = torus_ae.encoder_torus,
                                    input_dim = config["architecture"]["input_dim"] )
        outlyers_encoder_jac_norm = encoder_jac_norm - config["loss_settings"]["delta_encoder"]
        dict_losses["Contractive"] = torch.nn.ReLU()( outlyers_encoder_jac_norm ).max()
    if config["training_mode"]["compute_curvature"] == True:
        encoded_points_no_grad = torus_ae.encoder2lifting(data).detach()
        if config["training_mode"]["curvature_computation_mode"] == "jacfwd":        
            #Sc_on_data, metric_on_data = ricci_regularization.Sc_g_jacfwd(encoded_points_no_grad,
            #                                function=torus_ae.decoder_torus,eps=config["loss_settings"]["eps"])
            #det_on_data = torch.det(metric_on_data)
            #dict_losses["Curvature"] = (torch.sqrt(det_on_data)*torch.square(Sc_on_data)).mean() 
            
            # avoiding recursive hell
            dict_losses["Curvature"] = ricci_regularization.curvature_loss_jacfwd(encoded_points_no_grad, function=torus_ae.decoder_torus,eps=config["loss_settings"]["eps"])
            if config["training_mode"]["diagnostic_mode"] == True: # FIX this!
                dict_losses["curv_squared_mean"] = (torch.square(Sc_on_data)).mean()
                dict_losses["curv_squared_max"] = (torch.square(Sc_on_data)).max()
        elif config["training_mode"]["curvature_computation_mode"] == "fd":
            dict_losses["Curvature"] = ricci_regularization.curvature_loss(points=encoded_points_no_grad,
                                                        function=torus_ae.decoder_torus, h = 0.01, eps=0.)
    if config["training_mode"]["diagnostic_mode"] == True:
        if config["training_mode"]["compute_curvature"] == False:
            encoded_points_no_grad = torus_ae.encoder2lifting(data).detach()
            metric_on_data = ricci_regularization.metric_jacfwd_vmap(encoded_points_no_grad,
                                                                     function = torus_ae.decoder_torus)
            det_on_data = torch.det(metric_on_data)    
        g_inv_train_batch = torch.linalg.inv(metric_on_data + config["loss_settings"]["eps"]*torch.eye(config["architecture"]["latent_dim"]).to(device))
        g_inv_norm_train_batch = torch.linalg.matrix_norm(g_inv_train_batch)
        dict_losses["g_inv_norm_mean"] = torch.mean(g_inv_norm_train_batch)
        dict_losses["g_inv_norm_max"] = torch.max(g_inv_norm_train_batch)
        dict_losses["g_det_mean"] = det_on_data.mean()
        dict_losses["g_det_max"] = det_on_data.max()
        dict_losses["g_det_min"] = det_on_data.min()
        decoder_jac_norm = torch.func.vmap(torch.trace)(metric_on_data)
        dict_losses["decoder_jac_norm_mean"] = decoder_jac_norm.mean()
        dict_losses["decoder_jac_norm_max"] = decoder_jac_norm.max()
        dict_losses["decoder_contractive_loss"] = (torch.nn.ReLU()(decoder_jac_norm)).max()
        if config["training_mode"]["compute_contractive_loss"] == False:
            encoder_jac_norm = ricci_regularization.Jacobian_norm_jacrev_vmap( data, 
                                        function = torus_ae.encoder_torus,
                                        input_dim = config["architecture"]["input_dim"] )
        encoder_jac_norm_mean = encoder_jac_norm.mean()
        encoder_jac_norm_max = encoder_jac_norm.max()
        dict_losses["encoder_jac_norm_mean"] = encoder_jac_norm_mean
        dict_losses["encoder_jac_norm_max"] = encoder_jac_norm_max
    return dict_losses

In [None]:
def train( epoch=1, batch_idx = 0, dict_loss_arrays={}):
    #  creating a dict for losses shown in progress bar
    dict_loss2print = {}
    if batch_idx == 0:
        dict_loss_arrays = {}
    torus_ae.train()
    print("Epoch %d"%epoch)
    t = tqdm( train_loader, desc="Train", position=0 )

    # OOD points initialization
    if yaml_config["training_mode"]["OOD_regime"] == True:
        first_batch,_ = next(iter(train_loader))
        first_batch = first_batch.to(device)
        extreme_curv_points_tensor = torus_ae.encoder2lifting(first_batch.view(-1,yaml_config["architecture"]["input_dim"])[:yaml_config["OOD_settings"]["N_extr"]]).detach()
        extreme_curv_points_tensor.to(device)
        extreme_curv_value_tensor = ricci_regularization.Sc_jacfwd_vmap(extreme_curv_points_tensor, 
                function=torus_ae.decoder_torus,eps=yaml_config["loss_settings"]["eps"])
    
    for (data, labels) in t:
        data = data.to(device)
        data = data.reshape(-1, yaml_config["architecture"]["input_dim"])
        optimizer.zero_grad()
        # Forward
        recon_batch, z = torus_ae( data )
        # Computing necessary losses on the current batch
        dict_losses = loss_function( recon_batch, data, z,
                                     torus_ae = torus_ae,
                                     config = yaml_config )
        # appending current losses to loss history
        for key in dict_losses.keys():
            if batch_idx == 0:
                dict_loss_arrays[key] = []
            # losses to keep in memory:
            dict_loss_arrays[key].append(dict_losses[key].item())
            # losses to show on progress bar: 
            dict_loss2print[key] = f"{dict_losses[key].item():.4f}"
            # moving average (per epoch)
            #dict_loss2print[key] = f"{np.array(dict_loss_arrays[key])[-batches_per_epoch:].mean():.4f}"
        # end for 
        loss = yaml_config["loss_settings"]["lambda_recon"] * dict_losses["MSE"] 
        loss += yaml_config["loss_settings"]["lambda_unif"] * dict_losses["Uniform"] 

        if yaml_config["training_mode"]["compute_contractive_loss"] == True:
            # adding the contractive loss on the currenct batch to the loss function
            loss += yaml_config["loss_settings"]["lambda_contractive_encoder"] * dict_losses["Contractive"]
            
        if (yaml_config["training_mode"]["compute_curvature"] == True): 
            # adding the curvature loss on the currenct batch to the loss function
            loss += yaml_config["loss_settings"]["lambda_curv"] * dict_losses["Curvature"]
            
        # OOD regime (optional)
        if yaml_config["training_mode"]["OOD_regime"] == True:
            OOD_params_dict = yaml_config["OOD_settings"]
            extreme_curv_points_tensor, extreme_curv_value_tensor = ricci_regularization.find_extreme_curvature_points(
                data_batch = data,
                extreme_curv_points_tensor = extreme_curv_points_tensor,
                extreme_curv_value_tensor = extreme_curv_value_tensor,
                batch_idx = batch_idx,
                encoder=torus_ae.encoder2lifting,decoder = torus_ae.decoder_torus,
                OOD_params_dict = yaml_config["OOD_settings"],
                output_dim=yaml_config["architecture"]["input_dim"])
            
            if (batch_idx % OOD_params_dict["T_ood"] == 0) & (batch_idx >= yaml_config["OOD_settings"]["start_ood"]):
                OOD_curvature_loss = ricci_regularization.OODTools.curv_loss_on_OOD_samples(
                    extreme_curv_points_tensor = extreme_curv_points_tensor,
                    decoder = torus_ae.decoder_torus,
                    OOD_params_dict = yaml_config["OOD_settings"],
                    latent_space_dim = yaml_config["architecture"]["latent_dim"] )
                if yaml_config["training_mode"]["diagnostic_mode"] == True:
                    print("Curvature functional at OOD points", OOD_curvature_loss)
                loss = yaml_config["OOD_settings"]["OOD_w"] * OOD_curvature_loss
            #end if
        #end if
        
        # Backpropagate
        loss.backward()
        optimizer.step()

        # Progress bar plotting
        t.set_postfix(dict_loss2print)
        # Switching batch index
        batch_idx += 1
    #end for
    return batch_idx, dict_loss_arrays

def test(dict_loss_arrays = {}, batch_idx = 0):
    dict_loss2print = {}
    torus_ae.to(device)
    t = tqdm( test_loader, desc="Test", position=1 )
    for data, _ in t:
        data = data.to(device)
        data = data.reshape(-1, yaml_config["architecture"]["input_dim"])
        recon_batch, z = torus_ae(data)
        dict_losses = loss_function(recon_batch, data, z,torus_ae = torus_ae, config=yaml_config)
        # appending current losses to loss history    
        for key in dict_losses.keys():
            if batch_idx == 0:
                dict_loss_arrays[key] = []
            dict_loss_arrays[key].append(dict_losses[key].item())
            # mean losses to print
            dict_loss2print[key] = f"{dict_losses[key]:.4f}"
#            dict_loss2print[key] = f"{np.array(dict_loss_arrays[key]).mean():.4f}"
        t.set_postfix(dict_loss2print)
        # switch batch index
        batch_idx+=1
    #end for
    return batch_idx, dict_loss_arrays

# 3. Training

In [None]:
batch_idx = 0
test_batch_idx = 0
dict_loss_arrays = {}
test_dict_loss_arrays = {}

# Launch
for epoch in range(1, yaml_config["optimizer_settings"]["num_epochs"] + 1):
  torus_ae.to(device)
  batch_idx, dict_loss_arrays = train(epoch=epoch,batch_idx=batch_idx,dict_loss_arrays=dict_loss_arrays)
  if yaml_config["training_mode"]["diagnostic_mode"] == True :
    dict2print = ricci_regularization.PlottingTools.translate_dict(dict2print=dict_loss_arrays, 
                include_curvature_plots=yaml_config["training_mode"]["compute_curvature"],
                eps=yaml_config["loss_settings"]["eps"])
    ricci_regularization.PlottingTools.plotsmart(dict2print)
    fig = ricci_regularization.point_plot(encoder=torus_ae.encoder2lifting, data=test_dataset,
                                          dataset_name=yaml_config["dataset"]["name"], 
                                          batch_idx=batch_idx,config=yaml_config, device = device)
    #fig.show()
  #else:
  #  ricci_regularization.PlottingTools.plotfromdict(dict_of_losses=dict_loss_arrays)  
  if (yaml_config["dataset"]["name"] == "MNIST01"):
    ricci_regularization.PlottingTools.plot_ae_outputs_selected(
      test_dataset=test_dataset,
      encoder=torus_ae.cpu().encoder2lifting,
      decoder=torus_ae.cpu().decoder_torus,
      selected_labels=yaml_config["dataset"]["selected_labels"] )
  elif (yaml_config["dataset"]["name"] == "MNIST"):
    ricci_regularization.PlottingTools.plot_ae_outputs(
      test_dataset=test_dataset,
      encoder=torus_ae.cpu().encoder2lifting,
      decoder=torus_ae.cpu().decoder_torus )
  # end if
  # Test 
  torus_ae.to(device)
  test_batch_idx,test_dict_loss_arrays = test(batch_idx=test_batch_idx,dict_loss_arrays=test_dict_loss_arrays) 
  # end for

# 4. Report of training

## Saving the model state dictionary

In [None]:
torch.save(torus_ae.state_dict(), f'{Path_pictures}/ae_weights.pt')
print("AE weights saved at:", Path_pictures)

## Test losses, $R^2$

In [None]:
# compute test losses
_,dict_test_losses = test()

test_mse = np.array(dict_test_losses["MSE"]).mean()
# collect test batches in a list and then concatenate to get one tensor for test data

list = []
for data,_ in test_loader:
    list.append(data.float())
# compute variance
var = torch.var(torch.cat(list).flatten())
# compute R^2
test_R_squared = 1 - test_mse/var
#printing

print("Test losses")
print(f"MSE:, {test_mse.item():.4f}")
print(f"R²: {test_R_squared.item():.4f}")


test_unif = np.array(dict_test_losses["Uniform"]).mean()
try:
    test_curv = np.array(dict_test_losses["Curvature"]).mean()
    print(f"Curvature: {test_curv.item():.4f}")
except KeyError:
    test_curv = "not_computed"
    print(test_curv)
print(f"Uniform:, {test_unif.item():.4f}")



## Losses plot

In [None]:
# loss ploting
if yaml_config["training_mode"]["diagnostic_mode"] == True:
    #fig,axes = ricci_regularization.PlottingTools.plotsmart(dict2print)
    fig,axes = ricci_regularization.PlottingTools.PlotSmartConvolve(dict2print)
else:
    fig,axes = ricci_regularization.PlottingTools.plotfromdict(dict_loss_arrays)
fig.savefig(f"{Path_pictures}/losses.pdf",bbox_inches='tight',format="pdf")

### Saving loss history

In [None]:
# check if it is optimal
# saving loss history
torch.save(dict_loss_arrays, f'{Path_pictures}/losses_history.pt')

## Torus latent space

In [None]:
torus_ae.cpu() # switch device to cpu for plotting
fig = ricci_regularization.point_plot(encoder=torus_ae.encoder2lifting, data=test_dataset,
                                      dataset_name = yaml_config["dataset"]["name"], 
                                      batch_idx=batch_idx,config=yaml_config, 
                                      show_title=False, device = "cpu", figsize=(9,9))
fig.savefig( Path_pictures + "/latent_space.pdf", bbox_inches = 'tight', format = "pdf" )
fig.show()

## Saving test losses in a json file

In [None]:
import json
json_results = {
        "R^2_test_data": test_R_squared.item(),
        "mse_loss_test_data": test_mse.item(),  
        "unif_loss_test_data": test_unif.item()  
}
if yaml_config["training_mode"]["compute_curvature"] == True:
    json_results["curv_loss_test_data"] = test_curv.item()
else:
    json_results["curv_loss_test_data"] = "Not computed"

with open(f"{Path_pictures}/results.json", 'w') as json_file:
    json.dump(json_results, json_file, indent=4)