# Torus AE training

This notebook performs the training of the autoencoder (AE). 

The AE consists of the encoder $\Phi$ and the decoder $\Psi$.
The latent space of the AE is topologically a $d-$ dimensional torus $\mathcal{T}^d$, i.e it can be considered as a periodic box $[-\pi, \pi]^d$. We define a Riemannian metric on the latent space  as the pull-back of the Euclidean metric in the output space $\mathbb{R}^D$ by the decoder function $\Psi$ of the AE:
\begin{equation}
    g = \nabla \Psi ^* \nabla \Psi \ .
\end{equation}


One can switch on/off following training modes: 

- "compute_curvature"
- "compute_contractive_loss"
- "OOD_regime"
- "diagnostic_mode"

If "compute_curvature"==True, curvature functional is computed and latent space is regularized. The curvature functional is given by:
\begin{equation*}
\mathcal{L}_\mathrm{curv} := \int_M R^2 \mu \ .
\end{equation*}

If "compute_contractive_loss"==True, contractive loss that penalizes the Frobenius norm of outlyers of encoder's Jacobian is computed, i.e.:
$$
 \mathcal{L}_\mathrm{contractive} = Relu\left( \|\Phi\|_F - \delta_\mathrm{encoder}\right)
$$
the Frobenius norm of the encoder functional is computed and latent space is regularized. One might want to turn it off for faster training for initial tuning of the parameters.

If "OOD_regime"==True, than OOD sampling is performed to refine the curvature regularization results.

One might want to turn off any of the modes to speed up the training in order to tune faster the "vanilla" AE (without regularization) and obtain the optimal hyperparameters that are the initial guess to start from for training the AE with regularization.

If "diagnostic_mode"==True, following losses are plotted: MSE, $\mathcal{L}_\mathrm{unif}$, $\mathcal{L}_\mathrm{curv}$, $\det(g)$, $\|g_{reg}^{-1}\|_F$, $\|\nabla \Psi \|^2_F$, $\|\nabla \Phi \|^2_F$, where:
\begin{equation*}
\mathcal{L}_\mathrm{curv} := \int_M R^2 \mu \ ,
\end{equation*}

\begin{equation*}
\mathcal{L}_\mathrm{unif} := \sum\limits_{k=1}^{m} |\int_M z^k  \mu_N (dz) |^2 \ ,
\end{equation*}
where $R$ states for scalar curvature (see https://en.wikipedia.org/wiki/Scalar_curvature), $\mu_N = \Phi\# ( \frac{1}{N}\sum\limits_{j=1}^{N} \delta_{X_i} ) $ is the push-forward of the natural measure induced by the dataset by the encoder $\Phi$, thus $\mu_N$ is a measure on $\mathcal{T}^d$,  $ \alpha_k = \frac{1}{N} \sum_{j=1}^{N} z_j^k$ is the empirical estimator of the $k$ -th moment of the data distribution in the latent space.

If $\xi \sim \mathcal{U}[-\pi, \pi]$ and $z = e^{i \xi}$ than all the moments of $z$ are zero, namely if $\mathcal{L}_\mathrm{unif} \to 0$ as $m \to \infty$, one obtains weak convergence of the data distribution in the latent space to the uniform distribution.

Also $g_{reg} = g + \varepsilon \cdot I$ is the regularized matrix of metric for stability of inverse matrix computation, $\|\|_F$ is the Frobenius norm of the matrix.


The notebook consists of

1) Imports. Choosing hyperparameters for dataset uploading, learning and plotting such as learning rate, batch size, weights of MSE loss, curvature loss, etc. Automatic loading of train and test dataloaders. Choice among data sets "Synthetic", "Swissroll", "MNIST", "MNIST01" (any selected labels from MNIST). 
2) Architecture and device. Architecture types: Fully connected (TorusAE), Convolutional (TorusConvAE). Device: cuda/cpu. 
3) Training.
4) Report of training. Printing of graphs of losses, saves of a json file with training params.


# 1. Imports

In [None]:
# prerequisites
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import os
import ricci_regularization
import yaml
import time

## Hyperparameters loading from YAML file

