In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from morphomics.io.io import load_obj, save_obj
from kxa_analysis import dimreduction_runner, bootstrap_runner
import numpy as np
from kxa_analysis import plot_2d, plot_pi, plot_dist_matrix, mask_pi, get_base, inverse_function, get_2d, plot_vae_dist
import plotly.express as px
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch as th
from morphomics.nn_models import train_test

import pandas as pd
base_path = "results/vae_analysis/"

from microglia_retina.analysis import io_retina
from microglia_retina.analysis import Retina, retina

from scipy.spatial.distance import cdist


In [None]:
reduced_path = "results/dim_reduction/Morphomics.PID_v1_l.pi_pca_vae_1_reduced_data"
mf = load_obj(reduced_path)
mf_vae_kxa =  mf[mf['Model'].isin(['1xSaline_4h', '1xKXA_4h'])].copy()
layers = mf_vae_kxa['Layer'].unique()

mf_vae_kxa['Condition'] = mf_vae_kxa.apply(lambda row: f"{row['Model']}-{row['Sex']}", axis=1)
mf_vae_kxa['Condition_l'] = mf_vae_kxa.apply(lambda row: f"{row['Layer']}-{row['Model']}-{row['Sex']}-{row['Animal']}", axis=1)
animals = mf_vae_kxa['Animal'].unique()
# First, ensure that 'pca_vae' contains numeric values and convert them if needed
mf_vae_kxa['pca_vae'] = mf_vae_kxa['pca_vae'].apply(lambda x: np.array(x) if isinstance(x, list) else x)

# Compute the median of each component (assuming each value in 'pca_vae' is a vector)
# We apply np.median to the individual components of the vector for each group in Condition
median_pi_by_condition = mf_vae_kxa.groupby('Condition')['pca_vae'].apply(lambda x: np.median(np.array(x).tolist(), axis=0))


In [None]:
def rgb_str_to_mpl_tuple(rgb_str):
    # Convert 'rgb(r, g, b)' to a tuple of floats between 0 and 1
    rgb = [int(c) for c in rgb_str.strip('rgb()').split(',')]
    return tuple([c/255 for c in rgb])
conditions_colors = {
    # '1xSaline_4h-F': rgb_str_to_mpl_tuple('rgb(130, 130, 130)'),
    '1xSaline_4h-M': rgb_str_to_mpl_tuple('rgb(20, 20, 20)'),
    '1xKXA_4h-F': rgb_str_to_mpl_tuple('rgb(255, 50, 255)'),
    '1xSaline_4h-F': rgb_str_to_mpl_tuple('rgb(20, 20, 20)'),
    '1xKXA_4h-M': rgb_str_to_mpl_tuple('rgb(20, 20, 20)'),

    # '1xKXA_4h-M': rgb_str_to_mpl_tuple('rgb(50, 255, 255)'),
}

# Define a color map for the layers
layer_colors = {
    'L1': 'darkblue',
    'L2_3': 'forestgreen',
    'L5_6': 'darkorange',
    'L4': 'purple'
}


In [None]:
def retina(retina, is_voronoi=True, is_hull=True, is_retina_hull=False, title=None, subtitle=None, show=True):
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(12, 12))

    # Loop through the layers and plot the branches and somas with different colors based on their layer
    for i, row in retina.microglias.iterrows():
        condition = row['pred_Condition']
        branch_xy = row['branches_coord']
        lay = row['Layer']
        soma_xy = row['soma_coord']
        x, y = np.array(branch_xy[:,0]), np.array(branch_xy[:,1])

        x0, y0 = soma_xy[0], soma_xy[1]
        # Set color based on layer
        cond_color = conditions_colors.get(condition, 'black')  # Default to black if the layer is unknown
        lay_color = layer_colors.get(lay, 'black')  # Default to black if the layer is unknown

        # Plot branch coordinates and soma coordinates
        ax.plot(x, y, 'o', color=cond_color, markersize=0.2)
        ax.plot(x0, y0, 'o', color=lay_color,  markersize=3)

    branches_coords = np.concatenate(retina.microglias['branches_coord'])
    x, y = branches_coords[:, 0], branches_coords[:, 1]

    # Set axis limits based on the bounding box of all plotted points
    x_min = min(np.min(x), np.min(x0))
    x_max = max(np.max(x), np.max(x0))
    y_min = min(np.min(y), np.min(y0))
    y_max = max(np.max(y), np.max(y0))

    padding = 0.05  # Add some padding around the edges
    ax.set_xlim(x_min - padding * (x_max - x_min), x_max + padding * (x_max - x_min))
    ax.set_ylim(y_min - padding * (y_max - y_min), y_max + padding * (y_max - y_min))

    # Set plot aspect and title
    ax.set_aspect('equal')
    if title is None:
        if retina.retina_path is not None:
            title = retina.retina_path
            index = title.rfind('/')
            title = title[:index]
            title.replace('/', ' ')
        else:
            title = ', '.join([f"{key}: {value}" for key, value in retina.conditions.items()])
    ax.set_title(title, fontsize=14)  # Main title

    # Add a subtitle
    if subtitle is not None:
        fig.text(0.5, 0.92, subtitle, ha='center', va='center', fontsize=10, color='gray')

    # ax.legend()
    if show:
        plt.show()

    # Return figure and axis for further customization
    return fig, ax


