In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from morphomics.io.io import load_obj, save_obj
from kxa_analysis import dimreduction_runner, bootstrap_runner
import numpy as np

In [None]:
# VAE parameters (KL factor)
epochs = 2000
x_values = np.linspace(2, 7, epochs)    
kl_factor_list = list((1 - np.exp(-x_values)))
vae_parameters = {}
# Initialize the nested dictionaries
vae_parameters["Dim_reductions"] = {}
vae_parameters["Dim_reductions"]["dimred_method_parameters"] = {}
vae_parameters["Dim_reductions"]["dimred_method_parameters"]["vae"] = {
    "nb_epochs": epochs,
    "kl_factor_function": kl_factor_list
}

In [None]:
experiment = 'v1_l'
# Base path for storing results
pi_lm_filepath = f"results/vectorization/Morphomics.PID_{experiment}.pi_lm.pkl"

In [None]:
weights = [1.]

vae_par = {
    f'pi_pca_vae_{w}': {
        "nb_epochs": 2000,
        "kl_factor_function": list((1 - np.exp(-x_values)) / w)
    }
    for w in weights
}


In [None]:
# Check and run each method only if the corresponding file does not exist
for key, filename in vae_par.items():
    print(f"Run {key.replace('_', ' ')}")

    input_filepath = pi_lm_filepath

    vae_parameters["Dim_reductions"]["dimred_method_parameters"]["vae"] = vae_par[key]
    vae_parameters["Dim_reductions"]["save_filename"] = key

    _ = dimreduction_runner(
        parameters_id=experiment, 
        vectors_filepath=input_filepath, 
        vectorization_name='pi', 
        toml_filename="pca_vae.toml", 
        is_bt=False,
        extra_params=vae_parameters 
        )


In [None]:
output_base_path = f"results/dim_reduction/Morphomics.PID_{experiment}."
from kxa_analysis import plot_2d
feature = 'pca_vae'
title = 'VAE Latent Space of Persistence Images'
base_folder_path = "results/vae_plot/pi/"

for key, filename in vae_par.items():
        df_path = f"results/dim_reduction/Morphomics.PID_v1_l.{key}_reduced_data"
        mf_ = load_obj(df_path)

        # plot everything
        # Replace '+' in the 'Model' column by escaping the '+' character in the filtering step
        mf_['Model'] = mf_['Model'].str.replace(r'\+', '+', regex=True)

        plot_2d(df = mf_, 
                title = title,
                feature = feature, 
                conditions = ['Model', 'Sex'],
                name = base_folder_path + key + '_kxa-safit2-fkbp5ko_',
                extension = 'html',
                show=False)