In [None]:
import scanpy as sc 
import anndata as ad
import polars as pl
import os
import sys
# Add the module_folder to the sys.path list
sys.path.append('/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/analysis_functions/')
from plotting import *

In [None]:
PROJECT_DIR = "/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/SPECS"

## E distance

In [None]:
from scperturb import *
import pandas as pd

In [None]:
adata = ad.read("/home/jovyan/share/data/analyses/benjamin/cellxgene/SPECS/deepprofiler/moa/sc_embedding_BF_specs5k.h5ad")

In [None]:
estats = edist(adata, obs_key='Metadata_cmpdName', obsm_key='X_pca', dist='sqeuclidean', n_jobs= -1)


In [None]:
estats.to_csv("edists_BF_compounds.csv")

In [None]:
estats_moa = edist(adata_NO_CBK041211, obs_key='moa_broad', obsm_key='X_pca', dist='sqeuclidean', n_jobs= -1)


In [None]:
estats_NO_CBK041211 = edist(adata_NO_CBK041211, obs_key='Metadata_cmpdName', obsm_key='X_pca', dist='sqeuclidean', n_jobs= -1)

In [None]:
import matplotlib.pyplot as plt
def e_dist_violin(estats, ctrl, condition = None):
    estats_control = pd.DataFrame(estats.loc[:, ctrl])
    estats_control.columns=['E-statistic']
    estats_control['tmp'] = np.log10(np.clip(estats_control['E-statistic'], 0, np.infty)+1)
    scale=0.75
    with sns.axes_style('whitegrid'):
        fig, ax = plt.subplots(figsize=[20*scale, 5*scale], dpi=300)
    # np.log10(estats_control['E-statistic']+1)
    sns.violinplot(data=estats_control.drop(ctrl), x='tmp', inner=None, color=".8", width=0.8, bw=0.5)
    ax = sns.swarmplot(data=estats_control.drop(ctrl), x=estats_control.drop(ctrl)['tmp'], y=['']*len(estats_control.drop(ctrl)), size=10, marker='$\mathbf{\odot}$',
                    edgecolors='white', linewidth=0, palette=['tab:blue', 'tab:red'])
    
    top3_indices = estats_control['tmp'].nlargest(3).index
    top3_values = estats_control.loc[top3_indices, 'tmp']
   # Adjust the x location of each annotation to prevent overlap
    offsets = [(0.5, 0.3), (-0.4, 0.3), (0.2, 0.4)]
    y_offsets = [-0.02, 0.03, 0.001]
    for offset, y_offset, (idx, row) in zip(offsets,y_offsets, estats_control.loc[top3_indices].iterrows()):
        ax.annotate(idx, xy=(row['tmp'], y_offset), xytext=(row['tmp'] + offset[0], offset[1]),
                    arrowprops=dict(arrowstyle='-|>', color='black', lw=1.5),
                    ha='center', va='bottom', fontsize=12, color='black')
    plt.xlabel('E-distance+1 to unperturbed (log scale)')
    ax.axvline(0, c='grey', linestyle='--', linewidth=4)
    # log scale x ticks
    from matplotlib import ticker as mticker
    ax.xaxis.set_major_formatter(mticker.StrMethodFormatter("$10^{{{x:.0f}}}$"))
    xmin, xmax = ax.get_xlim()
    tick_range = np.arange(0, xmax)
    ax.xaxis.set_ticks(tick_range)
    ax.xaxis.set_ticks([np.log10(x+1) for p in tick_range for x in np.linspace(10 ** p, 10 ** (p + 1), 10)], minor=True)
    ax.set_xlim([-1, 5.5])
    if condition is not None:
        plt.savefig(f'e_dist_violin_{condition}.png', bbox_inches='tight')
    plt.show()

In [None]:
e_dist_violin(estats, "[DMSO]")


In [None]:
fig, ax = plt.subplots(1,1, figsize=[20,15])
order = estats_NO_CBK041211.sort_index().index
sns.heatmap(estats_NO_CBK041211)
plt.show()

