In [None]:
import pandas as pd
import numpy as np
import matplotlib
import scanpy as sc
import seaborn as sns
import tacco as tc

In [None]:
import sys
# Make helper functions available: The notebook expects to be executed either in the sub-workflow directory or in the notebooks directory
sys.path.insert(1, '../'), sys.path.insert(1, '../workflow/'); # prefer to look just one directory up
import helper
sys.path.pop(1), sys.path.pop(1);

get_path = helper.get_paths('mouse_slideseq')

# settings

## visualization settings

In [None]:
figures_folder = get_path('plots')

# Load mouse commot data

In [None]:
adata_pathway = sc.read(f'{get_path("data")}/mouse_slideseq_pathway.h5ad')

In [None]:
all_regions = adata_pathway.obs['region'].cat.categories
ordered_regions = pd.Series(np.arange(len(all_regions))).astype(str).map({r.split(' ')[1]:r for r in adata_pathway.obs['region'].cat.categories}).to_numpy()

In [None]:
adata_pathway.obs['region'] = adata_pathway.obs['region'].astype(pd.CategoricalDtype(ordered_regions,ordered=True))

## Determine enrichments across samples using spatially separated patches

In [None]:
enrichments = {}

for value_spec in ['pathway']:
    value_adata = adata_pathway
    
    for group_key in ['State','region']:
        value_spec_group_key = f'{value_spec}_{group_key}'
        print(value_spec_group_key)
        
        enrichment_per_split_power = {}
        
        for split_power in range(6):
            print('\t',split_power)
            tc.utils.split_spatial_samples(value_adata, buffer_thickness=400, split_scheme=(2,)*split_power, sample_key='SampleID', result_key='SampleID_split', check_splits=False)

            # remove all split+group with less than 100 observations
            value_adata.obs['SampleID_split+group'] = value_adata.obs['SampleID_split'].astype(str) + value_adata.obs[group_key].astype(str)
            lowly_covered = value_adata.obs['SampleID_split+group'].value_counts()<100
            lowly_covered = lowly_covered[lowly_covered].index
            split_covered_group_key = f'split_covered_{group_key}'
            value_adata.obs[split_covered_group_key] = value_adata.obs[group_key]
            value_adata.obs.loc[value_adata.obs['SampleID_split+group'].isin(lowly_covered),split_covered_group_key] = None
            
            path_enr_sample_state = tc.tl.enrichments(value_adata,value_key=None,group_key=split_covered_group_key,sample_key='SampleID_split',reduction='mean',normalization=None,value_location='X',method='mwu',)
            path_enr_sample_state.rename(columns={split_covered_group_key:group_key}, inplace=True)
            path_enr_sample_state['split_power'] = split_power
            enrichment_per_split_power[split_power] = path_enr_sample_state
        
        all_enrichments_per_split_power = pd.concat(enrichment_per_split_power.values())
        all_enrichments_per_split_power[group_key] = all_enrichments_per_split_power[group_key].astype(pd.CategoricalDtype([c for c in value_adata.obs[group_key].cat.categories if c in all_enrichments_per_split_power[group_key].unique()], ordered=True))
        all_enrichments_per_split_power['what'] = all_enrichments_per_split_power['value'].astype(str) + " is " + all_enrichments_per_split_power['enrichment'].astype(str) + " in " + all_enrichments_per_split_power[group_key].astype(str)
        all_enrichments_per_split_power['$\\log_{10}(FDR)$'] = np.log10(all_enrichments_per_split_power['p_mwu_fdr_bh'])
        all_enrichments_per_split_power[value_spec] = all_enrichments_per_split_power['value'].str.split('-',n=1).str[1]
        all_enrichments_per_split_power['direction'] = all_enrichments_per_split_power['value'].str.split('-',n=1).str[0].map({'s':'sender','r':'receiver'})
        
        enrichments[value_spec_group_key] = all_enrichments_per_split_power

## Plot results

In [None]:
interesting_regions = pd.Series(['2','6','11']).map({r.split(' ')[1]:r for r in ordered_regions}).to_numpy()

