In [None]:
import scanpy as sc 
import anndata as ad
import polars as pl
import os
import sys
# Add the module_folder to the sys.path list
sys.path.append('/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/analysis_functions/')
from plotting import *

In [None]:
PROJECT_DIR = "/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/SPECS"

## Compound densities

In [None]:
adata =ad.read_h5ad("moa/sc_embedding_specs5k_undersampled_sign.h5ad")

In [None]:
adata

In [None]:
sc.pl.umap(adata, color = "Metadata_cmpdName")

In [None]:
#sc.tl.embedding_density(adata_filt, basis='umap', groupby='moa_broad')
sc.pl.embedding_density(adata, basis='umap', key='umap_density_moa_broad', save = f"figures/moa_density_umap.png")

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import datetime

def make_jointplot_anndata(adata, dmso_name, colouring, cmpd, save_path=None):
    # Extract UMAP data from .obsm
    umap_data = pd.DataFrame(adata.obsm['X_umap'], columns=['UMAP1', 'UMAP2']).reset_index()

    # Join with metadata from .obs
    embedding = pd.concat([umap_data,pd.DataFrame(adata.obs).reset_index()], axis = 1)
    embedding = embedding.reset_index()
    #embedding["Metadata_cmpdNameConc2"] = embedding["Metadata_cmpdName"].astype(str) + "_" + embedding["Metadata_cmpdConc"].astype(str)
    # Generate a color palette based on unique values in the colouring column
    unique_treatments = embedding[colouring].unique()
    palette = sns.color_palette("Set2", len(unique_treatments))
    color_map = dict(zip(unique_treatments, palette))
    # Adjust colors and transparency if colouring is 'Metadata_cmpdName'

    embedding[colouring] = embedding[colouring].astype(str)

# Adjust colors and transparency if colouring is 'Metadata_cmpdName'
    if colouring == colouring:
        if dmso_name in color_map:
            color_map[dmso_name] = 'lightgrey'

    def get_size(val):
        return 10 if val != dmso_name else 3
    
    embedding['color'] = embedding[colouring].apply(lambda x: color_map.get(x, "default_color"))
    #embedding['color'] = embedding[colouring].map(color_map)
    embedding['size'] = embedding[colouring].apply(get_size)
    
    # Increase the DPI for displaying
    plt.rcParams['figure.dpi'] = 300
    
    # Create the base joint plot
    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    unique_treatments = list(unique_treatments)
    index = unique_treatments.index(dmso_name)

# Remove the element from its current position and insert it at index 0
    if index != 0:
        unique_treatments.insert(0, unique_treatments.pop(index))
    # Plot KDE plots for each category
    for treatment in unique_treatments:
        subset = embedding[embedding[colouring] == treatment]
        
        sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=color_map[treatment], legend=False)
        sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=color_map[treatment], legend=False)

    # Plot the scatter plots
    for treatment in unique_treatments:
        subset = embedding[embedding[colouring] == treatment]
        alpha_val = 0.3 if treatment == dmso_name and colouring == 'Metadata_cmpdName' else 0.8
        g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], label=treatment, alpha=alpha_val, edgecolor='white', linewidth=0.5)
    
    g.ax_joint.set_title(cmpd)
    legend = g.ax_joint.legend(fontsize=10)
    legend.get_frame().set_facecolor('white')
    legend_elements = [Line2D([0], [0], marker='o', linestyle= "None", color=color_map[treatment], label=treatment, markersize=5, markerfacecolor=color_map[treatment], alpha=1) for treatment in unique_treatments]
    legend = g.ax_joint.legend(handles=legend_elements, fontsize=10, title=colouring)
    legend.get_frame().set_facecolor('white')
    
    if save_path != None:
        current_time = datetime.datetime.now()
        timestamp = current_time.strftime("%Y%m%d_%H%M%S")
        g.savefig(f"{save_path}.png", dpi=300)

    plt.show()


In [None]:
make_jointplot_anndata(adata_filt, '[DMSO]', 'Metadata_cmpdName', "")

## E distance

