In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
from morphomics.io.io import save_obj, load_obj
from utils import load_toml, run_toml

In [None]:
parameters_filepath = "f_pi_vae.toml"

In [None]:
import itertools
import copy

# Load initial parameters
parameters = load_toml(parameters_filepath=parameters_filepath)

# Update the dimred_method_parameters with default values
parameters['Dim_reductions']['dimred_method_parameters'] = {
    'pca': {"n_components": 50, "svd_solver": False, "pca_version": 'standard'},
    'vae': {'nb_epochs': 100, 'batch_size': 32, 'optimizer':"cocob"}
}

# Additional parameters
parameters["Dim_reductions"]["filter_pixels"] = False
parameters["Dim_reductions"]["save_dimreducer"] = False
parameters['Dim_reductions']['save_data'] = False

# Define parameter grid for grid search
param_grid = {
    "standardize": [True, False],
    "vae.nn_layers": [[64, 32, 16, 8], [32, 8]],
    # "vae.learning_rate": [0.01, 0.001, 0.0001],
    # "vae.scheduler": [True, False],
    "vae.batch_layer_norm": [True, False]
}

# Generate all parameter combinations
keys, values = zip(*param_grid.items())
param_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

# Create a list to store all parameter configurations
param_list = []

# Apply each combination to the base parameters
for combo in param_combinations:
    # Create a deepcopy to avoid modifying the base parameters
    new_params = copy.deepcopy(parameters)
    
    # Update the specific parameters
    new_params["Dim_reductions"]["standardize"] = combo["standardize"]
    new_params["Dim_reductions"]["dimred_method_parameters"]["vae"]["nn_layers"] = combo["vae.nn_layers"]
    # new_params["Dim_reductions"]["dimred_method_parameters"]["vae"]["learning_rate"] = combo["vae.learning_rate"]
    # new_params["Dim_reductions"]["dimred_method_parameters"]["vae"]["scheduler"] = combo["vae.scheduler"]
    new_params["Dim_reductions"]["dimred_method_parameters"]["vae"]["batch_layer_norm"] = combo["vae.batch_layer_norm"]
    
    # Append the modified configuration to the list
    param_list.append(new_params)

# Now param_list contains all configurations
print(f"Generated {len(param_list)} parameter configurations.")
results_pca = {}
for i, parameters in enumerate(param_list):
    my_pipeline = run_toml(parameters=parameters) 
    
    results_pca[str(i)] = [parameters, my_pipeline.metadata['mse'] ]

In [None]:
result_path1 = 'results/vae_hp/adam_pca'
save_obj(results, result_path1)

In [None]:
# Sort dictionary items based on the tensor value (second element in the list)
sorted_items = sorted(results.items(), key=lambda item: item[1][1].item(), reverse=True)

# Select the top 10 entries
top_10_items = dict(sorted_items[-10:])

# Print the result
print(top_10_items)

In [None]:
my_pipeline = run_toml(parameters=parameters_top1) 

In [None]:
mf = my_pipeline.morphoframe['v1_pi_bt']
z_mean = mf['vae']
z_mean = np.vstack(z_mean)
dim1 = "dim_1"
dim2 = "dim_2"
df = pd.DataFrame(z_mean[:,[0,1]], columns=[dim1, dim2])

In [None]:
conditions = mf[['Model', 'Sex']].apply(lambda x: '-'.join(x), axis=1)
mf['conditions'] = conditions
labels = sorted(conditions.unique())
df['Label'] = mf['conditions']

colors = {'1xKXA+SAFIT2_4h-F': 'rgb(100, 255, 100)',
 '1xKXA+SAFIT2_4h-M': 'rgb(0, 100, 0)',

 '1xKXA_4h-F': 'rgb(255, 50, 255)',
 '1xKXA_4h-M': 'rgb(50, 255, 255)',

 '1xSaline+SAFIT2_4h-F': 'rgb(255, 255, 100)',
 '1xSaline+SAFIT2_4h-M': 'rgb(150, 150, 0)',

 '1xSaline_4h-F': 'rgb(130, 130, 130)',
 '1xSaline_4h-M': 'rgb(20, 20, 20)'}

In [None]:
import plotly.express as px

# Plot using plotly.express
fig = px.scatter(df, x=dim1, y=dim2, color='Label', 
                 title="VAE latent space",
                 labels={'Label': 'Condition'},
                 color_discrete_map=colors  # Apply custom color map
)
average_vectors = df.groupby('Label')[[dim1, dim2]].median().reset_index()

import plotly.graph_objects as go

fig.add_trace(
    go.Scattergl(
        x=average_vectors[dim1], 
        y=average_vectors[dim2], 
        mode='markers+text',  # Include text labels
        marker=dict(
            size=15,  # Larger marker size
            color=average_vectors['Label'].map(colors),  # Map color based on condition
            line=dict(width=2, color='black')  # Optional: black outline for clarity
        ),
        text=None,  # Set the text to the extracted labels
        textposition='top center',  # Position of the labels
        name='Median',
    )
)

# Optional: Customize the layout further if needed
fig.update_layout(
    showlegend=True
)

fig.show()
