In [None]:
import os

from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import scipy.sparse

import anndata
import scanpy as sc
import datasets
from datasets import Dataset

import torch
from scmg.model.contrastive_embedding import (CellEmbedder, 
                         embed_standardized_adata)

In [None]:
import torch
from datasets import Dataset
torch.set_float32_matmul_precision('high')

In [None]:
# Load the autoencoder model
model_ce_path = '../contrastive_embedding/trained_embedder/'

model_ce = torch.load(os.path.join(model_ce_path, 'model.pt'))
model_ce.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model_ce.to(device)
model_ce.eval()

In [None]:
# Get the input files
adata_input_path = '/GPUData_xingjie/SCMG/sc_rna_data/'
dataset_names = sorted([f.replace('.h5ad', '') for f in os.listdir(adata_input_path)])

standard_gene_df = pd.read_csv(
    '/GPUData_xingjie/Softwares/SCMG_dev/scmg/data/standard_genes.csv')
standard_ids = list(standard_gene_df['human_id'])

# Create the output folder
output_path = '/GPUData_xingjie/SCMG/manifold_generator_training'
os.makedirs(os.path.join(output_path, 'datasets'), exist_ok=True)

In [None]:
dataset_names

In [None]:
dataset_names = [
    'standard_adata_AllenBrain_WB_MM_2023_all',
    'standard_adata_Allen_BrainAging_MM_2022_all',
    'standard_adata_Arutyunyan_Placenta_HS_2023_all',
    'standard_adata_Bhaduri_CtxDev_HS_2021_all',
    'standard_adata_Cao_dev_HS_2020_all',
    'standard_adata_Conde_Immune_HS_2022_all',
    'standard_adata_Cowan_Retina_HS_2020_fovea',
    'standard_adata_Cowan_Retina_HS_2020_periphery',
    'standard_adata_Deng_CarT_HS_2020_all',
    'standard_adata_Elmentaite_intestine_HS_2021_all',
    'standard_adata_Enge_Pancrea_HS_2017_all',
    'standard_adata_Eraslan_MultiTissue_HS_2022_all',
    'standard_adata_Fawkner-Corbett_IntestineDev_HS_2021_all',
    'standard_adata_Han_HS_2020_all',
    'standard_adata_He_LungDev_HS_2022_all',
    'standard_adata_Hrovatin_Pancrea_MM_2022_all',
    'standard_adata_Jardine_BloodDev_HS_2021_normal',
    'standard_adata_Khaled_Breast_HS_2023_all',
    'standard_adata_Kuppe_Heart_HS_2022_all',
    'standard_adata_LaManno_WBDev_MM_2021_all',
    'standard_adata_Lake_Kidney_HS_2023_all',
    'standard_adata_Lengyel_FallopianTube_HS_2022_all',
    'standard_adata_Litvinukova_Heart_HS_2020_all',
    'standard_adata_Park_Thymus_HS_2020_all',
    'standard_adata_Qiu_Organogenesis_MM_2022_all',
    'standard_adata_Qiu_whole_embryo_dev_MM_2024_all',
    'standard_adata_Sikkema_Lung_HS_2023_core',
    'standard_adata_Streets_Adipose_HS_2023_all',
    'standard_adata_Suo_ImmuneDev_HS_2022_all',
    'standard_adata_Tabula_Muris_MM_2020_10x',
    'standard_adata_Tabula_Muris_MM_2020_smart-seq',
    'standard_adata_Tabula_Sapiens_HS_2022_all',
    'standard_adata_Tyser_Embryo_HS_2021_all',
    'standard_adata_VentoTormo_Placenta_HS_2018_all',
    'standard_adata_Wiedemann_Skin_HS_2023_all',
    'standard_adata_Yanagida_Blastocyst_HS_2021_all',
    'standard_adata_Yu_MultiTissue_HS_2021_all'
 ]

In [None]:
# Get the cell types considered for the embedder training
edge_file_paths = [
    '/GPUData_xingjie/SCMG/contrastive_embedding_training/edges/inter_dataset',
    '/GPUData_xingjie/SCMG/contrastive_embedding_training/edges/intra_dataset_core',
    '/GPUData_xingjie/SCMG/contrastive_embedding_training/edges/intra_integration'
]