In [None]:
# Choose setting of the experiment
AE_setting_name = 'MNIST_Setting_1_exp1'
# Open and read the YAML configuration file
with open(f'../experiments/{AE_setting_name}_config.yaml','r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

# Print the loaded YAML configuration
print(f"YAML Configuration loaded successfully from \n: {yaml_file.name}")


In [None]:
# Construct the experiment_name
print("Experiment Name:", AE_setting_name)  # Print the constructed experiment name

# Paths for saving  pictures
Path_pictures = f"../experiments/" + AE_setting_name
print(f"Path for experiment results: {Path_pictures}")  # Print the path for pictures

# Check and create directories based on configuration
if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
    os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
    print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
else:
    print(f"Directiry already exists: {Path_pictures}") 

## Dataset loading 

In [None]:
# Set the random seed for data loader reproducibility
torch.manual_seed(yaml_config["data_loader_settings"]["random_seed"])
print(f"Set random seed to: {yaml_config['data_loader_settings']['random_seed']}")

try:
    dtype = getattr(torch, yaml_config["architecture"]["weights_dtype"]) # Convert the string to actual torch dtype
except KeyError:
    dtype = torch.float32

# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype = dtype )# Convert the string to actual torch dtype

train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders
test_dataset_labels = test_dataset.targets

print("Data loaders created successfully.")

# Calculate number of batches per epoch
batches_per_epoch = len(train_loader)
print(f"Number of batches per epoch: {batches_per_epoch}")

# 2. Architecture and device

In [None]:
# Set the random seed for reproducibility
torch.manual_seed(yaml_config["data_loader_settings"]["random_seed"])
print(f"Set random seed to: {yaml_config['data_loader_settings']['random_seed']}")

# Selecting the architecture type based on YAML configuration
if yaml_config["architecture"]["type"] == "TorusConvAE":
    torus_ae = ricci_regularization.Architectures.TorusConvAE(
        x_dim=yaml_config["architecture"]["output_dim"],
        h_dim1=512,
        h_dim2=256,
        z_dim=yaml_config["architecture"]["latent_dim"],
        pixels=28
    )
    print("Selected architecture: TorusConvAE")
else:
    torus_ae = ricci_regularization.Architectures.TorusAE(
        x_dim=yaml_config["architecture"]["output_dim"],
        h_dim1=512,
        h_dim2=256,
        z_dim=yaml_config["architecture"]["latent_dim"],
        dtype = dtype
    )
    print("Selected architecture: TorusAE")

# Check GPU availability and set device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available! Training will use GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is NOT available! Using CPU.")

# Move the AE model to the selected device
torus_ae.to(device)
print(f"Moved model to device: {device}")

### Loading the saved weights

In [None]:
if yaml_config["experiment"]["weights_loaded_from"] != False:
    PATH_weights_loaded = yaml_config["experiment"]["weights_loaded_from"]
    torus_ae.load_state_dict(torch.load(PATH_weights_loaded))
    torus_ae.eval()
    print(f"Weights loaded from {PATH_weights_loaded}")
else:
    print("No pretrained weights loaded as per the config.")


# 3. Training

In [None]:
# losses
batch_idx = 0
test_batch_idx = 0
dict_loss_arrays = {}
test_dict_loss_arrays = {}

## Optimizer

In [None]:
num_epochs = yaml_config["optimizer_settings"]["num_epochs"]
lr = yaml_config["optimizer_settings"]["lr"]
#lr = 0.5e-3 # for manual change of the lr
optimizer = torch.optim.Adam( torus_ae.parameters(),
        lr = lr,
        weight_decay = yaml_config["optimizer_settings"]["weight_decay"] )
print(f"Optimizer configured with learning rate {lr} and weight decay {yaml_config['optimizer_settings']['weight_decay']}.")
print("Number of epochs:", num_epochs)

In [None]:
%matplotlib inline
# timing
start_time = time.time()
# Launch
for epoch in range(1, num_epochs + 1):
  torus_ae.to(device)
  #batch_idx, dict_loss_arrays = train(epoch=epoch,batch_idx=batch_idx,dict_loss_arrays=dict_loss_arrays)
  batch_idx, dict_loss_arrays = ricci_regularization.train(torus_ae,training_config=yaml_config,
      train_loader=train_loader,optimizer=optimizer, epoch=epoch,batch_idx=batch_idx,
      dict_loss_arrays=dict_loss_arrays)
  if yaml_config["training_mode"]["diagnostic_mode"] == True :
    dict_losses_to_plot = ricci_regularization.PlottingTools.translate_dict(dict_losses_to_plot=dict_loss_arrays, 
                include_curvature_plots=yaml_config["training_mode"]["compute_curvature"],
                eps=yaml_config["loss_settings"]["eps"])
    ricci_regularization.PlottingTools.plotsmart(dict_losses_to_plot)

  #plotting
  torus_ae.cpu()
  with torch.no_grad():
      encoded_points = torus_ae.encoder2lifting(test_dataset.data.view(-1,yaml_config["architecture"]["output_dim"]).to(torch.float32)/255.)
  ricci_regularization.point_plot_fast(encoded_points,test_dataset_labels,batch_idx,yaml_config)
  torus_ae.cuda()
  plt.show()
  #fig.show()
  #else:
  #  ricci_regularization.PlottingTools.plotfromdict(dict_of_losses=dict_loss_arrays)  
  if (yaml_config["dataset"]["name"] in ["MNIST01","MNIST_subset"]):
    ricci_regularization.PlottingTools.plot_ae_outputs_selected(
      test_dataset=test_dataset,
      encoder=torus_ae.cpu().encoder2lifting,
      decoder=torus_ae.cpu().decoder_torus,
      selected_labels=yaml_config["dataset"]["selected_labels"] )
  elif (yaml_config["dataset"]["name"] == "MNIST"):
    ricci_regularization.PlottingTools.plot_ae_outputs(
      test_dataset=test_dataset,
      encoder=torus_ae.cpu().encoder2lifting,
      decoder=torus_ae.cpu().decoder_torus )
  # end if
  # Test 
  torus_ae.to(device)
  test_batch_idx,test_dict_loss_arrays = ricci_regularization.test(torus_ae, test_loader=test_loader,
      training_config=yaml_config, batch_idx=test_batch_idx,dict_loss_arrays=test_dict_loss_arrays) 
  # end for

#timing
end_time = time.time()
algorithm_execution_time = end_time - start_time

# 4. Report of training

## Saving the model state dictionary

In [None]:
torch.save(torus_ae.state_dict(), f'{Path_pictures}/ae_weights.pt')
print("AE weights saved at:", Path_pictures)

## Test losses, $R^2$

In [None]:
# compute test losses
_,dict_test_losses = ricci_regularization.test(torus_ae,train_loader,training_config=yaml_config)

test_mse = np.array(dict_test_losses["MSE"]).mean()
# collect test batches in a list and then concatenate to get one tensor for test data

list = []
for data,_ in test_loader:
    list.append(data.float())
# compute variance
var = torch.var(torch.cat(list).flatten())
# compute R^2
test_R_squared = 1 - test_mse/var
#printing

print("Test losses")
print(f"MSE:, {test_mse.item():.4f}")
print(f"R²: {test_R_squared.item():.4f}")


test_unif = np.array(dict_test_losses["Equidistribution"]).mean()
try:
    test_curv = np.array(dict_test_losses["Curvature"]).mean()
    print(f"Curvature: {test_curv.item():.4f}")
except KeyError:
    test_curv = "not_computed"
    print("Curvature:", test_curv)
print(f"Equidistribution:, {test_unif.item():.4f}")



## Losses plot

In [None]:
#plotting smooth:
dict_losses_to_plot = ricci_regularization.PlottingTools.translate_dict(dict_losses_to_plot=dict_loss_arrays,
                eps=yaml_config["loss_settings"]["eps"])
fig,axes = ricci_regularization.PlottingTools.PlotSmartConvolve(dict_losses_to_plot,numwindows1=10,numwindows2=50)
# plot only non-smooth:
#fig,axes = ricci_regularization.PlottingTools.plotfromdict(dict_loss_arrays)
fig.savefig(f"{Path_pictures}/losses.pdf",bbox_inches='tight',format="pdf")

### Saving loss history

In [None]:
# check if it is optimal
# saving loss history
torch.save(dict_loss_arrays, f'{Path_pictures}/losses_history.pt')

## Saving test losses in a json file

In [None]:
import json
# Define the file path to save results
json_file_path = f"{Path_pictures}/results.json"
# check if time was computed previousely
try:
    with open(json_file_path, 'r') as json_file:
        loaded_json_results = json.load(json_file)
    training_time_accumulated_flag = True
    total_training_time = loaded_json_results[-1]["training_time"]
    total_epoch_count = loaded_json_results[-1]["epoch_count"]
    total_training_time += algorithm_execution_time
    total_epoch_count += num_epochs
    print(f"Training time is accumulated with previousely saved training time from{json_file_path}")
except FileNotFoundError:
    total_training_time = algorithm_execution_time
    total_epoch_count = num_epochs
    loaded_json_results = []
    pass

current_results = {
        "epoch_count": total_epoch_count,
        "learning_rate": lr,
        "R^2_test_data ": float(f"{test_R_squared.item():.6f}"),
        "MSE_loss_test_data": float(f"{test_mse.item():.6f}"),  
        "Equidistribution_loss_test_data": float(f"{test_unif.item():.6f}"),
        "training_time": float(f"{total_training_time:.3f}")
}
if yaml_config["training_mode"]["compute_curvature"] == True:
    current_results["Curvature_loss_test_data"] = test_curv.item()
else:
    current_results["Curvature_loss_test_data"] = "Not computed"
loaded_json_results.append(current_results)
# save all the results
with open(json_file_path, 'w') as json_file:
    json.dump(loaded_json_results, json_file, indent=4)

## Torus latent space

In [None]:

torus_ae.cpu() # switch device to cpu for plotting
fig = ricci_regularization.point_plot(encoder=torus_ae.encoder2lifting, data_loader=test_loader,
                                      batch_idx=batch_idx,config=yaml_config, 
                                      show_title=False, figsize=(9,9))
fig.savefig( Path_pictures + "/latent_space.pdf", bbox_inches = 'tight', format = "pdf" )
plt.show()