In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from morphomics.io.io import load_obj, save_obj
from kxa_analysis import dimreduction_runner, bootstrap_runner
import numpy as np
from kxa_analysis import plot_2d, plot_pi, plot_dist_matrix, mask_pi
import plotly.express as px
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import pandas as pd
base_path = "results/pca/"


In [None]:
# Base path for storing results
dimreducer_path = "results/dim_reduction/Morphomics.PID_v1_l.pi_pca_vae_1_fitted_dimreducer"
reduced_path = "results/dim_reduction/Morphomics.PID_v1_l.pi_pca_vae_1_reduced_data"

vae_pip = load_obj(dimreducer_path)
mf = load_obj(reduced_path)
mf = mf.reset_index()  # Resets the index and adds the old index as a column
mf.rename(columns={'index': 'old_idcs'}, inplace=True)
pis = mf['pi']
pi_example = pis.iloc[0]

In [None]:
def get_base(pi, pixes_tokeep):
    pi_full = np.zeros_like(pi_example)
    pi_full[pixes_tokeep] = pi
    return pi_full

In [None]:
# Create a new column for the condition (Model + Sex)
mf['Condition'] = mf['Model'] + "-" + mf['Sex']
# Sort by condition
mf_sorted = mf.sort_values(by='Condition').reset_index(drop=True)

# Apply Threshold

In [None]:
pixes_tokeep = vae_pip['pixes_tokeep']
pis_threshold = pis.apply(lambda pi: mask_pi(pi, pixes_tokeep)[0])
pis_filtered = pis.apply(lambda pi: mask_pi(pi, pixes_tokeep)[1])

In [None]:
pi_th_example = pis_threshold.iloc[0]
white_orange_cmap = mcolors.LinearSegmentedColormap.from_list("white_orange", ["white", "orange"])

plot_pi(pi_th_example, title= 'Persistence Image Example', 
        is_log = False, scale = 'Persistence Density',
        cmap = 'hot',
        name = f"{base_path}/pi_example")

# Apply Scaler

In [None]:
standardizer = vae_pip['standardizer']
pis_filtered_arr = np.vstack(pis_filtered)
pis_scaled = standardizer.transform(pis_filtered_arr)

In [None]:
pi_scaled_full_example = get_base(pis_scaled[0], pixes_tokeep)

# Define vmin and vmax for normalization
vmin, vmax = pi_scaled_full_example.min(), pi_scaled_full_example.max()

# Custom colormap: Choose alternative colors (Green for negative, Black for zero, Purple for positive)
colors = ["purple", "white", "orange"]  # Change colors here if needed
custom_cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", colors)

# Normalize the colors to center at 0
norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)

plot_pi(pi_scaled_full_example,cmap=custom_cmap,
        title = 'Rescaled Persistence Image Example', is_log = False, norm=norm, 
        scale = 'Rescaled Persistence Density',
        name = f"{base_path}/rescaled_pi_example")

In [None]:
pca = vae_pip['fitted_pca_vae'][0]

In [None]:
pis_pca = pca.transform(pis_scaled)

In [None]:
mf_pca = mf[['Layer', 'Model', 'Sex']]

In [None]:
pis_pca_2d = pis_pca[:,[0,1]]
mf_pca['pi_pca_2d'] = list(pis_pca_2d)
plot_2d(mf_pca, 'pi_pca_2d', title = 'PCA of Peristence Image', conditions = ['Model', 'Sex'], 
        show = True,
        ax_labels=['PC 1', 'PC 2'],
        extension = 'html',

        name = f"{base_path}/PC1_PC2_")
# def plot_2d(df, feature, title, conditions = ['Model', 'Sex'], colors= merged_dict, name = None, extension = 'pdf', show = True):


In [None]:
pis_pca_2d = pis_pca[:,[0,2]]
mf_pca['pi_pca_2d'] = list(pis_pca_2d)
plot_2d(mf_pca, 'pi_pca_2d', title = 'PCA of Peristence Image', conditions = ['Model', 'Sex'], 
        show = True,
        ax_labels=['PC 1', 'PC 3'],
        extension = 'html',
        name = f"{base_path}/PC1_PC3_")

In [None]:
###  rf_rfe_selected or rf_sorted_idx or svm_rfe_selected

# Get the PCA components (eigenvectors)
loadings = pca.components_  # Shape: (n_components, n_features)

# If feature names are available
feature_names = [f'Feature{i+1}' for i in range(loadings.shape[1])]  # Replace with actual feature names if available

# Convert to a DataFrame for better readability
pc_load_df = pd.DataFrame(loadings, columns=feature_names, index=[f'PC{i+1}' for i in range(len(loadings))])

In [None]:
for i in [0,1,2]:
    save_name = f"{base_path}/loading_pc{i+1}"  # Ensure proper path formatting
    title = f"Loading of Principal Component {i+1}"  # Fixed typo in "Principal"
    
    pc_load_i_full = get_base(pc_load_df.iloc[i], pixes_tokeep)  # Check if indexing is correct
    plot_pi(pc_load_i_full, name=save_name, title=title, is_log=False)