In [None]:
def plot_dependency_on_split_power(all_enrichments_per_split_power, value_spec, group_key, n_pathways=20, n_pathways_per_ax=4, only_groups=None):

    if only_groups is None:
        all_groups = adata_pathway.obs[group_key].cat.categories
    else:
        all_groups = only_groups
    n_ax_per_row = (n_pathways+1)//n_pathways_per_ax
    fig,axs = tc.pl.subplots(n_ax_per_row,len(all_groups), axsize=(2,2), y_padding=1.0, x_padding=2.0, sharex=False, sharey=True, dpi=72)
    for i_group,_group in enumerate(all_groups):
        sub_enrichments = all_enrichments_per_split_power[all_enrichments_per_split_power[group_key].isin([_group])&(all_enrichments_per_split_power['enrichment']=='enriched')].reset_index().copy()
        top_pathways = sub_enrichments[(sub_enrichments['split_power'] == 2)&(sub_enrichments['p_mwu_fdr_bh']<0.05)].sort_values('p_mwu_fdr_bh').drop_duplicates(value_spec)[value_spec].head(n_pathways).to_numpy()
        for _i,_top_pathways in enumerate([top_pathways[p:p+n_pathways_per_ax] for p in range(0, len(top_pathways), n_pathways_per_ax)]):
            _sub_enrichments = sub_enrichments[sub_enrichments[value_spec].isin(_top_pathways)].reset_index().copy()
            _sub_enrichments[value_spec] = _sub_enrichments[value_spec].astype(pd.CategoricalDtype(categories=_top_pathways, ordered=True))
            _sub_enrichments['value'] = _sub_enrichments['value'].astype(str).astype('category')
            ax = axs[i_group,_i]
            ax.set_title(_group)
            ax.hlines(np.log10(0.05),0,1,transform=ax.get_yaxis_transform(), color='#AAAAAA')
            sax = sns.lineplot(data=_sub_enrichments, x='split_power', y='$\\log_{10}(FDR)$', hue=value_spec, style='direction', ax=ax)
            sax.set_xticks([0,1,2,3,4,5])
            sax.set_xticklabels([0,1,2,3,4,5])
            if len(_sub_enrichments) > 0:
                sns.move_legend(sax, "upper left", bbox_to_anchor=(1.05, 1), ncol=1, title=None, frameon=False)
        for _i2 in range(_i+1,n_ax_per_row):
            ax = axs[i_group,_i2]
            ax.set_axis_off()

    return fig

for value_spec in ['pathway']:
    for group_key in ['State','region']:
        value_spec_group_key = f'{value_spec}_{group_key}'
        print(value_spec_group_key)
        
        only_groups = None if group_key == 'State' else interesting_regions
        fig = plot_dependency_on_split_power(enrichments[value_spec_group_key], value_spec=value_spec, group_key=group_key, n_pathways=5, n_pathways_per_ax=5, only_groups=only_groups)

        fig.savefig(f'{figures_folder}/spatial_split_bidirectional_communication_{value_spec}_per_{group_key}.pdf',bbox_inches='tight')

In [None]:

for value_spec in ['pathway']:
    for group_key in ['State','region']:
        value_spec_group_key = f'{value_spec}_{group_key}'
        print(value_spec_group_key)

        sub_enrichments = enrichments[value_spec_group_key][enrichments[value_spec_group_key]['split_power']==2].copy()
        
        n_pathways = 100
        top_pathways = sub_enrichments[sub_enrichments['p_mwu_fdr_bh']<0.05].sort_values('p_mwu_fdr_bh').drop_duplicates(value_spec)[value_spec].head(n_pathways).to_numpy()
        sub_enrichments = sub_enrichments[sub_enrichments[value_spec].isin(top_pathways)].reset_index().copy()
        sub_enrichments[value_spec] = sub_enrichments[value_spec].astype(pd.CategoricalDtype(categories=reversed(top_pathways), ordered=True))
        
        senders_significances = sub_enrichments[sub_enrichments['direction']=='sender']
        receiver_significances = sub_enrichments[sub_enrichments['direction']=='receiver']
        
        fig,axs = tc.pl.subplots(2,1, axsize=(len(sub_enrichments[group_key].cat.categories)*0.7,len(top_pathways)*0.25), x_padding=3.0, dpi=72)
        tc.pl.significances(senders_significances, p_key='p_mwu_fdr_bh', value_key=value_spec, group_key=group_key, ax=axs[0,0]); axs[0,0].set_title('sender')
        tc.pl.significances(receiver_significances, p_key='p_mwu_fdr_bh', value_key=value_spec, group_key=group_key, ax=axs[0,1]); axs[0,1].set_title('receiver')
        
        fig.savefig(f'{figures_folder}/spatial_split_bidirectional_communication_{value_spec}_per_{group_key}_heatmap.pdf',bbox_inches='tight')