In [None]:
# Convert the Series of medians to a DataFrame for distance calculations
conditions = median_pi_by_condition.index.to_list()
medians = np.stack(median_pi_by_condition.values)

# Function to find closest condition for a given vector
def predict_condition(vec):
    dists = cdist([vec], medians)[0]  # distances to all medians
    closest_idx = np.argmin(dists)
    return conditions[closest_idx]

# Apply the function to assign predicted condition
mf_vae_kxa['pred_Condition'] = mf_vae_kxa['pca_vae'].apply(predict_condition)

In [None]:
for animal in animals:
    sub_df = mf_vae_kxa[mf_vae_kxa['Animal'] == animal]
    my_retina = Retina(info_frame=sub_df, conditions = {'Model':sub_df['Model'].iloc[0], 'Sex':sub_df['Sex'].iloc[0]})
    my_retina.set_soma(dim = 3)
    my_retina.set_branches(dim = 3)
    retina(my_retina, is_voronoi=False, is_hull=False, is_retina_hull=False, title=None, subtitle = None, show = True)


In [None]:
df = mf_vae_kxa

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

labels = ['1xSaline_4h-F', '1xSaline_4h-M', '1xKXA_4h-F', '1xKXA_4h-M']
cm = confusion_matrix(df['Condition'], df['pred_Condition'], labels=labels)

fig, ax = plt.subplots(figsize=(8, 8))  # 👈 increase size here
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap='Blues', ax=ax, colorbar=False)

# Reverse y-axis
ax.invert_yaxis()

plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()


In [None]:
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(df['Condition'], df['pred_Condition'])
print(f'Accuracy: {accuracy:.2f}')


In [None]:
from sklearn.metrics import classification_report

report = classification_report(df['Condition'], df['pred_Condition'], target_names=['1xSaline_4h-F', '1xSaline_4h-M', '1xKXA_4h-F', '1xKXA_4h-M'])
print(report)


In [None]:
import pandas as pd
import numpy as np
from scipy.stats import chi2_contingency

def cramers_v(x, y):
    confusion_matrix = pd.crosstab(x, y)
    chi2 = chi2_contingency(confusion_matrix)[0]
    n = confusion_matrix.sum().sum()
    phi2 = chi2 / n
    r, k = confusion_matrix.shape
    return np.sqrt(phi2 / min(k - 1, r - 1))

cramers = cramers_v(df['Condition'], df['pred_Condition'])
print(f"Cramér's V: {cramers:.2f}")


In [None]:
df['Condition_bin'] = df['Condition'].apply(lambda x: x if x == '1xKXA_4h-F' else 'Ctrl')
df['pred_Condition_bin'] = df['pred_Condition'].apply(lambda x: x if x == '1xKXA_4h-F' else 'Ctrl')


In [None]:
labels = ['Ctrl', '1xKXA_4h-F']
cm = confusion_matrix(df['Condition_bin'], df['pred_Condition_bin'], labels=labels)

fig, ax = plt.subplots(figsize=(8, 8))  # 👈 increase size here
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap='Blues', ax=ax, colorbar=False)

# Reverse y-axis
ax.invert_yaxis()

plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()


In [None]:
accuracy = accuracy_score(df['Condition_bin'], df['pred_Condition_bin'])
print(f'Accuracy: {accuracy:.2f}')
