In [None]:
import os
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

import umap
import phate
import scipy.stats as stats

from utils import load_csv

## treated

In [None]:
home_dir = ''  # main directory
plotting_group = '3days_treated'
plotting_feature = 'dapi_nucleus_total_intensity'
treated_phate_model_dir = f'{home_dir}/output/model/treated_phate_model.pickle'
untreated_phate_model_dir = f'{home_dir}/output/model/untreated_phate_model.pickle'
# selected features for phate embedding
feature_names = [
    'cyclina_mean_ratio',
    'cyclinb_mean_ratio',
    'cyclind_mean_ratio',
    'dapi_nucleus_total_intensity',
]

group_id_dict = {'14days_treated': 0, '3days_treated': 1, '3days_no_treated': 2}

In [None]:
# tick step
step = 1

# get feature df based on list of features
def get_features_df(feature_names, dir, group_label=None, group_id_dict=group_id_dict):
    total_df = load_csv(f'{dir}/output/total_intensity.csv').dropna(axis=0)
    mean_df = load_csv(f'{dir}/output/mean_intensity.csv').dropna(axis=0)
    median_df = load_csv(f'{dir}/output/median_intensity.csv').dropna(axis=0)
    ratio_df = load_csv(f'{dir}/output/ratio.csv').dropna(axis=0)

    if group_label is not None:
        total_df = total_df[total_df['group'] == group_id_dict[group_label]]
        mean_df = mean_df[mean_df['group'] == group_id_dict[group_label]]
        median_df = median_df[median_df['group'] == group_id_dict[group_label]]
        ratio_df = ratio_df[ratio_df['group'] == group_id_dict[group_label]]

    # simple filter for extreme cyclina ratio values
    ratio_df = ratio_df[ratio_df['cyclina_mean_ratio'] < 100]
    keep_index = ratio_df.index
    total_df = total_df.loc[keep_index]
    mean_df = mean_df.loc[keep_index]
    median_df = median_df.loc[keep_index]

    # base metadata
    data_df = total_df[['cell_id', 'cell_position', 'nucleus_id', 'bbox']].copy()

    for feature in feature_names:
        if 'mean_intensity' in feature:
            feature_df = mean_df[[feature]]
        elif 'median_intensity' in feature:
            feature_df = median_df[[feature]]
        elif 'ratio' in feature:
            feature_df = ratio_df[[feature]]
        elif 'total_intensity' in feature:
            feature_df = total_df[[feature]]
        else:
            continue
        data_df = pd.concat([data_df, feature_df], axis=1)

    return data_df


untreated_df = get_features_df(feature_names, home_dir, group_label='3days_no_treated', group_id_dict=group_id_dict)
untreated_df['plotting_label'] = 'untreated'
# Outlier removal only on feature columns (avoid metadata columns)
feature_matrix_untreated = untreated_df[feature_names]
z_scores = np.abs(stats.zscore(feature_matrix_untreated, nan_policy='omit'))
# keep rows where all feature z-scores < 3
mask_untreated = (z_scores < 3).all(axis=1)
untreated_df = untreated_df.loc[mask_untreated].reset_index(drop=True)

# Normalize the data
scaler = StandardScaler()
untreated_df[feature_names] = scaler.fit_transform(untreated_df[feature_names])

treated_df = get_features_df(feature_names, home_dir, group_label='3days_treated', group_id_dict=group_id_dict)
treated_df['plotting_label'] = 'treated'
# Outlier removal only on feature columns (avoid metadata columns)
feature_matrix_treated = treated_df[feature_names]
z_scores = np.abs(stats.zscore(feature_matrix_treated, nan_policy='omit'))
# keep rows where all feature z-scores < 3
mask_treated = (z_scores < 3).all(axis=1)
treated_df = treated_df.loc[mask_treated].reset_index(drop=True)
colorbar_values = treated_df[plotting_feature].to_numpy()

# Normalize the data
scaler = StandardScaler()
treated_df[feature_names] = scaler.fit_transform(treated_df[feature_names])

final_df = pd.concat([untreated_df, treated_df], axis=0).reset_index(drop=True)

