In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from morphomics.io.io import load_obj, save_obj
from kxa_analysis import dimreduction_runner, bootstrap_runner
import numpy as np
from kxa_analysis import plot_2d, plot_pi, plot_dist_matrix, mask_pi, get_base, inverse_function, get_2d, plot_vae_dist
import plotly.express as px
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch as th
from morphomics.nn_models import train_test

import pandas as pd
base_path = "results/vae_analysis/"


In [None]:
# Base path for storing results
dimreducer_path = "results/dim_reduction/Morphomics.PID_v1_l.pi_pca_vae_1_fitted_dimreducer"
reduced_path = "results/dim_reduction/Morphomics.PID_v1_l.pi_pca_vae_1_reduced_data"

vae_pip = load_obj(dimreducer_path)
mf = load_obj(reduced_path)
mf = mf.reset_index()  # Resets the index and adds the old index as a column
mf.rename(columns={'index': 'old_idcs'}, inplace=True)
pis = mf['pi']
pi_example = pis.iloc[0]

In [None]:
def get_base(pi, pixes_tokeep):
    pi_full = np.zeros_like(pi_example)
    pi_full[pixes_tokeep] = pi
    return pi_full

In [None]:
# Create a new column for the condition (Model + Sex)
mf['Condition'] = mf['Model'] + "-" + mf['Sex']
# Sort by condition
mf_sorted = mf.sort_values(by='Condition').reset_index(drop=True)

# Apply Threshold

In [None]:
pixes_tokeep = vae_pip['pixes_tokeep']
pis_threshold = pis.apply(lambda pi: mask_pi(pi, pixes_tokeep)[0])
pis_filtered = pis.apply(lambda pi: mask_pi(pi, pixes_tokeep)[1])

# Apply Scaler

In [None]:
standardizer = vae_pip['standardizer']
pis_filtered_arr = np.vstack(pis_filtered)
pis_scaled = standardizer.transform(pis_filtered_arr)

# Apply PCA

In [None]:
pca = vae_pip['fitted_pca_vae'][0]

In [None]:
pis_pca = pca.transform(pis_scaled)

# VAE

In [None]:
vae = vae_pip['fitted_pca_vae'][1]

In [None]:
pis_pca_torch = th.tensor(pis_pca, dtype=th.float32)

In [None]:
pred, z_mean, z_log_var, mse = train_test.vae_test(data = pis_pca_torch,
                                                model = vae, 
                                                sample_size = 3,
                                               )
print('mse:', mse)
print('sample size:', z_mean.shape)
print('pred shape:', pred.shape)

come back to pi

In [None]:
pred_processed_pi_mean = pred.mean(dim=0)
mf['pred'] = list(pred_processed_pi_mean)
pred_processed_pi = pred_processed_pi_mean.cpu().detach().numpy()  # If it's a PyTorch tensor, convert it to NumPy
pred_scaled_pi = pca.inverse_transform(pred_processed_pi)
pred_filter_pi = standardizer.inverse_transform(pred_scaled_pi)
mf['pi_filter_pred'] = list(pred_filter_pi)
mf['pi_pred'] = mf['pi_filter_pred'].apply(lambda pi: get_base(pi, pixes_tokeep))

interpolation

In [None]:
nb_points = 5 # Number of points along the line
x_values = np.linspace(-1, 1, nb_points)  # X values from -0.5 to 1
y_values = (-1/3) * x_values - (1/6)  # Apply the line equation

interpolation = np.column_stack((x_values, y_values)) 

In [None]:
mf_interpolation = pd.DataFrame()
mf_vae_kxa =  mf[mf['Model'].isin(['1xSaline_4h', '1xKXA_4h'])].copy()
mf_interpolation['Condition'] = mf_vae_kxa['Condition']
mf_interpolation['pca_vae'] = mf_vae_kxa['pca_vae']


