In [1]:
import os

import numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt

import scanpy as sc

import statsmodels.stats.multitest

In [2]:
def count_zero_pairs(contact_mtx):
    n_0 = 0
    for i in range(contact_mtx.shape[0]):
        for j in range(i, contact_mtx.shape[0]):
            if contact_mtx[i, j] == 0:
                n_0 += 1
    return n_0

def adjust_p_value_matrix_by_BH(p_val_mtx):
    '''Adjust the p-values in a matrix by the Benjamini/Hochberg method.
    The matrix should be symmetric.
    '''
    p_val_sequential = []
    N = p_val_mtx.shape[0]
    
    for i in range(N):
        for j in range(i, N):
            p_val_sequential.append(p_val_mtx[i, j])

    p_val_sequential_bh = statsmodels.stats.multitest.multipletests(p_val_sequential, method='fdr_bh')[1]
    
    adjusted_p_val_mtx = np.zeros((N, N))
    
    counter = 0
    for i in range(N):
        for j in range(i, N):
            adjusted_p_val_mtx[i, j] = p_val_sequential_bh[counter]
            adjusted_p_val_mtx[j, i] = p_val_sequential_bh[counter]
            counter += 1
            
    return adjusted_p_val_mtx

def get_data_frame_from_metrices(cell_types, mtx_dict):
    N = len(cell_types)
    
    serials_dict = {'cell_type1':[], 'cell_type2':[]}
    for k in mtx_dict.keys():
        serials_dict[k] = []
        
    for i in range(N):
        for j in range(i, N):
            serials_dict['cell_type1'].append(cell_types[i])
            serials_dict['cell_type2'].append(cell_types[j])
            for k in mtx_dict.keys():
                serials_dict[k].append(mtx_dict[k][i, j])
                
    return pd.DataFrame(serials_dict)
    

def sort_cell_type_contact_p_values(p_val_mtx, cell_types):
    '''Return a list of (cell_type1, cell_type2, p_value) sorted by p_values.'''
    p_val_list = []
    N = p_val_mtx.shape[0]
    for i in range(N):
        for j in range(i, N):
            p_val_list.append((cell_types[i], cell_types[j], p_val_mtx[i, j]))
    return sorted(p_val_list, key=lambda x:x[2])

In [3]:
import scipy.cluster
# from scattermap import scattermap

def get_optimal_order_of_mtx(X):
    Z = scipy.cluster.hierarchy.ward(X)
    return scipy.cluster.hierarchy.leaves_list(
        scipy.cluster.hierarchy.optimal_leaf_ordering(Z, X))

def get_ordered_tick_labels(tick_labels):
    tick_labels_with_class = [s.split(' ')[-1] + ' ' + s for s in tick_labels]
    return np.argsort(tick_labels_with_class)

def filter_pval_mtx(pval_mtx, tick_labels, allowed_pairs):
    pval_mtx_filtered = pval_mtx.copy()
    
    for i in range(pval_mtx.shape[0]):
        ct1 = tick_labels[i]
        for j in range(pval_mtx.shape[1]):
            ct2 = tick_labels[j]
            
            if ((ct1, ct2) in allowed_pairs) or ((ct2, ct1) in allowed_pairs):
                continue
            else:
                pval_mtx_filtered[i, j] = 1
            
    return pval_mtx_filtered


In [4]:
def make_dotplot(pval_mtx, fold_change_mtx, tick_labels, title='', allowed_pairs=None):

    #optimal_order = get_optimal_order_of_mtx(pval_mtx)
    optimal_order = get_ordered_tick_labels(tick_labels)
    
    pval_mtx = pval_mtx[optimal_order][:, optimal_order]
    fold_change_mtx = fold_change_mtx[optimal_order][:, optimal_order]
    tick_labels = tick_labels[optimal_order]
    
    if None is not allowed_pairs:
        pval_mtx = filter_pval_mtx(pval_mtx, tick_labels, allowed_pairs)
    
    pval_mtx[pval_mtx>0.05]=1
    mlog_pvals = - np.log10(np.maximum(pval_mtx, 1e-10))
    fold_change_mtx[mlog_pvals==0]=0

    fold_change_mtx=np.log10(fold_change_mtx+1)*100

    fig_len = len(tick_labels) * 0.1
#     fig = plt.figure(figsize=(fig_len, fig_len), dpi=300)


    fig,ax = scattermap(mlog_pvals, marker_size= fold_change_mtx,
                square=True, 
                cmap="Reds",
                linewidths=0.2 * (pval_mtx < 0.05).reshape(-1), 
                linecolor='black', xticklabels=tick_labels, yticklabels=tick_labels,
                vmin=0, vmax=np.max(mlog_pvals), 
                cbar_kws={'shrink':0.5, 'anchor':(0, 0.7)})
    
    plt.tight_layout()
    fig.savefig(f'figures_{focus_key}/{title}.png',dpi=300)#, transparent=True)
#     return fig

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