In [None]:
from scipy.stats import zscore
from scipy.cluster.hierarchy import distance, linkage, dendrogram
from scipy.cluster import hierarchy

def cluster_matrix(matrix, how='row', return_order=False, method='centroid'):
    '''
    Hierarchical clustering of a matrix/dataframe. `how` can be 'col', 'row' or 'both' (default: 'row').
    '''
    if how not in ['col', 'row', 'both']:
        raise ValueError('Value for "how" must be row or col.')
    if how!='both':
        M = matrix if how=='row' else matrix.T
        dist = distance.pdist(M)
        link = linkage(dist, method=method)
        dend = dendrogram(link, no_plot=True)
        order = np.array(dend['leaves'], dtype=int)
        if return_order:
            return order
        elif isinstance(matrix, pd.DataFrame):
            return matrix.iloc[order] if how=='row' else matrix.iloc[:, order]
        else:
            return matrix[order] if how=='row' else matrix[:, order]
    else:
        if return_order:
            warn('Returning order when clustering both row and col is not supported.')
        matrix_ = cluster_matrix(matrix, how='row', return_order=False, method=method)
        return cluster_matrix(matrix_, how='col', return_order=False, method=method)
    

In [None]:
ed=estats

with sns.axes_style('whitegrid'):
    fig, ax  = plt.subplots(1, figsize=[20,20], dpi = 300)

sns.heatmap(cluster_matrix(ed, "both"), robust=True, xticklabels=True, yticklabels=True, ax=ax)
ax.set_title('E-distance between selected compounds')
ax.set_xlabel('E-distance')
ax.set_ylabel('E-distance')
plt.savefig("e_dist_heatmap_moa_noCBK041211.png", dpi = 300)


In [None]:
with sns.axes_style('whitegrid'):
    fig, ax  = plt.subplots(1, figsize=[20,20], dpi = 300)

Z = hierarchy.linkage(ed, 'single')
dn = hierarchy.dendrogram(Z, labels=ed.columns, color_threshold=800, ax=ax)
plt.xticks(rotation=90)
plt.grid(axis='y')
plt.ylabel('E-distance')
plt.xlabel('cell type')
plt.title('Perturbations hierarchy based on e dist')
plt.savefig("e_dist_dendogram_moa.png", dpi = 300)
plt.show()

In [None]:
#adata.obs.groupby("Metadata_cmpdName").count().sort_values(by='Metadata_Plate')
no_dmso_df = adata.obs[adata.obs['Metadata_cmpdName'] != '[DMSO]']

group_counts = no_dmso_df.groupby("Metadata_cmpdName").size()
# Now plot the distribution
group_counts_sorted = group_counts.sort_values()

# Calculate the percentage of groups below the thresholds
threshold_200 = (group_counts_sorted < 250).mean() * 100
threshold_500 = (group_counts_sorted < 500).mean() * 100

# Now plot the distribution in ascending order
group_counts_sorted.plot(kind='bar')

ax = plt.gca()

# Remove the x-axis tick marks
ax.set_xticks([])

# Draw horizontal lines at the 200 and 500 marks
plt.axhline(y=250, color='r', linestyle='--')
plt.axhline(y=500, color='g', linestyle='--')

# Annotate the horizontal lines with the percentage of groups below the thresholds
plt.text(x=group_counts_sorted.size, y=250, s=f"{threshold_200:.2f}% groups < 250", color='r', va='bottom')
plt.text(x=group_counts_sorted.size, y=500, s=f"{threshold_500:.2f}% groups < 500", color='g', va='bottom')

# Set labels and title
plt.xlabel('Group')
plt.ylabel('Number of cells')
plt.title('Distribution of compounds sizes in ascending Order')
plt.xlabel('Compound') # Rotate x-axis labels if necessary
plt.tight_layout()  # Adjust layout to fit everything nicely
plt.show()

## E test

In [None]:
etest_grit_all = pd.read_csv("etest_res_specs5k_200_samples_100000_perms.csv")

