In [None]:
%config InlineBackend.figure_format='retina'

In [None]:
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc
sc.settings.n_jobs = 64

In [None]:
def generate_raw_adata_copy(adata):
    adata_raw = adata.copy()
    adata_raw.var.index = list(adata_raw.var['human_gene_id'])
    adata_raw.var_names_make_unique()
    return adata_raw

def find_hv_genes(adata_origin):
    adata = adata_origin.copy()
    sc.pp.filter_cells(adata, min_genes=20)

    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, min_mean=0.01, max_mean=10, min_disp=0.3)

    return adata.var[adata.var['highly_variable']].index.tolist()

from sklearn.neighbors import NearestNeighbors

def mutual_k_nearest_neighbors(arr1, arr2, k):
    # Initialize NearestNeighbors models for each array
    nn1 = NearestNeighbors(n_neighbors=k, metric='cosine')
    nn2 = NearestNeighbors(n_neighbors=k, metric='cosine')
    
    # Fit the models
    nn1.fit(arr1)
    nn2.fit(arr2)
    
    # Find K nearest neighbors
    distances1, indices1 = nn2.kneighbors(arr1)
    distances2, indices2 = nn1.kneighbors(arr2)
    
    p1s = []
    p2s = []
    distances = []
    
    # Iterate over points in arr1
    for p1 in range(arr1.shape[0]):
        for i in range(k):
            p2 = indices1[p1, i]
            if p1 in indices2[p2]:
                p1s.append(p1)
                p2s.append(p2)
                distances.append(distances1[p1, i])
    
    return p1s, p2s, distances

def integrate_datasets(adata_ref, adata_query,
                       k=30, dist_threshold=0.4):
    adata_ref.var.index = list(adata_ref.var['human_gene_id'])
    adata_query.var.index = list(adata_query.var['human_gene_id'])
    adata_ref.var_names_make_unique()
    adata_query.var_names_make_unique()

    hvg1 = find_hv_genes(adata_ref)
    hvg2 = find_hv_genes(adata_query)

    hv_genes = np.intersect1d(hvg1, hvg2)
    print(f'Found {len(hv_genes)} highly variable genes in both datasets')

    adata_ref = adata_ref[:, hv_genes].copy()
    adata_query = adata_query[:, hv_genes].copy()

    sc.pp.filter_cells(adata_ref, min_genes=20)
    sc.pp.filter_cells(adata_query, min_genes=20)

    sc.pp.normalize_total(adata_ref, target_sum=len(hv_genes))
    sc.pp.log1p(adata_ref)
    sc.pp.scale(adata_ref, max_value=10)

    sc.pp.normalize_total(adata_query, target_sum=len(hv_genes))
    sc.pp.log1p(adata_query)
    sc.pp.scale(adata_query, max_value=10)

    adata_merge = anndata.concat(
        {'ref': adata_ref, 'query': adata_query},
        label='batch_category',
    )

    # Get the significant PCs
    sc.tl.pca(adata_merge, svd_solver='arpack', n_comps=100)

    adata_m1 = adata_merge[adata_merge.obs['batch_category'] == 'ref']
    adata_m2 = adata_merge[adata_merge.obs['batch_category'] == 'query']

    p1s, p2s, distances = mutual_k_nearest_neighbors(adata_m1.obsm['X_pca'], 
                                         adata_m2.obsm['X_pca'], k=k)
    mkn_df = pd.DataFrame({
        'cell1': adata_m1.obs.index[p1s],
        'cell2': adata_m2.obs.index[p2s],
        'dist': distances
    }).sort_values('dist')

    mkn_df = mkn_df[mkn_df['dist'] < dist_threshold]
    print(f'Found {mkn_df.shape[0]} mutual k-nearest neighbors.')

    return mkn_df['cell1'].values, mkn_df['cell2'].values


In [None]:
# Get the input files
adata_input_path = '/GPUData_xingjie/SCMG/sc_rna_data/'
dataset_names = sorted([f.replace('.h5ad', '') for f in os.listdir(adata_input_path)])

standard_gene_df = pd.read_csv(
    '/GPUData_xingjie/Softwares/SCMG_dev/scmg/data/standard_genes.csv')
standard_ids = list(standard_gene_df['human_id'])

# Create the output folder
output_path = '/GPUData_xingjie/SCMG/contrastive_embedding_training/edges/intra_integration'
os.makedirs(output_path, exist_ok=True)

# Integrate the developmental atlas

In [None]:
adata_all = sc.read_h5ad(os.path.join(adata_input_path, 
                        'standard_adata_Qiu_Organogenesis_MM_2022_all.h5ad'))
adata_all

In [None]:
adata_all.obs['development_stage'].value_counts()

