In [None]:
import scanpy as sc
import numpy as np
import anndata as ad
import pandas as pd
from scipy.sparse import csr_matrix
from tqdm import tqdm 
from collections import defaultdict
import itertools
import argparse

import sys 
sys.path.insert(0, "..")

import os
import pickle as pkl
from cfp import preprocessing as cfpp

ModuleNotFoundError: No module named 'cfp'

In [1]:
ood_condition = 'K562_INS'
ood_pathway = ood_condition.split('_')[1]
ood_cell_type = ood_condition.split('_')[0]

ood_path = f'/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_pathway_cell_type/{ood_cell_type}_{ood_pathway}/adata_ood_{ood_cell_type}_{ood_pathway}.h5ad'
train_path = f'/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_pathway_cell_type/{ood_cell_type}_{ood_pathway}/adata_train_{ood_cell_type}_{ood_pathway}.h5ad'

ood = sc.read_h5ad(ood_path)
train = sc.read_h5ad(train_path)

ood_genes = ood.var_names

NameError: name 'sc' is not defined

In [None]:
ms = 0.5
# Pathway string use to parse .h5ad
pathway = 'IFNG_IFNB_TNFA_TGFB_INS'

# The final output dir
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_pathway_cell_type_full_source/" + ood_cell_type + "_" + ood_pathway 
os.makedirs(output_dir, exist_ok=True)

genes_from_paper = [
    "AHNAK", "RNF213", "APOL6", "ASTN2", "B2M", "CFH", "CXCL9", "DENND4A", 
    "DOCK9", "EFNA5", "ERAP2", "FAT1", "GBP1", "GBP4", "HAPLN3", "HSPG2", 
    "IDO1", "IFI6", "IRF1", "LAP3", "LI", "LINC02328", "MAGI1", "MUC4", 
    "NLRC5", "NUB1", "PARP14", "PARP9", "RARRES1", "RNF213", "ROR1", "SCN9A", 
    "SERPING1", "ST5", "STAT1", "TAP1", "TAP2", "THBS1", "THSD4", "TPM1", "VCL", 
    "WARS", "XRN1"
]

# Read the data 
datasets = []
for pw in pathway.split('_'):
    if ms == None:
        data_path = '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/' + pw + '_Perturb_seq.h5ad' # '_Perturb_seq_ms_0.5.h5ad'
    else:
        data_path = '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/' + pw + '_Perturb_seq_ms_' + str(ms) + '.h5ad'
    print('Loading dataset from ' + data_path)
    dataset = sc.read_h5ad(data_path)
    dataset.obs['pathway'] = pw
    datasets.append(dataset)

# Create common anndata 
adata = ad.concat(datasets, join='outer')
print('Datasets concatenated')

# Make the variable names unique
adata.obs_names_make_unique()

# Drop unused columns 
columns_to_drop = ['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'mixscale_score', 'RNA_snn_res.0.9', 'seurat_clusters']
adata.obs.drop(columns=columns_to_drop, inplace=True)
print('Unnecessary columns dropped')
            
# Add specific columns to adata.obs 
adata.obs['condition'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.pathway, x.gene]), axis=1)
adata.obs['background'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.pathway]), axis=1)

# Filter very rare perturbation classes 
condition_counts = adata.obs['condition'].value_counts()
filtered_conditions = condition_counts[condition_counts >= 100]  # Keep only some conditions
adata = adata[adata.obs['condition'].isin(filtered_conditions.index)]
print(f"Filtered adata for perturbation count: {adata.shape[0]} observations remaining")

adata.layers["counts"] = adata.layers["counts"].astype(np.float32)
adata.X = csr_matrix(adata.layers["counts"])
del adata.layers['counts']
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/IFNG_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/IFNB_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/TNFA_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/TGFB_Perturb_seq_ms_0.5.h5ad
Loading dataset from /lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/satija_merged/INS_Perturb_seq_ms_0.5.h5ad


  utils.warn_names_duplicates("obs")


Datasets concatenated
Unnecessary columns dropped
Filtered adata for perturbation count: 618023 observations remaining


  adata.layers["counts"] = adata.layers["counts"].astype(np.float32)


In [None]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_pathway_cell_type_full_source/" + ood_cell_type + "_" + ood_pathway 


In [None]:
adata = adata[:, ood_genes]

In [None]:
adata.obs['is_ood'] = adata.obs.apply(lambda x: x['pathway'] == ood_pathway and x['cell_type'] == ood_cell_type, axis=1)

  adata.obs['is_ood'] = adata.obs.apply(lambda x: x['pathway'] == ood_pathway and x['cell_type'] == ood_cell_type, axis=1)


In [None]:
adata_ood = adata[adata.obs['is_ood'],:]

In [None]:
adata_ood.uns = train.uns
adata_ood.varm['X_mean'] = train.varm['X_mean']

adata_train_final_mean = train.varm["X_mean"].flatten()
adata_ood.layers["centered_X"] = csr_matrix(adata_ood.X.toarray() - adata_train_final_mean)
adata_ood.obsm["X_pca"] = np.matmul(adata_ood.layers["centered_X"].toarray(), train.varm["PCs"])
adata_ood.obs['control'] = adata_ood.obs.apply(lambda x: x['gene'] == 'NT', axis=1)

In [None]:
adata_ood.uns['gene_emb'] = train.uns['gene_emb']
adata_ood.uns['cell_type_emb'] = train.uns['cell_type_emb']
adata_ood.uns['pathway_emb'] = train.uns['pathway_emb']

In [None]:
output_dir

'/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/adata_ood_final_pathway_cell_type_full_source/HT29_TNFA'

In [None]:
adata_ood.write(os.path.join(output_dir, "adata_ood_" + ood_condition + ".h5ad"))