In [None]:
def merge_enrichments(significances, group_key='State', value_key='pathway'):
    small_value = 1e-300
    pmax = 0.05
    min_log = -np.log(pmax)

    enr_e = pd.pivot(significances[significances['enrichment']=='enriched'], value_key, group_key, 'p_mwu_fdr_bh')
    enr_p = pd.pivot(significances[significances['enrichment']!='enriched'], value_key, group_key, 'p_mwu_fdr_bh')

    enr_e = np.maximum(enr_e,small_value)
    enr_p = np.maximum(enr_p,small_value)

    enr_p = enr_p.reindex_like(enr_e)

    enr = pd.DataFrame(np.where(enr_e < enr_p, -np.log10(enr_e), np.log10(enr_p)),index=enr_e.index,columns=enr_e.columns)
    
    return enr

def enrichment_scatter_plot(merged_enrichments, ax, pathways_to_annotate=None, where=None):
    enriched_color = (1.0, 0.07058823529411765, 0.09019607843137255)
    depleted_color = (0.30196078431372547, 0.5215686274509804, 0.7098039215686275)
    null_color = (0.9,0.9,0.9)
    slightly_weight = 0.5
    slightly_enriched_color, slightly_depleted_color = tc.pl.mix_base_colors(
        np.array([[slightly_weight,1-slightly_weight,0.0],[0.0,1-slightly_weight,slightly_weight],]),
        np.array([list(enriched_color),list(null_color),list(depleted_color)])
    )
    enriched_color, depleted_color, slightly_enriched_color, slightly_depleted_color, null_color = [matplotlib.colors.to_hex(c) for c in [enriched_color, depleted_color, slightly_enriched_color, slightly_depleted_color, null_color]]
    merged_enrichments['sender_enriched'] = merged_enrichments['sender'] > -np.log10(0.05)
    merged_enrichments['sender_depleted'] = merged_enrichments['sender'] <  np.log10(0.05)
    merged_enrichments['receiver_enriched'] = merged_enrichments['receiver'] > -np.log10(0.05)
    merged_enrichments['receiver_depleted'] = merged_enrichments['receiver'] <  np.log10(0.05)
    merged_enrichments['color'] = np.where(merged_enrichments['sender_enriched'] & merged_enrichments['receiver_enriched'], enriched_color,
                                  np.where(merged_enrichments['sender_depleted'] & merged_enrichments['receiver_depleted'], depleted_color,
                                  np.where(merged_enrichments['sender_enriched'] | merged_enrichments['receiver_enriched'], slightly_enriched_color,
                                  np.where(merged_enrichments['sender_depleted'] | merged_enrichments['receiver_depleted'], slightly_depleted_color,
                                           null_color
                                          ))))

    sub = merged_enrichments#[merged_enrichments['receiver_significant']]

    ax.hlines( np.log10(0.05),0,1,transform=ax.get_yaxis_transform(), color='#CCC', zorder=-1)
    ax.hlines(-np.log10(0.05),0,1,transform=ax.get_yaxis_transform(), color='#CCC', zorder=-1)
    ax.vlines( np.log10(0.05),0,1,transform=ax.get_xaxis_transform(), color='#CCC', zorder=-1)
    ax.vlines(-np.log10(0.05),0,1,transform=ax.get_xaxis_transform(), color='#CCC', zorder=-1)
    ax.scatter(sub['sender'], sub['receiver'], c=sub['color'])
    maxlim = max([abs(i) for i in [*ax.get_xlim(),*ax.get_ylim()]])
    ax.set_xlim(-maxlim, maxlim)
    ax.set_ylim(-maxlim, maxlim)
    ax.set_xticks([v for v in ax.get_xticks() if v!=0]) # remove tick at maximum insignificance
    x_tick_labels = [f'$10^{{-{int(abs(v))}}}$' for v in ax.get_xticks()]
    ax.set_xticklabels(x_tick_labels)
    ax.set_yticks([v for v in ax.get_yticks() if v!=0])
    y_tick_labels = [f'$10^{{-{int(abs(v))}}}$' for v in ax.get_yticks()]
    ax.set_yticklabels(y_tick_labels)
    ax.set_xlim(-maxlim, maxlim) # reset the range as the xticks which one sees are not identical to the ones delivered by get_xticks()...
    ax.set_ylim(-maxlim, maxlim)

    if where is None:
        main_text = 'enrichment'
        enriched_text = 'in premalignant'
        depleted_text = 'in normal'
    else:
        main_text = f'in {where}'
        enriched_text = 'enriched'
        depleted_text = 'depleted'
    
    ax.text(0.5,-0.20,f'sender {main_text} (FDR)',horizontalalignment='center',verticalalignment='top',transform=ax.transAxes)
    ax.text(1.00,-0.13,enriched_text,horizontalalignment='right',verticalalignment='top',transform=ax.transAxes)
    ax.text(0.00,-0.13,depleted_text,horizontalalignment='left',verticalalignment='top',transform=ax.transAxes)
    ax.annotate("", xy=(-0.05,-0.11), xytext=(0.40,-0.11), arrowprops=dict(arrowstyle="->"),xycoords='axes fraction',textcoords='axes fraction')
    ax.annotate("", xy=(1.05,-0.11), xytext=(0.60,-0.11), arrowprops=dict(arrowstyle="->"),xycoords='axes fraction',textcoords='axes fraction')

    ax.text(-0.25,0.5,f'reciever {main_text} (FDR)',horizontalalignment='right',verticalalignment='center',transform=ax.transAxes,rotation='vertical')
    ax.text(-0.18,1.00,enriched_text,horizontalalignment='right',verticalalignment='top',transform=ax.transAxes,rotation='vertical')
    ax.text(-0.18,0.00,depleted_text,horizontalalignment='right',verticalalignment='bottom',transform=ax.transAxes,rotation='vertical')
    ax.annotate("", xy=(-0.16,-0.05), xytext=(-0.16,0.40), arrowprops=dict(arrowstyle="->"),xycoords='axes fraction',textcoords='axes fraction')
    ax.annotate("", xy=(-0.16,1.05), xytext=(-0.16,0.60), arrowprops=dict(arrowstyle="->"),xycoords='axes fraction',textcoords='axes fraction')

    if pathways_to_annotate is not None:
        nps = len(pathways_to_annotate)
        for pathway,radius,theta in zip(pathways_to_annotate['pathways'],pathways_to_annotate['radii'],pathways_to_annotate['thetas']):
            x_y_data = sub.loc[pathway,['sender','receiver']]
            x_y_text_delta = radius * np.array([np.cos(theta),np.sin(theta)])
            ax.annotate(pathway,x_y_data,xytext=x_y_text_delta,textcoords='offset points',arrowprops=dict(arrowstyle="-"),horizontalalignment='center',verticalalignment='center')