In [None]:
estats_control = pd.DataFrame(estats.loc[:, '[DMSO]'])
estats_control.columns=['E-statistic']
estats_control['tmp'] = np.log10(np.clip(estats_control['E-statistic'], 0, np.infty)+1)
estats_control.reset_index(inplace=True)

In [None]:
e_dist_sign = pd.merge(estats_control, etest_grit_all, left_on = "Metadata_cmpdName", right_on = "Unnamed: 0", how = "left")

In [None]:
def e_dist_violin_sign(estats, condition, sign = False):
    estats = estats[estats["Metadata_cmpdName"] != "DIMETHYL SULFOXIDE"]
    scale=0.75
    with sns.axes_style('whitegrid'):
        fig, ax = plt.subplots(figsize=[20*scale, 5*scale], dpi=300)
    # np.log10(estats_control['E-statistic']+1)
    sns.violinplot(data=estats, x='tmp', inner=None, color=".8", width=0.8, bw=0.5)
    if sign:
        ax = sns.swarmplot(data=estats, x=estats['tmp'], y=['']*len(estats), size=10, marker='$\mathbf{\odot}$',
                    edgecolors='white', linewidth=0, hue = "significant_adj",palette=['tab:blue', 'tab:red', 'tab::orange'])
        
    else:
        ax = sns.swarmplot(data=estats, x=estats['tmp'], y=['']*len(estats), size=10, marker='$\mathbf{\odot}$',
                    edgecolors='white', linewidth=0, palette=['tab:blue', 'tab:red'])
    plt.xlabel('E-distance+1 to unperturbed (log scale)')
    ax.axvline(0, c='grey', linestyle='--', linewidth=4)
    # log scale x ticks
    from matplotlib import ticker as mticker
    ax.xaxis.set_major_formatter(mticker.StrMethodFormatter("$10^{{{x:.0f}}}$"))
    xmin, xmax = ax.get_xlim()
    tick_range = np.arange(0, xmax)
    ax.xaxis.set_ticks(tick_range)
    ax.xaxis.set_ticks([np.log10(x+1) for p in tick_range for x in np.linspace(10 ** p, 10 ** (p + 1), 10)], minor=True)
    ax.set_xlim([-1, 5.5])
    if sign:
        plt.savefig(f'e_dist_violin_{condition}_sign_lev.png', bbox_inches='tight')
    else:
        plt.savefig(f'e_dist_violin_{condition}.png', bbox_inches='tight')
    plt.show()

In [None]:
e_dist_violin_sign(e_dist_sign, "full_grit", sign = True)

In [None]:
e_dist_sign[e_dist_sign["significant_adj"] == "False"]

In [None]:
def show_topn_comps(edist, estats, n, control):
    df_sorted = estats.sort_values(by="E-statistic", ascending=False)
    # Get top n values from target_column
    top = df_sorted.head(n)["Metadata_cmpdName"].tolist()
    # Get bottom n values from target_column
    bot= df_sorted.tail(n)["Metadata_cmpdName"].tolist()
    scale=0.3
    plt.subplots(figsize=[13*scale,10*scale], dpi=300)
    conds = ["[DMSO]"] + bot + top
    sub = edist.loc[conds, conds] + 1
    sub = np.log10(np.clip(sub, 0, np.infty)+1)
    sns.heatmap(sub, robust=False, linewidth=3)
    plt.xticks(rotation=45, ha='right')
    #plt.savefig(f'figures/e_distance/top_bot_{n}_cmpds_heatmap.png', bbox_inches='tight')
    plt.show()
    return top, bot

In [None]:
top_comp, bottom_comp = show_topn_comps(estats, e_dist_sign, 5, "DIMETHYL SULFOXIDE")

In [None]:
def run_edist_top_n(adata, cmpd_list, group):
    filt_dat = adata[adata.obs["Metadata_cmpdName"].isin(cmpd_list)]
    e_stats = edist(filt_dat, obs_key=group, obsm_key='X_pca', dist='sqeuclidean', n_jobs= -1)
    
    return e_stats

In [None]:
top5_edist = run_edist_top_n(adata, top_comp + ["DIMETHYL SULFOXIDE"], "Metadata_cmpdNameConc")