# Matrix for PHATE embedding (features only)
normalized_data = final_df[feature_names].to_numpy()


# load phate model
with open(treated_phate_model_dir, 'rb') as f:
    phate_model = pickle.load(f)

data_phate = phate_model.transform(normalized_data)
print('Embedding shape:', data_phate.shape)

# Map back groups after filtering
plotting_labels = final_df['plotting_label'].to_numpy()

# Split by group for coloring
is_untreated = plotting_labels == 'untreated'

# Plot PHATE embedding with colorbar
plt.figure(figsize=(6, 4))
plt.scatter(data_phate[is_untreated, 0], data_phate[is_untreated, 1], c='lightgrey', s=20, alpha=1, label='No Treated')
m = plt.scatter(data_phate[~is_untreated, 0], data_phate[~is_untreated, 1],
            c=colorbar_values,
            cmap='RdBu_r', s=20, alpha=1, label='Treated')
m.set_clim(vmin=0, vmax=colorbar_values.max())
cbar = plt.colorbar(m)

# how many ticks
# n = 8  # desired number of ticks
# ticks = np.linspace(m.get_clim()[0], m.get_clim()[1], n)
# cbar.set_ticks(ticks)

# Set ticks length based on step size
vmin, vmax = m.get_clim()
start = vmin
end = np.floor(vmax / step + 1e-9) * step
n = int(round((end - start) / step))
ticks = start + step * np.arange(n + 1) 
cbar.set_ticks(ticks)

# Remove axes, ticks, spines
ax = plt.gca()
for spine in ax.spines.values():
    spine.set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title('')
plt.tight_layout()
plt.savefig(f'{home_dir}/output/figure/paper/{plotting_feature}_{plotting_group}_rdbu.png', dpi=300, bbox_inches='tight')
plt.show()

## untreated

In [None]:
# tick step
step = 0.3e6

untreated_df = get_features_df(feature_names, home_dir, group_label='3days_no_treated', group_id_dict=group_id_dict)
untreated_df['plotting_label'] = 'untreated'
# Outlier removal only on feature columns (avoid metadata columns)
feature_matrix_untreated = untreated_df[feature_names]
z_scores = np.abs(stats.zscore(feature_matrix_untreated, nan_policy='omit'))
# keep rows where all feature z-scores < 3
mask_untreated = (z_scores < 3).all(axis=1)
untreated_df = untreated_df.loc[mask_untreated].reset_index(drop=True)
colorbar_values = untreated_df[plotting_feature].to_numpy()
# Normalize the data
scaler = StandardScaler()
untreated_df[feature_names] = scaler.fit_transform(untreated_df[feature_names])

# Matrix for PHATE embedding (features only)
normalized_data = untreated_df[feature_names].to_numpy()

# load phate model
with open(untreated_phate_model_dir, 'rb') as f:
    phate_model = pickle.load(f)

data_phate = phate_model.transform(normalized_data)
print('Embedding shape:', data_phate.shape)

# Plot PHATE embedding with colorbar
plt.figure(figsize=(6, 4))
m = plt.scatter(data_phate[:, 0], data_phate[:, 1],
            c=colorbar_values,
            cmap='RdBu_r', s=20, alpha=1, label='NoTreated')

m.set_clim(vmin=0, vmax=colorbar_values.max())
cbar = plt.colorbar(m)

# how many ticks
# n = 8  # desired number of ticks
# ticks = np.linspace(m.get_clim()[0], m.get_clim()[1], n)
# cbar.set_ticks(ticks)

# Set ticks at every 0.3 million
vmin, vmax = m.get_clim()
start = vmin
end = np.floor(vmax / step + 1e-9) * step
n = int(round((end - start) / step))
ticks = start + step * np.arange(n + 1) 
cbar.set_ticks(ticks)

# Remove axes, ticks, spines
ax = plt.gca()
for spine in ax.spines.values():
    spine.set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title('')
plt.tight_layout()
plt.savefig(f'{home_dir}/output/figure/paper/{plotting_feature}_{plotting_group}_rdbu.png', dpi=300, bbox_inches='tight')
plt.show()