In [None]:
%config InlineBackend.figure_format='retina'

In [None]:
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc
sc.settings.n_jobs = 64

In [None]:
def ct_aware_downsample(adata, sample_frac, ct_min_cells,
                        ct_ignore_threshold=10):
    # Extract seed cells for each cell type
    cell_types = list(adata.obs['cell_type'].value_counts().index)
    selected_cells = []
    for ct in cell_types:
        ct_cell_ids = list(adata.obs.index[adata.obs['cell_type'] == ct])

        if len(ct_cell_ids) < ct_ignore_threshold:
            continue

        n_cells_to_keep = max(ct_min_cells, int(len(ct_cell_ids) * sample_frac))

        if len(ct_cell_ids) > n_cells_to_keep:
            ct_cell_ids = list(np.random.choice(ct_cell_ids,
                                n_cells_to_keep, replace=False))

        selected_cells.extend(ct_cell_ids)

    return adata[selected_cells].copy()

In [None]:
from sklearn.neighbors import NearestNeighbors

def find_neighbor_edges(adata_merge):
    nbrs = NearestNeighbors(n_neighbors=9).fit(adata_merge.obsm['X_pca_integrate'])
    neighbor_indices = nbrs.kneighbors(adata_merge.obsm['X_pca_integrate'], 
                                     return_distance=False)
    
    anchor_indices = []
    positive_indices = []

    for i in range(neighbor_indices.shape[0]):
        anchor_indices.append(i)
        positive_indices.append(np.random.choice(neighbor_indices[i][1:]))

    anchor_cells = adata_merge.obs.index.values[anchor_indices]
    positive_cells = adata_merge.obs.index.values[positive_indices]
    return anchor_cells, positive_cells

In [None]:
# Get the input files
adata_input_path = '/GPUData_xingjie/SCMG/sc_rna_data/'
dataset_names = sorted([f.replace('.h5ad', '') for f in os.listdir(adata_input_path)])

standard_gene_df = pd.read_csv(
    '/GPUData_xingjie/Softwares/SCMG_dev/scmg/data/standard_genes.csv')
standard_ids = list(standard_gene_df['human_id'])

# Create the output folder
output_path = '/GPUData_xingjie/SCMG/contrastive_embedding_training/edges/intra_dataset'
os.makedirs(output_path, exist_ok=True)

In [None]:
os.listdir(adata_input_path)

In [None]:
dataset_names = [
    'AllenBrain_WB_MM_2023_all',
    'Allen_BrainAging_MM_2022_all',
    'Arutyunyan_Placenta_HS_2023_all',
    'Bhaduri_CtxDev_HS_2021_all',
    'Cao_dev_HS_2020_all',
    'Conde_Immune_HS_2022_all',
    'Cowan_Retina_HS_2020_fovea',
    'Cowan_Retina_HS_2020_organoid',
    'Cowan_Retina_HS_2020_periphery',
    'Deng_CarT_HS_2020_all',
    'Elmentaite_intestine_HS_2021_all',
    'Enge_Pancrea_HS_2017_all',
    'Eraslan_MultiTissue_HS_2022_all',
    'Fawkner-Corbett_IntestineDev_HS_2021_all',
    'Han_HS_2020_all',
    'He_LungDev_HS_2022_all',
    'Hrovatin_Pancrea_MM_2022_all',
    'Jardine_BloodDev_HS_2021_normal',
    'Khaled_Breast_HS_2023_all',
    'Kuppe_Heart_HS_2022_all',
    'LaManno_WBDev_MM_2021_all',
    'Lake_Kidney_HS_2023_all',
    'Lengyel_FallopianTube_HS_2022_all',
    'Litvinukova_Heart_HS_2020_all',
    'Park_Thymus_HS_2020_all',
    'Qiu_Organogenesis_MM_2022_all',
    'Qiu_whole_embryo_dev_MM_2024_all',
    'Sikkema_Lung_HS_2023_core',
    'Streets_Adipose_HS_2023_all',
    'Suo_ImmuneDev_HS_2022_all',
    'Tabula_Muris_MM_2020_10x',
    'Tabula_Muris_MM_2020_smart-seq',
    'Tabula_Sapiens_HS_2022_all',
    'Tyser_Embryo_HS_2021_all',
    'VentoTormo_Placenta_HS_2018_all',
    'Wiedemann_Skin_HS_2023_all',
    'Xu_HS_early_organogenesis_2023_all',
    'Yanagida_Blastocyst_HS_2021_all',
    'Yu_MultiTissue_HS_2021_all'
]

In [None]:
for ds_name in dataset_names:
    print(ds_name)

    adata = sc.read_h5ad(os.path.join(adata_input_path, f'standard_adata_{ds_name}.h5ad'))
    
    # Downsample the dataset
    adata = ct_aware_downsample(adata, sample_frac=0.1, ct_min_cells=100)
    
    if adata.shape[0] < 200:
        continue    

    adata.var.index = list(adata.var['human_gene_id'])
    adata.var_names_make_unique()
    adata_raw = adata.copy()
    display(adata)

    # Dimension reduction
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

    sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
    adata = adata[:, adata.var['highly_variable']].copy()
    sc.pp.filter_cells(adata, min_genes=20)

    sc.pp.scale(adata, max_value=10)

    sc.tl.pca(adata, svd_solver='arpack')
    adata.obsm['X_pca_integrate'] = adata.obsm['X_pca']

    # Visualize the dataset
    if adata.shape[0] > 10000:
        adata_display = sc.pp.subsample(adata, n_obs=10000, copy=True)
    else:
        adata_display = adata.copy()

    sc.pp.neighbors(adata_display, use_rep='X_pca_integrate')
    sc.tl.umap(adata_display)
    sc.pl.umap(adata_display, color='cell_type', legend_loc='on data', palette='tab20',
               legend_fontsize=5)

    # Generate the cell pair dataset
    anchor_cells, positive_cells = find_neighbor_edges(adata)

    edges_df = pd.DataFrame({
        'cell_ref': anchor_cells,
        'cell_query': positive_cells,
        'dataset_ref': adata[anchor_cells].obs['dataset_id'].values,
        'dataset_query': adata[positive_cells].obs['dataset_id'].values,
        'cell_type_ref': adata[anchor_cells].obs['cell_type'].values,
        'cell_type_query': adata[positive_cells].obs['cell_type'].values
    })

    edges_df.to_parquet(os.path.join(output_path, f'{ds_name}.parquet'))