In [None]:
def padN(iterable,N):
    iterated = [i for i in iterable]
    return [*iterated,*[None for i in range(len(iterated),N)]]

In [None]:
top_enriched = {}
for value_spec_group_key, enr in enrichments.items():
    value_spec,group_key = value_spec_group_key.split('_')
    interesting_groups = interesting_regions if group_key == 'region' else ['normal','premalignant']
    top_enriched[value_spec_group_key] = pd.DataFrame({r: padN(enrichments[value_spec_group_key][
        enrichments[value_spec_group_key][group_key].isin([r]) & 
        enrichments[value_spec_group_key]['split_power'].isin([2]) & 
        enrichments[value_spec_group_key]['enrichment'].isin(['enriched']) & 
        (enrichments[value_spec_group_key]['p_mwu_fdr_bh'] <= 0.05)
    ].sort_values('p_mwu_fdr_bh').drop_duplicates(value_spec)[value_spec].head(5).to_numpy(),5) for r in interesting_groups})
top_enriched = pd.concat(top_enriched, axis=1)
top_enriched

In [None]:
top_depleted = {}
for value_spec_group_key, enr in enrichments.items():
    value_spec,group_key = value_spec_group_key.split('_')
    interesting_groups = interesting_regions if group_key == 'region' else ['normal','premalignant']
    top_depleted[value_spec_group_key] = pd.DataFrame({r: padN(enrichments[value_spec_group_key][
        enrichments[value_spec_group_key][group_key].isin([r]) & 
        enrichments[value_spec_group_key]['split_power'].isin([2]) & 
        enrichments[value_spec_group_key]['enrichment'].isin(['purified']) & 
        (enrichments[value_spec_group_key]['p_mwu_fdr_bh'] <= 0.05)
    ].sort_values('p_mwu_fdr_bh').drop_duplicates(value_spec)[value_spec].head(5).to_numpy(),5) for r in interesting_groups})