In [None]:
from scperturb import *
import pandas as pd

In [None]:
adata = ad.read("/home/jovyan/share/data/analyses/benjamin/cellxgene/SPECS/deepprofiler/embeddings/grit_reference_locations_fixed_SPECS.h5ad")

In [None]:
estats = edist(adata, obs_key='Metadata_cmpdName', obsm_key='X_pca', dist='sqeuclidean', n_jobs= -1)


In [None]:
estats

In [None]:
def e_dist_violin(estats, ctrl, condition = None):
    estats_control = pd.DataFrame(estats.loc[:, ctrl])
    estats_control.columns=['E-statistic']
    estats_control['tmp'] = np.log10(np.clip(estats_control['E-statistic'], 0, np.infty)+1)
    scale=0.75
    with sns.axes_style('whitegrid'):
        fig, ax = plt.subplots(figsize=[20*scale, 5*scale], dpi=300)
    # np.log10(estats_control['E-statistic']+1)
    sns.violinplot(data=estats_control.drop(ctrl), x='tmp', inner=None, color=".8", width=0.8, bw=0.5)
    ax = sns.swarmplot(data=estats_control.drop(ctrl), x=estats_control.drop(ctrl)['tmp'], y=['']*len(estats_control.drop(ctrl)), size=10, marker='$\mathbf{\odot}$',
                    edgecolors='white', linewidth=0, palette=['tab:blue', 'tab:red'])
    
    top3_indices = estats_control['tmp'].nlargest(3).index
    top3_values = estats_control.loc[top3_indices, 'tmp']
   # Adjust the x location of each annotation to prevent overlap
    offsets = [(0.5, 0.3), (-0.4, 0.3), (0.2, 0.4)]
    y_offsets = [-0.02, 0.03, 0.001]
    for offset, y_offset, (idx, row) in zip(offsets,y_offsets, estats_control.loc[top3_indices].iterrows()):
        ax.annotate(idx, xy=(row['tmp'], y_offset), xytext=(row['tmp'] + offset[0], offset[1]),
                    arrowprops=dict(arrowstyle='-|>', color='black', lw=1.5),
                    ha='center', va='bottom', fontsize=12, color='black')
    plt.xlabel('E-distance+1 to unperturbed (log scale)')
    ax.axvline(0, c='grey', linestyle='--', linewidth=4)
    # log scale x ticks
    from matplotlib import ticker as mticker
    ax.xaxis.set_major_formatter(mticker.StrMethodFormatter("$10^{{{x:.0f}}}$"))
    xmin, xmax = ax.get_xlim()
    tick_range = np.arange(0, xmax)
    ax.xaxis.set_ticks(tick_range)
    ax.xaxis.set_ticks([np.log10(x+1) for p in tick_range for x in np.linspace(10 ** p, 10 ** (p + 1), 10)], minor=True)
    ax.set_xlim([-1, 5.5])
    if condition is not None:
        plt.savefig(f'e_dist_violin_{condition}.png', bbox_inches='tight')
    plt.show()

In [None]:
e_dist_violin(estats, "[DMSO]")


In [None]:
fig, ax = plt.subplots(1,1, figsize=[20,15])
order = estats.sort_index().index
sns.heatmap(estats)
plt.show()

In [None]:
from scipy.stats import zscore
from scipy.cluster.hierarchy import distance, linkage, dendrogram
from scipy.cluster import hierarchy

def cluster_matrix(matrix, how='row', return_order=False, method='centroid'):
    '''
    Hierarchical clustering of a matrix/dataframe. `how` can be 'col', 'row' or 'both' (default: 'row').
    '''
    if how not in ['col', 'row', 'both']:
        raise ValueError('Value for "how" must be row or col.')
    if how!='both':
        M = matrix if how=='row' else matrix.T
        dist = distance.pdist(M)
        link = linkage(dist, method=method)
        dend = dendrogram(link, no_plot=True)
        order = np.array(dend['leaves'], dtype=int)
        if return_order:
            return order
        elif isinstance(matrix, pd.DataFrame):
            return matrix.iloc[order] if how=='row' else matrix.iloc[:, order]
        else:
            return matrix[order] if how=='row' else matrix[:, order]
    else:
        if return_order:
            warn('Returning order when clustering both row and col is not supported.')
        matrix_ = cluster_matrix(matrix, how='row', return_order=False, method=method)
        return cluster_matrix(matrix_, how='col', return_order=False, method=method)
    