# Add interpolation lines 
mf_interpolation_ = pd.DataFrame(data=nb_points*['interpolation'], columns = ['Condition'])
mf_interpolation_['pca_vae'] = list(interpolation)
mf_inter = pd.concat((mf_interpolation, mf_interpolation_)).reset_index(drop=True)
# Add color for interpolation 
from kxa_analysis import merged_dict
color_dict = dict(merged_dict, **{'interpolation': 'rgb(128, 0, 0)'  # or hex: '#800000'

 # or hex: '#FFA500'
})


In [None]:
plot_2d(mf_inter,
        feature='pca_vae', 
        title = 'VAE Latent Space of Persistence Image',
        conditions = ['Condition'],
        colors= color_dict,
        extension = 'html',
            ax_labels = ['Dim 1', 'Dim 2'],
                    name = f"{base_path}/pi_vae_kxa_interpolation_")


In [None]:
pred_pi_list = []
pred_pi_scaled_list = []
for pt in interpolation:
    pred_pi, pi_scaled = inverse_function(pt, model = vae, pca = pca, scaler = standardizer, filter=pixes_tokeep)
    pred_pi = get_2d(pred_pi)
    pred_pi_list.append(pred_pi)

    pi_scaled = get_2d(pi_scaled)
    pred_pi_scaled_list.append(pi_scaled)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors

# Set the base path and save name
base_path = "results/vae_analysis"
save_name = f"{base_path}/reconstructed_pi_scaled_interpolation_"

# Assuming pred_pi_scaled_list and interpolation are already defined
# Convert the predicted persistence images to a NumPy array if they aren't already
pred_pi_scaled_list = np.array(pred_pi_scaled_list)
num_images = pred_pi_scaled_list.shape[0]

# Calculate the number of rows needed for a 2-column layout (round up for odd numbers)
num_rows = (num_images + 1) // 2

# Create a figure with subplots; adjust the figsize for better visibility
fig, ax = plt.subplots(num_rows, 2, figsize=(12, num_rows * 5))

# Define vmin and vmax for normalization
vmin, vmax = pred_pi_scaled_list.min(), pred_pi_scaled_list.max()

# Custom colormap: Define colors for negative, zero, and positive values
colors_list = ["purple", "white", "orange"]
custom_cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", colors_list)

# Normalize the colormap such that 0 is centered
norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)

# Loop through each image and plot the heatmap
for i in range(num_rows):
    for j in range(2):
        index = i * 2 + j  # Compute the index for the flattened matrix
        if index < num_images:
            im = ax[i, j].imshow(
                pred_pi_scaled_list[index],
                cmap=custom_cmap,
                norm=norm,
                interpolation='nearest',
                origin='lower'
            )
            ax[i, j].set_title(f'Reconstructed Scaled Persistence Image {index + 1} ({np.round(interpolation[index], 2)})')
        else:
            ax[i, j].axis('off')  # Turn off axis for empty subplots

# Create a single colorbar spanning all subplots using the last image's mappable
cbar = fig.colorbar(im, ax=ax, orientation='vertical', fraction=0.03, pad=0.04)
cbar.set_label('Scaled Persistence Density')

# Adjust layout to avoid overlap and reserve space for the colorbar
plt.tight_layout()
plt.subplots_adjust(right=0.85)

# Save the figure as a PDF
fig.savefig(save_name + '.pdf', format='pdf')

# Display the plot
plt.show()


In [None]:
save_name = f"{base_path}/reconstructed_pi_interpolation_"

# Assuming pred_pi_list and interpolation are already defined
pred_pi_list = np.array(pred_pi_list)
num_images = pred_pi_list.shape[0]

# Calculate the number of rows needed for 2x2 plots
num_rows = (num_images + 1) // 2  # Round up for odd numbers

# Create a figure for the plots
fig, ax = plt.subplots(num_rows, 2, figsize=(12, num_rows * 5))  # Adjusting figsize for better visibility

# Define vmin and vmax for normalization
vmin = np.min(pred_pi_list)
vmax = np.max(pred_pi_list)

