In [None]:
import os
import json

import numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt

import scanpy as sc

import statsmodels.stats.multitest

In [None]:
def get_confusion_count_df(df, col1, col2):
    '''Get the confusion matrix between two categorical columns in a dataframe.'''
    assert((col1 != 'count') and (col2 != 'count'))
    count_df = df[[col1, col2]].copy()
    count_df['count'] = 1
    
    conf_df = pd.pivot_table(count_df, index=[col1], columns=[col2], 
                              values='count', aggfunc=np.sum).fillna(0)
    
    return conf_df
        
def get_expected_count_df(conf_df):
    mtx = conf_df.values
    
    total_count = np.sum(mtx)
    row_fractions = np.sum(mtx, axis=1) / total_count
    col_fractions = np.sum(mtx, axis=0) / total_count
    
    expect_fractions = row_fractions[:, np.newaxis] * col_fractions[np.newaxis, :]
    expected_count_df = pd.DataFrame(data=expect_fractions * total_count, 
                                     index=conf_df.index, columns=conf_df.columns)
    
    return expected_count_df

In [None]:
for focus_key in ['Atlas1','Atlas2','Atlas3']:
    df_ct_labels=pd.read_csv(f'./source_data/df_ct_labels_{focus_key}.csv',index_col=0)
    output_path = f'source_data/cells_by_regions_{focus_key}'
    os.makedirs(output_path, exist_ok=True)
    
    # Get the confusion matrix data frame
    conf_df = get_confusion_count_df(df_ct_labels, 'transfer_gt_cell_type_sub_STARmap', 'transfer_gt_tissue_region_main_STARmap')

    # Calculate the enrichment matrix data frame
    expected_count_df = get_expected_count_df(conf_df)
    region_enrichment_df = conf_df / expected_count_df.values

    major_brain_regions = list(df_ct_labels['transfer_gt_tissue_region_main_STARmap'].unique())

    for r in major_brain_regions:
        print(r)

        region_df_ct_labels = df_ct_labels[df_ct_labels['transfer_gt_tissue_region_main_STARmap'] == r]

        subclasses, counts = np.unique(region_df_ct_labels['transfer_gt_cell_type_sub_STARmap'], return_counts=True)
        selected_subclasses = []



        for i in np.argsort(-counts):

            selected = False
            neuron_cattegory_label=df_ct_labels.loc[df_ct_labels['transfer_gt_cell_type_sub_STARmap']==subclasses[i],'neuron_category'].unique()[0]

            # For non-neuronal, non-astrocytes
            if neuron_cattegory_label=='non' and (not subclasses[i].startswith('AC')):       
                if counts[i] > 50:
                    selected = True
                    selected_subclasses.append(subclasses[i])

            # For astrocytes
            elif subclasses[i].startswith('AC'):
                if region_enrichment_df.loc[subclasses[i], r] > 1:
                    selected = True
                    selected_subclasses.append(subclasses[i])

            # For neurons
            else:
                threshold = 6

                if region_enrichment_df.loc[subclasses[i], r] > threshold:
                    selected = True
                    selected_subclasses.append(subclasses[i])


        region_df_ct_labels = region_df_ct_labels[region_df_ct_labels['transfer_gt_cell_type_sub_STARmap'].isin(
                                    selected_subclasses)].copy()
        region_df_ct_labels.to_csv(os.path.join(output_path, f'{r}.csv'))