In [None]:
ed=estats

with sns.axes_style('whitegrid'):
    fig, ax  = plt.subplots(1, figsize=[20,20], dpi = 300)

sns.heatmap(cluster_matrix(ed, "both"), robust=True, xticklabels=True, yticklabels=True, ax=ax)
ax.set_title('E-distance between Beactica perturbations')
ax.set_xlabel('E-distance')
ax.set_ylabel('E-distance')
plt.savefig("figures/e_distance/e_dist_heatmap_grit_filtered_ref.png", dpi = 300)


In [None]:
with sns.axes_style('whitegrid'):
    fig, ax  = plt.subplots(1, figsize=[20,20], dpi = 300)

Z = hierarchy.linkage(ed, 'single')
dn = hierarchy.dendrogram(Z, labels=ed.columns, color_threshold=800, ax=ax)
plt.xticks(rotation=90)
plt.grid(axis='y')
plt.ylabel('E-distance')
plt.xlabel('cell type')
plt.title('Perturbations hierarchy based on e dist')
plt.savefig("figures/e_distance/e_dist_dendogram_grit_all_ref.png", dpi = 300)
plt.show()

## E test

In [None]:
etest_grit_all = pd.read_csv("etest_res_grit_all_500_samples_10000_perms.csv")

In [None]:
estats_control = pd.DataFrame(estats_all.loc[:, 'DIMETHYL SULFOXIDE'])
estats_control.columns=['E-statistic']
estats_control['tmp'] = np.log10(np.clip(estats_control['E-statistic'], 0, np.infty)+1)
estats_control.reset_index(inplace=True)

In [None]:
e_dist_sign = pd.merge(estats_control, etest_grit_all, left_on = "Metadata_cmpdName", right_on = "Unnamed: 0", how = "left")

In [None]:
def e_dist_violin_sign(estats, condition, sign = False):
    estats = estats[estats["Metadata_cmpdName"] != "DIMETHYL SULFOXIDE"]
    scale=0.75
    with sns.axes_style('whitegrid'):
        fig, ax = plt.subplots(figsize=[20*scale, 5*scale], dpi=300)
    # np.log10(estats_control['E-statistic']+1)
    sns.violinplot(data=estats, x='tmp', inner=None, color=".8", width=0.8, bw=0.5)
    if sign:
        ax = sns.swarmplot(data=estats, x=estats['tmp'], y=['']*len(estats), size=10, marker='$\mathbf{\odot}$',
                    edgecolors='white', linewidth=0, hue = "significant_adj",palette=['tab:blue', 'tab:red', 'tab::orange'])
        
    else:
        ax = sns.swarmplot(data=estats, x=estats['tmp'], y=['']*len(estats), size=10, marker='$\mathbf{\odot}$',
                    edgecolors='white', linewidth=0, palette=['tab:blue', 'tab:red'])
    plt.xlabel('E-distance+1 to unperturbed (log scale)')
    ax.axvline(0, c='grey', linestyle='--', linewidth=4)
    # log scale x ticks
    from matplotlib import ticker as mticker
    ax.xaxis.set_major_formatter(mticker.StrMethodFormatter("$10^{{{x:.0f}}}$"))
    xmin, xmax = ax.get_xlim()
    tick_range = np.arange(0, xmax)
    ax.xaxis.set_ticks(tick_range)
    ax.xaxis.set_ticks([np.log10(x+1) for p in tick_range for x in np.linspace(10 ** p, 10 ** (p + 1), 10)], minor=True)
    ax.set_xlim([-1, 5.5])
    if sign:
        plt.savefig(f'e_dist_violin_{condition}_sign_lev.png', bbox_inches='tight')
    else:
        plt.savefig(f'e_dist_violin_{condition}.png', bbox_inches='tight')
    plt.show()