for i in range(num_rows):
    for j in range(2):
        index = i * 2 + j  # Compute the index for the flattened matrix
        if index < num_images:
            # Plot the predicted heatmap with the inverted hot colormap
            cax = ax[i, j].imshow(pred_pi_list[index], cmap='hot', vmin=vmin, vmax=vmax, interpolation='nearest', origin='lower')
            ax[i, j].set_title(f'Reconstructed Persistence Image {index + 1} {np.round(interpolation[index], 2)}')
        else:
            ax[i, j].axis('off')  # Turn off the axis for blank spaces

# Create a single colorbar that spans all heatmaps
# Use the first heatmap's axes to set the colorbar
cbar = fig.colorbar(cax, ax=ax[:, :], orientation='vertical', fraction=0.03, pad=0.04)
cbar.set_label('Persistence Density')  # Optional: Add a label to the colorbar

# Adjust layout to avoid overlap
plt.tight_layout()
plt.subplots_adjust(right=0.85)  # Adjust the right side to give space for the colorbar

# Save the figure as a PDF with the specified name
fig.savefig(save_name + '.pdf', format='pdf')

# Display the plot
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Assuming pred_pi_list is your list of 2D arrays
# Assuming interpolation is the list of coordinates (same length as pred_pi_list)
# Example: interpolation = [np.random.rand(2), np.random.rand(2), ...]

# Step 1: Compute the sum and max for each 2D array in the list
sum_values = []
max_values = []
coordinates = []

for i, pred_pi in enumerate(pred_pi_list):
    sum_values.append(np.sum(pred_pi))  # Sum of all pixel values in the array
    max_values.append(np.max(pred_pi))  # Max pixel value in the array
    coordinates.append(np.round(interpolation[i], 2))  # Round coordinates to 2 decimals

# Step 2: Create a DataFrame to store the results
data = {'Index': np.arange(1, len(pred_pi_list) + 1),
        'Sum': [np.round(val, 2) for val in sum_values],  # Round sum values to 2 decimals
        'Max Value': [f"{val:.2e}" if val != 0 else "0.00" for val in max_values],  # Scientific notation for Max Value
        'Coordinates': coordinates}  # Add coordinates column
df = pd.DataFrame(data)

# Step 3: Plot the table using matplotlib
fig, ax = plt.subplots(figsize=(10, 4))  # Adjust the figure size
ax.axis('off')  # Turn off the axes

# Add the title
plt.title('Reconstructed Persistence Density Statistics from Interpolation', fontsize=16, ha='center', pad=20)

# Create a table with the DataFrame values
table = ax.table(cellText=df.values, colLabels=df.columns, loc='center', cellLoc='center')

# Optionally, adjust the font size or other styles
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.2, 1.2)

# Optionally, you can save the table as a PDF
# fig.savefig('table_output.pdf', format='pdf')

# Display the plot with the table
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Assuming pred_pi_list contains 2D arrays (e.g., size 100x100)
# and interpolation is a list of coordinate arrays (one for each image)
# For example:
# pred_pi_list = [np.random.rand(100, 100) for _ in range(5)]
# interpolation = [np.random.rand(2) for _ in range(5)]
# base_path is your directory for saving

# Step 1: Initialize a list to store the projection vectors
projection_vectors = []

# Step 2: Loop through each 2D array in the list
for pred_pi, coord in zip(pred_pi_list, interpolation):
    # Step 3: Extract the diagonal values (projection on the diagonal)
    diagonal_projection = np.diagonal(pred_pi)[:70]  # Adjust the slice as needed
    # Store the projection vector for this 2D array
    projection_vectors.append(diagonal_projection)

# Step 4: Convert the list of vectors into a 2D array (if needed for further analysis)
projection_vectors = np.array(projection_vectors)

# Create a custom colormap that goes from magenta -> cyan -> grey -> black
cmap = mcolors.LinearSegmentedColormap.from_list('magenta_cyan_grey_black', 
                                                        ['magenta', 'purple', 'cyan', 'grey', 'black'])

# Create a figure for the plot
fig, ax = plt.subplots(figsize=(10, 6))

# Generate a color scale for the lines
num_lines = len(projection_vectors)
colors = [cmap(i / num_lines) for i in range(num_lines)]  # Normalize to [0, 1]