In [None]:
integration_group_pairs = [
    (['E5.25', 'E6.25'], ['E3.5', 'E4.5'], 50, 0.6),
    (['E6.75'], ['E5.25', 'E6.25'], 20, 0.4),
    (['E8.5b'], ['E8', 'E8.25', 'E8.5a'], 10, 0.4),
    (['E9.5'], ['E8.5b'], 10, 0.4),
]

for i, igp in enumerate(integration_group_pairs):

    adata1 = adata_all[adata_all.obs['development_stage'].isin(igp[0])]
    adata2 = adata_all[adata_all.obs['development_stage'].isin(igp[1])]

    #cell_types1 = np.unique(adata1.obs['cell_type'])
    #cell_types2 = np.unique(adata2.obs['cell_type'])
    #common_cell_types = np.intersect1d(cell_types1, cell_types2)
    #adata1 = adata1[adata1.obs['cell_type'].isin(common_cell_types)].copy()
    #adata2 = adata2[adata2.obs['cell_type'].isin(common_cell_types)].copy()

    # Downsample the datasets if there are too many cells
    downsample_frac = 0.2
    min_cell_number = 1000
    n_obs1 = max(int(adata1.shape[0] * downsample_frac), min_cell_number)
    n_obs2 = max(int(adata2.shape[0] * downsample_frac), min_cell_number)

    if adata1.shape[0] > n_obs1:
        sc.pp.subsample(adata1, n_obs=n_obs1, copy=False)
    if adata2.shape[0] > n_obs2:
        sc.pp.subsample(adata2, n_obs=n_obs2, copy=False)

    display(adata1)
    display(adata2)

    # Integrateion
    adata1_raw = generate_raw_adata_copy(adata1)
    adata2_raw = generate_raw_adata_copy(adata2)
    anchor_cells1, anchor_cells2 = integrate_datasets(
                    adata_ref=adata1, adata_query=adata2,
                    k=igp[2], dist_threshold=igp[3])

    # Generate contrastive datasets
    output_prefix = os.path.join(output_path, 
                                 f'Qiu_Organogenesis_MM_2022_all_intra_{i}')
    
    edges_df = pd.DataFrame({
        'cell_ref': anchor_cells1,
        'cell_query': anchor_cells2,
        'dataset_ref': adata1_raw[anchor_cells1].obs['dataset_id'].values,
        'dataset_query': adata2_raw[anchor_cells2].obs['dataset_id'].values,
        'cell_type_ref': adata1_raw[anchor_cells1].obs['cell_type'].values,
        'cell_type_query': adata2_raw[anchor_cells2].obs['cell_type'].values,
    })

    edges_df.to_parquet(f'{output_prefix}.parquet')


# Intra integration for Suo_ImmuneDev

In [None]:
adata_all = sc.read_h5ad(os.path.join(adata_input_path, 
                        'standard_adata_Suo_ImmuneDev_HS_2022_all.h5ad'))
adata_all

In [None]:
np.unique(adata_all.obs['cell_type'])

In [None]:
integration_group_pairs = [
    (['double negative thymocyte'], ['early lymphoid progenitor'], 5, 2),
    (['late pro-B cell'], ['pro-B cell'], 5, 2),
    (['small pre-B-II cell', 'large pre-B-II cell'], ['late pro-B cell'], 5, 2),
]

for i, igp in enumerate(integration_group_pairs):

    adata1 = adata_all[adata_all.obs['cell_type'].isin(igp[0])]
    adata2 = adata_all[adata_all.obs['cell_type'].isin(igp[1])]

    # Downsample the datasets if there are too many cells
    display(adata1)
    display(adata2)

    # Integrateion
    adata1_raw = generate_raw_adata_copy(adata1)
    adata2_raw = generate_raw_adata_copy(adata2)
    anchor_cells1, anchor_cells2 = integrate_datasets(
                    adata_ref=adata1, adata_query=adata2,
                    k=igp[2], dist_threshold=igp[3])

    # Generate contrastive datasets
    output_prefix = os.path.join(output_path, 
                                 f'Suo_ImmuneDev_HS_2022_all_intra_{i}')
    
    edges_df = pd.DataFrame({
        'cell_ref': anchor_cells1,
        'cell_query': anchor_cells2,
        'dataset_ref': adata1_raw[anchor_cells1].obs['dataset_id'].values,
        'dataset_query': adata2_raw[anchor_cells2].obs['dataset_id'].values,
        'cell_type_ref': adata1_raw[anchor_cells1].obs['cell_type'].values,
        'cell_type_query': adata2_raw[anchor_cells2].obs['cell_type'].values,
    })

    edges_df.to_parquet(f'{output_prefix}.parquet')