In [None]:
%config InlineBackend.figure_format='retina'

In [None]:
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc
sc.settings.n_jobs = 64

In [None]:
import scipy.sparse

def get_standard_exp(adata, standard_ids):
    if scipy.sparse.issparse(adata.X):
        X_adata = adata.X.toarray()
    else:
        X_adata = adata.X.copy()

    adata_var_ids = list(adata.var.index)
    common_ids = np.intersect1d(standard_ids, adata_var_ids)
    common_in_standard_indices = [standard_ids.index(g) for g in common_ids]
    common_in_adata_indices = [adata_var_ids.index(g) for g in common_ids]

    X_input = np.zeros((adata.shape[0], len(standard_ids)), dtype=np.float32)
    X_input[:, common_in_standard_indices] = X_adata[:, common_in_adata_indices]

     # Record which genes are measured in this dataset
    X_measure = np.zeros((adata.shape[0], len(standard_ids)), dtype=bool)
    X_measure[:, common_in_standard_indices] = True

    return X_input, X_measure

In [None]:
from datasets import Dataset

def generate_dataset(dataset_path_prefix, adata_raw, standard_ids):
    
    # Subset to common genes
    common_genes = np.intersect1d(standard_ids, adata_raw.var['human_gene_id'])
    adata = adata_raw[:, adata_raw.var['human_gene_id'].isin(common_genes)].copy()
    adata.var.index = adata.var['human_gene_id']

    # Normalize the expression
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

    # Save the data by chunk
    chunk_size = 50000
    N_cells = adata.shape[0]
    N_chunks = int(np.ceil(N_cells / chunk_size))
    
    for i in range(N_chunks):
        start = i * chunk_size
        stop = min((i + 1) * chunk_size, N_cells)
        
        adata_local = adata[start:stop]
        X_exp, X_measure = get_standard_exp(adata_local, standard_ids)

        # Save the dataset
        dataset = Dataset.from_dict({
            'X_exp': X_exp.astype(np.float32),
            'X_measure': X_measure.astype(np.float32),
            'cell_id': adata_local.obs.index.values,
            'dataset_id': adata_local.obs['dataset_id'].values
                                     })
        
        save_path = f'{dataset_path_prefix}_{i}'
        if os.path.exists(save_path):
            shutil.rmtree(save_path)
        dataset.save_to_disk(save_path)

In [None]:
# Get the input files
adata_input_path = '/GPUData_xingjie/SCMG/sc_rna_data/'
dataset_names = sorted([f.replace('.h5ad', '') for f in os.listdir(adata_input_path)])

standard_gene_df = pd.read_csv(
    '/GPUData_xingjie/Softwares/SCMG_dev/scmg/data/standard_genes.csv')
standard_ids = list(standard_gene_df['human_id'])

# Create the output folder
output_path = '/GPUData_xingjie/SCMG/contrastive_embedding_training/exp_data/datasets'
os.makedirs(output_path, exist_ok=True)

In [None]:
os.listdir(adata_input_path)

In [None]:
dataset_names = [
    'AllenBrain_WB_MM_2023_all',
    'Allen_BrainAging_MM_2022_all',
    'Arutyunyan_Placenta_HS_2023_all',
    'Bhaduri_CtxDev_HS_2021_all',
    'Cao_dev_HS_2020_all',
    'Conde_Immune_HS_2022_all',
    'Cowan_Retina_HS_2020_fovea',
    'Cowan_Retina_HS_2020_organoid',
    'Cowan_Retina_HS_2020_periphery',
    'Deng_CarT_HS_2020_all',
    'Elmentaite_intestine_HS_2021_all',
    'Enge_Pancrea_HS_2017_all',
    'Eraslan_MultiTissue_HS_2022_all',
    'Fawkner-Corbett_IntestineDev_HS_2021_all',
    'Han_HS_2020_all',
    'He_LungDev_HS_2022_all',
    'Hrovatin_Pancrea_MM_2022_all',
    'Jardine_BloodDev_HS_2021_normal',
    'Khaled_Breast_HS_2023_all',
    'Kuppe_Heart_HS_2022_all',
    'LaManno_WBDev_MM_2021_all',
    'Lake_Kidney_HS_2023_all',
    'Lengyel_FallopianTube_HS_2022_all',
    'Litvinukova_Heart_HS_2020_all',
    'Park_Thymus_HS_2020_all',
    'Qiu_Organogenesis_MM_2022_all',
    'Qiu_whole_embryo_dev_MM_2024_all',
    'Sikkema_Lung_HS_2023_core',
    'Streets_Adipose_HS_2023_all',
    'Suo_ImmuneDev_HS_2022_all',
    'Tabula_Muris_MM_2020_10x',
    'Tabula_Muris_MM_2020_smart-seq',
    'Tabula_Sapiens_HS_2022_all',
    'Tyser_Embryo_HS_2021_all',
    'VentoTormo_Placenta_HS_2018_all',
    'Wiedemann_Skin_HS_2023_all',
    'Xu_HS_early_organogenesis_2023_all',
    'Yanagida_Blastocyst_HS_2021_all',
    'Yu_MultiTissue_HS_2021_all'
]

In [None]:
for ds_name in dataset_names:
    print(ds_name)

    adata = sc.read_h5ad(os.path.join(adata_input_path, f'standard_adata_{ds_name}.h5ad'))
    
    adata.var.index = list(adata.var['human_gene_id'])
    adata.var_names_make_unique()
    adata_raw = adata.copy()
    display(adata)

    output_prefix = os.path.join(output_path, ds_name)
    generate_dataset(output_prefix, adata_raw, standard_ids)