In [None]:
%config InlineBackend.figure_format='retina'

In [None]:
import os
import shutil
import json
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc
sc.settings.n_jobs = 64

In [None]:
def find_hv_genes(adata_origin):
    adata = adata_origin.copy()
    sc.pp.filter_cells(adata, min_genes=20)

    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, min_mean=0.01, max_mean=10, min_disp=0.3)

    return adata.var[adata.var['highly_variable']].index.tolist()

def find_common_hv_genes(adata1, adata2):
    adata_merge = anndata.concat(
        {'b1': adata1, 'b2': adata2},
        label='batch_category',
    )
    sc.pp.filter_cells(adata_merge, min_genes=20)

    sc.pp.normalize_total(adata_merge, target_sum=1e4)
    sc.pp.log1p(adata_merge)
    sc.pp.highly_variable_genes(adata_merge, 
                    n_top_genes=2000, batch_key='batch_category')
    return adata_merge.var[adata_merge.var['highly_variable']].index.tolist()

def get_merged_datasets(adata1, adata2): 
    adata1.var.index = list(adata1.var['human_gene_id'])
    adata2.var.index = list(adata2.var['human_gene_id'])
    adata1.var_names_make_unique()
    adata2.var_names_make_unique()

    common_genes = np.intersect1d(adata1.var_names, adata2.var_names)
    adata1 = adata1[:, common_genes].copy()
    adata2 = adata2[:, common_genes].copy()

    sc.pp.filter_cells(adata1, min_genes=20)
    sc.pp.filter_cells(adata2, min_genes=20)

    if (adata1.shape[0] < 10) or (adata2.shape[0] < 10):
        return None, None

    adata_merge = anndata.concat(
        {'b1': adata1, 'b2': adata2},
        label='batch_category',
    )

    sc.pp.normalize_total(adata_merge, target_sum=1e4)
    sc.pp.log1p(adata_merge)

    sc.pp.highly_variable_genes(adata_merge, 
                    n_top_genes=2000, batch_key='batch_category')
    adata_merge = adata_merge[:, adata_merge.var['highly_variable']].copy()
    sc.pp.scale(adata_merge, max_value=100)

    # Get the significant PCs
    sc.tl.pca(adata_merge, svd_solver='arpack', n_comps=min(100, adata_merge.n_obs - 1))
    adata_merge.obsm['X_pca_norm'] = adata_merge.obsm['X_pca'] / (np.linalg.norm(
        adata_merge.obsm['X_pca'], axis=1) + 1e-6)[:, None]

    adata_m1 = adata_merge[adata_merge.obs['batch_category'] == 'b1']
    adata_m2 = adata_merge[adata_merge.obs['batch_category'] == 'b2']

    return adata_m1, adata_m2

from sklearn.neighbors import NearestNeighbors

def mutual_k_nearest_neighbors(arr1, arr2, k1, k2, metric='euclidean'):
    # Initialize NearestNeighbors models for each array
    nn1 = NearestNeighbors(n_neighbors=k2, metric=metric)
    nn2 = NearestNeighbors(n_neighbors=k1, metric=metric)
    
    # Fit the models
    nn1.fit(arr1)
    nn2.fit(arr2)
    
    # Find K nearest neighbors
    distances1, indices1 = nn2.kneighbors(arr1)
    distances2, indices2 = nn1.kneighbors(arr2)
    
    p1s = []
    p2s = []
    distances = []
    
    # Iterate over points in arr1
    for p1 in range(arr1.shape[0]):
        for i in range(k1):
            p2 = indices1[p1, i]
            if p1 in indices2[p2]:
                p1s.append(p1)
                p2s.append(p2)
                distances.append(distances1[p1, i])
    
    return p1s, p2s, distances

In [None]:
import json
with open('../integration_specs.json' ,'r') as f:
    i_specs_d = json.load(f)

all_ref_datasets = set()
for v in i_specs_d.values():
    for d in v:
        all_ref_datasets.add(d.split('@')[1])

all_ref_datasets

In [None]:
import gc

input_path = '/GPUData_xingjie/SCMG/sc_rna_data'

# Create the output folder
output_path = '/GPUData_xingjie/SCMG/contrastive_embedding_training/edges/inter_dataset/'
os.makedirs(output_path, exist_ok=True)