for i, (projection, color) in enumerate(zip(projection_vectors, colors)):
    ax.plot(projection, label=f'{np.round(interpolation[i], 2)}', color=color)
    
ax.set_xlabel('Persistence Image Diagonal')
ax.set_ylabel('Persistence Density')
ax.set_title('Diagonal of Reconstructed Persistence Images from Interpolation')
ax.legend()

# Define save_name using your base_path
save_name = f"{base_path}/reconstructed_pi_diagonal_"

# Save the figure as a PDF
fig.savefig(save_name + '.pdf', format='pdf')

plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Assume pred_pi_list is your list of 2D arrays (each 100x100)
# and interpolation is a list of coordinate arrays for labeling.
# For example:
# pred_pi_list = [np.random.rand(100,100) for _ in range(5)]
# interpolation = [np.random.rand(2) for _ in range(5)]
# Also assume base_path is defined.

# Create a custom colormap that goes from magenta -> cyan -> grey -> black
custom_cmap = mcolors.LinearSegmentedColormap.from_list('magenta_cyan_grey_black', 
                                                        ['magenta', 'purple', 'cyan', 'grey', 'black'])
num_images = len(pred_pi_list)
# Generate a list of colors from the custom colormap for each image
colors = [custom_cmap(i / num_images) for i in range(num_images)]

# We'll store the "thickness" curves (variance along perpendicular diagonals)
thickness_curves = []

# Loop over each image
for img in pred_pi_list:
    n = img.shape[0]  # Should be 100
    thickness_variance = []
    # For each point along the main diagonal at (i,i), compute the variance of pixel values
    # along the perpendicular direction. The perpendicular line has points: (i+d, i-d)
    for i in range(n):
        d_max = min(i, n - 1 - i)  # Maximum offset d so indices remain within bounds
        # Collect pixel values along the perpendicular line through (i, i)
        values = [img[i + d, i - d] for d in range(-d_max, d_max + 1) ]
        # Compute variance as a measure of the "thickness" at that diagonal position
        thickness_variance.append(np.sum(values))
    thickness_curves.append(np.array(thickness_variance)[:70])

# Create a figure for the thickness curves plot
fig, ax = plt.subplots(figsize=(12, 6))

# Plot the thickness curves for each image
for i in range(num_images):
    label_str = f"Image {i+1} {np.round(interpolation[i], 2)}"
    ax.plot(thickness_curves[i], label=label_str, color=colors[i], linewidth=2)
    
ax.set_xlabel("Diagonal Index")
ax.set_ylabel("Thickness")
ax.set_title("Local Thickness Along the Diagonals")
ax.legend()

# Define the save name using base_path and a descriptive filename for thickness curves
save_name = f"{base_path}/reconstructed_pi_thickness_"

# Save the figure as a PDF
fig.savefig(save_name + '.pdf', format='pdf')

plt.show()


Latent Space Coherence

In [None]:
point_origin = np.array([0,0])
pi_origin = inverse_function(point_origin, model = vae, pca = pca, scaler = standardizer, filter=pixes_tokeep)
pi_origin_2d = get_2d(pi_origin)

In [None]:
plot_pi(pi_origin_2d, cmap='hot')

In [None]:
mf.loc[:, 'dist_2d'] = mf['pca_vae'].apply(lambda p: np.linalg.norm(p - point_origin))
mf.loc[:, 'dist_pi_pred'] = mf['pi_pred'].apply(lambda p: np.linalg.norm(p - pi_origin))
mf.loc[:, 'dist_pi'] = mf['pi'].apply(lambda p: np.linalg.norm(p - pi_origin))

In [None]:
plot_vae_dist(mf, points = 'pca_vae', dist='dist_pi_pred', vmin=0.0, vmax=0.10)

In [None]:
plot_vae_dist(mf, points = 'pca_vae', dist='dist_pi', vmin=0.0, vmax=0.1)

In [None]:
mf.loc[:, 'mse'] = mf.apply(lambda row: np.linalg.norm(row['pi_pred'] - row['pi']), axis=1)

In [None]:
plot_vae_dist(mf, points = 'pca_vae', dist='mse', vmin=0.0)