def scattermap(data_matrix, marker_size, square=True, cmap='coolwarm', 
               linewidths=0, linecolor='black', xticklabels=None, yticklabels=None, 
               vmin=None, vmax=None, cbar_kws=None):
    if vmin is None:
        vmin = data_matrix.min()
    if vmax is None:
        vmax = data_matrix.max()

    norm = Normalize(vmin=vmin, vmax=vmax)
    fig, ax = plt.subplots()
    cmap = plt.get_cmap(cmap)

    # Plot each data point individually
    n, m = data_matrix.shape
    for i in range(n):
        for j in range(m):
            color = cmap(norm(data_matrix[i, j]))
            size = marker_size[i, j] if marker_size.shape == data_matrix.shape else marker_size
            ax.scatter(j, i, color=color, s=size)

    # Customizations
    ax.set_xticks(np.arange(m))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(xticklabels if xticklabels is not None else np.arange(m), rotation=90)
    ax.set_yticklabels(yticklabels if yticklabels is not None else np.arange(n))

    ax.invert_yaxis()
    
    # Gridlines based on the data positions
    ax.set_xticks(np.arange(m+1)-.5, minor=True)
    ax.set_yticks(np.arange(n+1)-.5, minor=True)
#     ax.grid(which="minor", color="w", linestyle='-', linewidth=2)
    ax.tick_params(which="minor", size=0)

    # Colorbar
    if cbar_kws is None:
        cbar_kws = {}
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, **cbar_kws)
    cbar.set_label('log p value', rotation=270, labelpad=15)

    if square:
        plt.axis('equal')
#     plt.show()

    return fig,ax


In [5]:
for focus_key in ['Atlas1','Atlas1','Atlas1']:
    permutation_path = F'source_data/outputs_30um_{focus_key}'
    os.makedirs(permutation_path+'/result/', exist_ok=True)


    major_brain_regions = ['LSX_HY_MB_HB',
     'FbTrt',
     'CB_1',
     'CB_2',
     'MYdp',
     'Hbl_VS',
     'L1_HPFmo_Mngs',
     'TH',
     'CTX_1',
     'CTX_2',
     'ENTm',
     'DG',
     'HPF_CA',
     'HY',
     'OB_1',
     'STR',
     'OB_2']

    result_dfs = []
    
    
    for region in major_brain_regions:

        if os.path.exists(os.path.join(permutation_path, f'{region}_local_permutation_mean.npy')):
            print(region)

            # Load the cell type labels
            df_ct_labels = pd.read_csv(os.path.join(f'source_data/cells_by_regions_{focus_key}', f'{region}.csv'), index_col=0)


            subclass_types = np.unique(df_ct_labels['transfer_gt_cell_type_sub_STARmap'])

            cell_contact_counts = np.load(os.path.join(permutation_path, f'{region}_no_permutation.npy'))

            local_null_means = np.load(os.path.join(permutation_path, f'{region}_local_permutation_mean.npy'))
            local_null_stds = np.load(os.path.join(permutation_path, f'{region}_local_permutation_std.npy'))



            # Require all stds to be larger or equal to the minimal observable std value
            local_null_stds = np.maximum(local_null_stds, np.sqrt(1 / 1000))


            local_z_scores = (cell_contact_counts - local_null_means) / local_null_stds
            local_p_values = scipy.stats.norm.sf(local_z_scores)
            adjusted_local_p_values = adjust_p_value_matrix_by_BH(local_p_values)

            fold_changes = cell_contact_counts / (local_null_means + 1e-4)


            # Gather all results into a data frame
            contact_result_df = get_data_frame_from_metrices(subclass_types, 
                                                     {'pval-adjusted': adjusted_local_p_values,
                                                      'pval': local_p_values,
                                                      'z_score': local_z_scores,
                                                      'contact_count': cell_contact_counts,
                                                      'permutation_mean': local_null_means,
                                                      'permutatmerion_std': local_null_stds
                                                    }).sort_values('z_score', ascending=False)


            # Filter out pairs that don't contact
            contact_result_df = contact_result_df[contact_result_df['pval-adjusted'] < 0.05]
            contact_result_df = contact_result_df[contact_result_df['contact_count'] > 50]
            contact_result_df.to_csv(os.path.join(permutation_path+'/result/', f'{region}_close_contacts.csv'))

            contact_result_df['region']=region
            result_dfs.append(contact_result_df)

        else:
            print(region,'norun')
            
    combined_results = pd.concat(result_dfs)

    combined_results.to_csv(os.path.join(permutation_path+'/result/', 'all_close_contacts.csv'))

LSX_HY_MB_HB


  exec(code_obj, self.user_global_ns, self.user_ns)


FbTrt
CB_1
CB_2
MYdp
Hbl_VS
L1_HPFmo_Mngs
TH
CTX_1
CTX_2
ENTm
DG
HPF_CA
HY
OB_1
STR
OB_2
LSX_HY_MB_HB
FbTrt
CB_1
CB_2
MYdp
Hbl_VS
L1_HPFmo_Mngs
TH
CTX_1
CTX_2
ENTm
DG
HPF_CA
HY
OB_1
STR
OB_2
LSX_HY_MB_HB
FbTrt
CB_1
CB_2
MYdp
Hbl_VS
L1_HPFmo_Mngs
TH
CTX_1
CTX_2
ENTm
DG
HPF_CA
HY
OB_1
STR
OB_2


In [6]:
permutation_path

'source_data/outputs_30um_Atlas1'