# Generate the training datasets
for ds_ref in sorted(all_ref_datasets):
    ds_ref_name = ds_ref.replace(':', '_')

    # Define all the query datasets
    query_datasets = set()
    for k, v in i_specs_d.items():
        for d in v:
            if d.split('@')[1] == ds_ref:

                # Do not overwrite
                ds_ref_name = ds_ref.replace(':', '_')
                ds_query_name = k.split('@')[1].replace(':', '_')
                if os.path.exists(os.path.join(output_path, 
                        f'{ds_ref_name}_AND_{ds_query_name}.parquet')):
                    continue

                query_datasets.add(k.split('@')[1])

    if len(query_datasets) == 0:
        continue

    adata_ref_all = sc.read_h5ad(
        os.path.join(input_path, f'standard_adata_{ds_ref_name}.h5ad'))

    # Pairwise integration
    for ds_query in sorted(query_datasets):
        if ds_query == ds_ref:
            continue

        ds_query_name = ds_query.replace(':', '_')

        print(f'\nIntegrate {ds_ref} and {ds_query}')

        adata_query_all = sc.read_h5ad(
            os.path.join(input_path, f'standard_adata_{ds_query_name}.h5ad'))
        
        # Define the integratable ref and query cell types
        query_ct_map = {}
        for k, v in i_specs_d.items():
            if k.split('@')[1] != ds_query:
                continue
            for d in v:
                if d.split('@')[1] == ds_ref:
                    if k.split('@')[0] not in query_ct_map:
                        query_ct_map[k.split('@')[0]] = []

                    query_ct_map[k.split('@')[0]].append(d.split('@')[0])

        selected_query_cts = np.array(list(query_ct_map.keys()))
        selected_ref_cts = np.array(list(set(v for l in query_ct_map.values()
                                    for v in l)))
        print(f'Selected query cell types: {selected_query_cts}')
        print(f'Selected ref cell types: {selected_ref_cts}')

        adata_ref = adata_ref_all[
            adata_ref_all.obs['cell_type'].isin(selected_ref_cts)].copy()
        adata_query = adata_query_all[
            adata_query_all.obs['cell_type'].isin(selected_query_cts)].copy()
        
        if adata_ref.shape[0] < 5 or adata_query.shape[0] < 5:
            continue

        # Merge the datasets
        adata_m_ref, adata_m_query = get_merged_datasets(adata_ref, adata_query)
        
        if adata_m_ref is None or adata_m_query is None:
            continue
        if adata_m_ref.shape[0] < 1 or adata_m_query.shape[0] < 1:
            continue

        print(f'Query dataset has {adata_m_query.shape[0]} cells.')
        print(f'Reference dataset has {adata_m_ref.shape[0]} cells.')

        # Integrated for two datasets
        cell_pairs_d = {
            'cell_ref': [],
            'cell_query': [],
            'dataset_ref': [],
            'dataset_query': [],
            'cell_type_ref' : [],
            'cell_type_query' : [],
        }

        k1 = min(10, adata_m_ref.shape[0])
        k2 = min(10, adata_m_query.shape[0])

        p1s, p2s, distances = mutual_k_nearest_neighbors(
                        adata_m_query.obsm['X_pca_norm'], 
                        adata_m_ref.obsm['X_pca_norm'], 
                        k1=k1, k2=k2, metric='euclidean')
        
        cell_pairs_d['cell_query'].extend(adata_m_query.obs.index[p1s])
        cell_pairs_d['cell_ref'].extend(adata_m_ref.obs.index[p2s])
        cell_pairs_d['cell_type_query'].extend(adata_m_query.obs['cell_type'].values[p1s])
        cell_pairs_d['cell_type_ref'].extend(adata_m_ref.obs['cell_type'].values[p2s])
        cell_pairs_d['dataset_query'].extend(adata_m_query.obs['dataset_id'].values[p1s])
        cell_pairs_d['dataset_ref'].extend(adata_m_ref.obs['dataset_id'].values[p2s])

        mkn_df = pd.DataFrame(cell_pairs_d)

        # Filter out cell pairs that are not allowed
        filter_mask = []
        for i, row in mkn_df.iterrows():
            if row['cell_type_query'] in query_ct_map:
                if row['cell_type_ref'] in query_ct_map[row['cell_type_query']]:
                    filter_mask.append(True)
                else:
                    filter_mask.append(False)
            else:
                filter_mask.append(False)

        mkn_df = mkn_df[filter_mask]       
       
        print(f'Found {mkn_df.shape[0]} mutual k-nearest neighbors.')
        mkn_df.to_parquet(os.path.join(output_path, 
                        f'{ds_ref_name}_AND_{ds_query_name}.parquet'))
        
        gc.collect()