top_depleted = pd.concat(top_depleted, axis=1)
top_depleted

In [None]:

for value_spec in ['pathway']:
    for group_key in ['State','region']:
        value_spec_group_key = f'{value_spec}_{group_key}'
        print(value_spec_group_key)

        sub_enrichments = enrichments[value_spec_group_key][enrichments[value_spec_group_key]['split_power']==2].copy()
        senders_significances = sub_enrichments[sub_enrichments['direction']=='sender']
        receiver_significances = sub_enrichments[sub_enrichments['direction']=='receiver']

        if group_key == 'State':
            merged_enrichments = pd.DataFrame({
                'sender': merge_enrichments(senders_significances, value_key=value_spec)['premalignant'],
                'receiver': merge_enrichments(receiver_significances, value_key=value_spec)['premalignant'],
            })

            pathways_to_annotate = None
            if value_spec == 'pathway':
                pathways_to_annotate = pd.DataFrame({
                    'pathways':['CD137','RANKL','TNF','CSF', 'MK','GCG','AGT','NPY',],
                    'radii':   [     30,     30,   20,   20,   20,   18,   20,   20,],
                    'thetas':  [   1.00,   1.00,-0.25,-0.25,-0.50, 1.50, 1.75, 0.75,],
                })
                pathways_to_annotate['thetas'] *= np.pi

            fig,axs = tc.pl.subplots(axsize=(3,3))

            enrichment_scatter_plot(merged_enrichments, ax=axs[0,0], pathways_to_annotate=pathways_to_annotate)
        
        else:
            
            all_regions = interesting_regions
            fig,axs = tc.pl.subplots(len(all_regions),axsize=(3,3),x_padding=1.2, dpi=72)
            for ir,region in enumerate(all_regions):
                region_short = region.split(' ',2)
                region_short = region_short[0] + ' ' + region_short[1]

                merged_enrichments = pd.DataFrame({
                    'sender': merge_enrichments(senders_significances, group_key, value_key=value_spec)[region],
                    'receiver': merge_enrichments(receiver_significances, group_key, value_key=value_spec)[region],
                })

                pathways_to_annotate = None
                if value_spec == 'pathway':
                    if region_short == 'Region 2':
                        pathways_to_annotate = pd.DataFrame({
                         'pathways':['ANGPTL','PTN','PSAP','TAC','FGF','GUCA','GRN','GALECTIN','NRG','CX3C',],
                         'radii':   [      35,   23,    30,   18,   20,    30,   20,        23,   25,    30,],
                         'thetas':  [   -1.00,-1.00,  0.00, 1.50, 0.00,  0.00,-0.50,      0.50, 0.00,  1.00,],
                        })
                        pathways_to_annotate['thetas'] *= np.pi
                    elif region_short == 'Region 6':
                        pathways_to_annotate = pd.DataFrame({
                         'pathways':['CD137','TNF', 'MK','GCG', 'HH','TRAIL',],
                         'radii':   [     30,   23,   20,   18,   20,     30,],
                         'thetas':  [   0.00, 0.00, 0.00, 1.50, 0.00,   0.00,],
                        })
                        pathways_to_annotate['thetas'] *= np.pi
                    else: # Region 11
                        pathways_to_annotate = pd.DataFrame({
                         'pathways':['OSM','VEGI', 'CD137','PTN', 'MK','GUCA','GCG','VIP',],
                         'radii':   [   30,    23,      20,   23,   20,    30,   18,   18,],
                         'thetas':  [ 0.00,  0.00,   -0.50, 1.00, 0.00,  1.00, 0.50, 1.50,],
                        })
                        pathways_to_annotate['thetas'] *= np.pi

                enrichment_scatter_plot(merged_enrichments, ax=axs[0,ir], pathways_to_annotate=pathways_to_annotate, where=region_short)

        fig.savefig(f'{figures_folder}/spatial_split_bidirectional_communication_{value_spec}_per_{group_key}_scatter.pdf',bbox_inches='tight')