In [None]:
e_dist_violin_sign(e_dist_sign, "full_grit", sign = True)

In [None]:
e_dist_sign[e_dist_sign["significant_adj"] == "False"]

In [None]:
def show_topn_comps(edist, estats, n, control):
    df_sorted = estats.sort_values(by="E-statistic", ascending=False)
    # Get top n values from target_column
    top = df_sorted.head(n)["Metadata_cmpdName"].tolist()
    # Get bottom n values from target_column
    bot= df_sorted.tail(n)["Metadata_cmpdName"].tolist()
    scale=0.3
    plt.subplots(figsize=[13*scale,10*scale], dpi=300)
    conds = ["DIMETHYL SULFOXIDE"] + bot + top
    sub = edist.loc[conds, conds] + 1
    sub = np.log10(np.clip(sub, 0, np.infty)+1)
    sns.heatmap(sub, robust=False, linewidth=3)
    plt.xticks(rotation=45, ha='right')
    #plt.savefig(f'figures/e_distance/top_bot_{n}_cmpds_heatmap.png', bbox_inches='tight')
    plt.show()
    return top, bot

In [None]:
top_comp, bottom_comp = show_topn_comps(estats_all, e_dist_sign, 5, "DIMETHYL SULFOXIDE")

In [None]:
def run_edist_top_n(adata, cmpd_list, group):
    filt_dat = adata[adata.obs["Metadata_cmpdName"].isin(cmpd_list)]
    e_stats = edist(filt_dat, obs_key=group, obsm_key='X_pca', dist='sqeuclidean', n_jobs= -1)
    
    return e_stats

In [None]:
top5_edist = run_edist_top_n(adata, top_comp + ["DIMETHYL SULFOXIDE"], "Metadata_cmpdNameConc")

In [None]:
top5_edist

In [None]:
e_dist_violin(top5_edist, ctrl = "DIMETHYL SULFOXIDE_0.1", condition = None)

In [None]:
with sns.axes_style('whitegrid'):
    fig, ax  = plt.subplots(1, figsize=[20,20], dpi = 300)

sns.heatmap(cluster_matrix(top5_edist, "both"), robust=True, xticklabels=True, yticklabels=True, ax=ax)
ax.set_title('E-distance between Beactica perturbations')
ax.set_xlabel('E-distance')
ax.set_ylabel('E-distance')
plt.savefig("figures/e_distance/e_dist_heatmap_top5_conc.png", dpi = 300)


In [None]:
def edist_dose_response_curve(dist):
    reference_row = dist['DIMETHYL SULFOXIDE_0.1']
    reference_pd = pd.DataFrame(reference_row)
    reference_pd = reference_pd.reset_index()
    split_columns = reference_pd['Metadata_cmpdNameConc'].str.split('_', expand=True)

    # Assign the split columns back to the original DataFrame
    reference_pd['Treatment'] = split_columns[0]
    reference_pd['Concentration'] = split_columns[1]
    reference_pd = reference_pd[reference_pd["Treatment"] != "DIMETHYL SULFOXIDE"]
    # Convert 'Concentration' column to numeric
    reference_pd['Concentration'] = pd.to_numeric(reference_pd['Concentration'])

    data_for_plotting = reference_pd.rename(columns = {"DIMETHYL SULFOXIDE_0.1" : "E-Distance"})
    # Step 3: Create a New DataFrame
    # Step 4: Plot
    plt.figure(figsize=(15, 10), dpi = 300)
    sns.lineplot(data=data_for_plotting, x='Concentration', y='E-Distance', hue='Treatment', marker='o')
    plt.title('E distance to DMSO')
    plt.xlabel('Concentration')
    plt.ylabel('E distance')
    plt.legend(title='Treatment', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig("figures/e_distance/top5_edist_dose_resp.png")
    plt.show()

In [None]:
edist_dose_response_curve(top5_edist)