In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc
sc.settings.n_jobs = 56
sc.settings.set_figure_params(dpi=180, dpi_save=300, frameon=False, figsize=(4, 4), fontsize=8, facecolor='white')
from tqdm import tqdm
import sys
sys.path.append('.')
from permutation import generate_cell_type_contact_count_matrices

# Count cell-cell contacts at the subclass level

In [None]:
major_brain_regions = ['LSX_HY_MB_HB',
 'FbTrt',
 'CB_1',
 'CB_2',
 'MYdp',
 'Hbl_VS',
 'L1_HPFmo_Mngs',
 'TH',
 'CTX_1',
 'CTX_2',
 'ENTm',
 'DG',
 'HPF_CA',
 'HY',
 'OB_1',
 'STR',
 'OB_2']


6632.24microns=883.144

15 microns = 1.997


100 microns = 13.315

In [None]:
# Make the output path

for focus_key in ['Atlas1','Atlas2','Atlas3']:

    output_path = f'source_data/outputs_30um_{focus_key}'
    os.makedirs(output_path, exist_ok=True)

    r_radius=1.997*2
    r_permute_radius=13.315

    for region in major_brain_regions:
        print(region)

        # Read the data
        df_ct_labels = pd.read_csv(os.path.join(f'source_data/cells_by_regions_{focus_key}', f'{region}.csv'), index_col=0)
        slice_ids = np.unique(df_ct_labels['ap_order'])


        cell_type_col = 'transfer_gt_cell_type_sub_STARmap'
        cell_types = np.unique(df_ct_labels[cell_type_col])
        N_cell_types = len(cell_types)
        N_permutations = 1000

        # Count and save the contacts without permutation
        merged_contact_counts = np.zeros((N_cell_types, N_cell_types), dtype=int)
        for slice_id in tqdm(slice_ids):
            df_slice = df_ct_labels[df_ct_labels['ap_order'] == slice_id]
            cell_type_contact_counts = generate_cell_type_contact_count_matrices(df_slice, cell_type_col, 
                                            ['use_x', 'use_y'], cell_types, 
                                            permutation_method='no_permutation', contact_radius=r_radius)

            merged_contact_counts = merged_contact_counts + cell_type_contact_counts
        output_file = os.path.join(output_path, f'{region}_no_permutation.npy')
        np.save(output_file, merged_contact_counts)


        from multiprocessing import Pool
        def permute_and_count_contacts_for_slices(df_slice_list):
            merged_contact_counts = np.zeros((N_permutations, N_cell_types, N_cell_types), dtype=int)
            for df_slice in df_slice_list:
                for i in tqdm(range(N_permutations)):
                    df_slice_rand = df_slice.copy()
                    r_permute = r_permute_radius
                    r = r_permute * np.sqrt(np.random.uniform(size=df_slice_rand.shape[0]))
                    theta = np.random.uniform(size=df_slice_rand.shape[0]) * 2 * np.pi

                    df_slice_rand['use_x'] += r * np.sin(theta)
                    df_slice_rand['use_y'] += r * np.cos(theta)

                    cell_type_contact_counts = generate_cell_type_contact_count_matrices(df_slice_rand, cell_type_col, 
                                            ['use_x', 'use_y'], cell_types, 
                                            permutation_method='no_permutation', contact_radius=r_radius)
                    merged_contact_counts[i] = merged_contact_counts[i] + cell_type_contact_counts
            return merged_contact_counts



        # Get the dataframe for each slice
        all_df_slice_list = [df_ct_labels[df_ct_labels['ap_order'] == slice_id] for slice_id in slice_ids]


        # Split the slices into groups
        N_groups = 16
        group_size = int(np.ceil(len(slice_ids) / N_groups))
        grouped_slice_dfs = []
        for i in range(N_groups):
            slice_id_start = i * group_size
            slice_id_stop = (i + 1) * group_size
            grouped_slice_dfs.append(all_df_slice_list[slice_id_start:slice_id_stop])


        # Permute and count the contacts in parallel
        print('start')
        with Pool(N_groups) as p:
            contact_analysis_results = p.map(permute_and_count_contacts_for_slices, grouped_slice_dfs)
        merged_contact_counts = np.sum(contact_analysis_results, axis=0)

        means = np.mean(merged_contact_counts, axis=0)
        stds = np.std(merged_contact_counts, axis=0)

        np.save(os.path.join(output_path, f'{region}_local_permutation_count_tensor.npy'),
                merged_contact_counts)

        output_file_mean = os.path.join(output_path, f'{region}_local_permutation_mean.npy')
        np.save(output_file_mean, means)
        output_file_std = os.path.join(output_path, f'{region}_local_permutation_std.npy')
        np.save(output_file_std, stds)

        