edge_df_list = []

for edge_file_path in edge_file_paths:
    for f in os.listdir(edge_file_path):
        if f.endswith('.parquet'):
            edge_df = pd.read_parquet(os.path.join(edge_file_path, f))
            edge_df_list.append(edge_df)

edge_df = pd.concat(edge_df_list, axis=0)

ds_ct_dict = {ds : set() for ds in np.unique(edge_df['dataset_query'])}


for i, row in edge_df.iterrows():
    ds_ct_dict[row['dataset_query']].add(row['cell_type_query'])
    ds_ct_dict[row['dataset_ref']].add(row['cell_type_ref'])

import json
for ds, cts in ds_ct_dict.items():
    ds_ct_dict[ds] = sorted(cts)
with open('dataset_cell_types.json', 'w') as f:
    f.write(json.dumps(ds_ct_dict))

all_cell_types = set()
for ds, cts in ds_ct_dict.items():
    all_cell_types.update(cts)

pd.DataFrame({
    'cell_type' : sorted(all_cell_types)
}).to_csv('cell_types.csv', index=False)

In [None]:
for ds in ds_ct_dict:
    dataset_name = f'standard_adata_{ds}'.replace(':', '_')
    if dataset_name not in dataset_names:
        continue

    print(dataset_name)
    adata = sc.read_h5ad(os.path.join(adata_input_path, f'{dataset_name}.h5ad'))
    display(adata)

    # Choose cell types to keep
    adata = adata[adata.obs['cell_type'].isin(ds_ct_dict[ds])]

    # Subset to the standard genes
    common_genes = np.intersect1d(standard_ids, adata.var['human_gene_id'])
    adata = adata[:, adata.var['human_gene_id'].isin(common_genes)].copy()
    adata.var.index = adata.var['human_gene_id']
    sc.pp.filter_cells(adata, min_genes=100)
    display(adata)

    # Get the indices of standard genes to save
    adata_var_ids = list(adata.var.index)
    common_ids = np.intersect1d(standard_ids, adata_var_ids)
    common_in_standard_indices = [standard_ids.index(g) for g in common_ids]
    common_in_adata_indices = [adata_var_ids.index(g) for g in common_ids]

    # Save the data by chunk
    chunk_size = 50000
    N_cells = adata.shape[0]
    N_chunks = int(np.ceil(N_cells / chunk_size))
    
    for i in range(N_chunks):
        start = i * chunk_size
        stop = min((i + 1) * chunk_size, N_cells)
        n_chunk_cells = stop - start
        
        # Get tht standardized expressions for the chunk cells and their neighbors
        adata_chunk = adata[start:stop].copy()
        
        # Record which genes are measured in this dataset
        X_measure = np.zeros((n_chunk_cells, len(standard_ids)), dtype=bool)
        X_measure[:, common_in_standard_indices] = True

        # Get the standardized expressions
        if scipy.sparse.issparse(adata_chunk.X):
            X_chunk = adata_chunk.X.toarray()
        else:
            X_chunk = adata_chunk.X

        X_chunk_standard = np.zeros((X_chunk.shape[0], 
                        len(standard_ids)), dtype=np.float32)
        X_chunk_standard[:, common_in_standard_indices] = X_chunk[:, 
                                                        common_in_adata_indices]

        adata_chunk_standard = anndata.AnnData(X=X_chunk_standard,
                                       obs=adata_chunk.obs.copy(),
                                       var=pd.DataFrame(index=standard_ids))

        # Embed the standardized expressions
        embed_standardized_adata(model_ce, adata_chunk_standard, inplace=True)

        # Save the dataset
        dataset = Dataset.from_dict({
                 'X_ce_latent': adata_chunk_standard.obsm['X_ce_latent'],
                 'cell_type' : adata_chunk.obs['cell_type'].values,
                 })
        dataset.save_to_disk(os.path.join(output_path, 'datasets', f'{dataset_name}_{i}'))
