In [1]:
import numpy as np
import pandas as pd
import pertpy
import scanpy as sc
from rdkit import Chem
import anndata
from cfp import preprocessing as cfpp

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(pertpy.__version__)

0.9.4


In [3]:
df_cell_line = pd.read_csv(
    "/lustre/groups/ml01/workspace/alejandro.tejada/super_rad_project/cell_line_embeddings/cell_line_embedding_full_ccle_300.csv"
)

In [4]:
mcf7_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "MCF7"].iloc[0, 1:].values
k562_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "K562"].iloc[0, 1:].values
a549_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "A549"].iloc[0, 1:].values

In [5]:
adata = pertpy.data.sciplex3_raw()

In [6]:
adata.obs["cell_type"].value_counts()

cell_type
MCF7    292010
K562    146752
A549    143015
Name: count, dtype: int64

In [7]:
adata.obs["perturbation"] = adata.obs["product_name"]
adata.obs["drug"] = adata.obs.apply(lambda x: x["product_name"].replace(" ", "_"), axis=1)
adata.obs["cell_line"] = adata.obs["cell_type"]
adata.obs["logdose"] = adata.obs.apply(lambda x: np.log10(x["dose"]) if x["dose"] > 0.0 else 0.0, axis=1)
adata.obs.loc[:, "condition"] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.drug, str(x.dose)]), axis=1)

In [8]:
conditions_to_keep = adata.obs.groupby("condition").apply(lambda x: len(x) >= 100)

  conditions_to_keep = adata.obs.groupby("condition").apply(lambda x: len(x) >= 100)


In [9]:
conds_to_keep = list(conditions_to_keep[conditions_to_keep == True].index)

In [10]:
adata = adata[adata.obs["condition"].isin(conds_to_keep)]

In [11]:
adata.n_obs

572853

In [12]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

  view_to_actual(adata)


In [13]:
sc.pp.highly_variable_genes(adata, inplace=True, n_top_genes=2000)

In [14]:
adata = adata[:, adata.var["highly_variable"] == True]

In [15]:
adata = adata[~adata.obs["drug"].isnull()]

In [16]:
sc.pp.filter_cells(adata, min_genes=10)

  adata.obs["n_genes"] = number


In [17]:
pertpy.md.Compound().annotate_compounds(adata)

[94m•[0m There are 189 identifiers in `adata.obs`.However, 6 identifiers can't be found in the compound annotation,leading to the presence of NA values for their respective metadata.
Please check again: *unmatched_identifiers[:verbosity]...


AnnData object with n_obs × n_vars = 571429 × 2001
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'perturbation', 'drug', 'cell_line', 'logdose', 'condition', 'n_genes', 'pubchem_name', 'pubchem_ID', 'smiles'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'hvg'

In [18]:
# taken from pubchem
smiles_dict = {
    "Dacinostat (LAQ824)": "C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)/C=C/C(=O)NO",
    "Glesatinib?(MGCD265)": "COCCNCC1=CN=C(C=C1)C2=CC3=NC=CC(=C3S2)OC4=C(C=C(C=C4)NC(=S)NC(=O)CC5=CC=C(C=C5)F)F",
    "MC1568": "CN1C=C(C=C1/C=C/C(=O)NO)/C=C/C(=O)C2=CC(=CC=C2)F",
    "Ivosidenib (AG-120)": "C1CC(=O)N([C@@H]1C(=O)N(C2=CC(=CN=C2)F)[C@@H](C3=CC=CC=C3Cl)C(=O)NC4CC(C4)(F)F)C5=NC=CC(=C5)C#N",
    "Bisindolylmaleimide IX (Ro 31-8220 Mesylate)": "CN1C=C(C2=CC=CC=C21)C3=C(C(=O)NC3=O)C4=CN(C5=CC=CC=C54)CCCSC(=N)N",
}

In [19]:
adata.obs["smiles"].value_counts(dropna=False)

smiles
NaN                                                                     26456
C1CCC(C1)C(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                             6530
COC1=CC=CC(=C1N)C2=CC(=O)C3=CC=CC=C3O2                                   3758
C1C(C1N)C2=CC=CC=C2.Cl                                                   3737
CC(C1=CC=CC=C1)NC(=O)C(=CC2=NC(=CC=C2)Br)C#N                             3710
                                                                        ...  
C1CC1NC(=O)NC2=C(NN=C2)C3=NC4=C(N3)C=C(C=C4)CN5CCOCC5                    1491
CN1CCC(C(C1)O)C2=C(C=C(C3=C2OC(=CC3=O)C4=CC=CC=C4Cl)O)O.Cl               1331
CC1CCCC2(C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C)C     1078
CC1CCCC2C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C         734
CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=NC=CN=C4.[Br-]              336
Name: count, Length: 183, dtype: int64

In [20]:
adata.obs["smiles"] = adata.obs[["smiles", "perturbation"]].apply(
    lambda x: smiles_dict[x["perturbation"]] if x["perturbation"] in smiles_dict.keys() else x["smiles"], axis=1
)

In [21]:
adata.obs["smiles"].value_counts(dropna=False)

smiles
NaN                                                                     12968
C1CCC(C1)C(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                             6530
COC1=CC=CC(=C1N)C2=CC(=O)C3=CC=CC=C3O2                                   3758
C1C(C1N)C2=CC=CC=C2.Cl                                                   3737
CC(C1=CC=CC=C1)NC(=O)C(=CC2=NC(=CC=C2)Br)C#N                             3710
                                                                        ...  
C1CC1NC(=O)NC2=C(NN=C2)C3=NC4=C(N3)C=C(C=C4)CN5CCOCC5                    1491
CN1CCC(C(C1)O)C2=C(C=C(C3=C2OC(=CC3=O)C4=CC=CC=C4Cl)O)O.Cl               1331
CC1CCCC2(C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C)C     1078
CC1CCCC2C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C         734
CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=NC=CN=C4.[Br-]              336
Name: count, Length: 188, dtype: int64

In [22]:
adata[adata.obs["smiles"].isnull()].obs.dose.value_counts()

dose
0.0    12968
Name: count, dtype: int64

In [23]:
adata[~adata.obs["smiles"].isnull()].obs.dose.value_counts()

dose
10.0       151665
100.0      145870
1000.0     139899
10000.0    121027
Name: count, dtype: int64

In [24]:
from rdkit.Chem import AllChem


def get_fp(smiles, radius=4, nBits=1024):
    m = Chem.MolFromSmiles(smiles, sanitize=False)
    if m is None:
        return "invalid"
    else:
        try:
            Chem.SanitizeMol(m)
        except:
            return "invalid"
    return AllChem.GetHashedMorganFingerprint(m, radius=radius, nBits=nBits)

In [25]:
adata.obs["smiles"].value_counts(dropna=False)

smiles
NaN                                                                     12968
C1CCC(C1)C(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                             6530
COC1=CC=CC(=C1N)C2=CC(=O)C3=CC=CC=C3O2                                   3758
C1C(C1N)C2=CC=CC=C2.Cl                                                   3737
CC(C1=CC=CC=C1)NC(=O)C(=CC2=NC(=CC=C2)Br)C#N                             3710
                                                                        ...  
C1CC1NC(=O)NC2=C(NN=C2)C3=NC4=C(N3)C=C(C=C4)CN5CCOCC5                    1491
CN1CCC(C(C1)O)C2=C(C=C(C3=C2OC(=CC3=O)C4=CC=CC=C4Cl)O)O.Cl               1331
CC1CCCC2(C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C)C     1078
CC1CCCC2C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C         734
CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=NC=CN=C4.[Br-]              336
Name: count, Length: 188, dtype: int64

In [26]:
adata[adata.obs["smiles"].isnull()].obs["condition"].value_counts()

condition
MCF7_Vehicle_0.0    6327
K562_Vehicle_0.0    3354
A549_Vehicle_0.0    3287
Name: count, dtype: int64

In [27]:
smiles_to_fp = {}
for sm in adata.obs["smiles"].unique():
    if not isinstance(sm, str):
        continue
    smiles_to_fp[sm] = np.array(list(get_fp(sm)))



In [28]:
features_df = pd.DataFrame.from_dict(smiles_to_fp).T

In [29]:
features_cells = np.zeros((adata.shape[0], features_df.shape[1]))
for mol, ecfp in features_df.iterrows():
    features_cells[adata.obs["smiles"].isin([mol])] = ecfp.values

In [30]:
adata.obsm["ecfp"] = features_cells

In [31]:
def rank_genes_groups_by_cov(
    adata,
    groupby,
    control_group,
    covariate,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):
    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        # name of the control group in the groupby obs column
        control_group_cov = "_".join([cov_cat, "Vehicle", "0.0"])
        adata_cov = adata[adata.obs[covariate] == cov_cat]

        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
            use_raw=False,
        )
        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()
    adata.uns[key_added] = gene_dict
    if return_dict:
        return gene_dict


def get_DE_genes(adata):
    adata.obs.loc[:, "control"] = adata.obs.perturbation.apply(lambda x: 1 if x == "Vehicle" else 0)
    adata.obs = adata.obs.astype("category")
    rank_genes_groups_by_cov(
        adata,
        groupby="condition",
        covariate="cell_type",
        control_group="Vehicle",
        n_genes=50,
        key_added="rank_genes_groups_cov_all",
    )
    return adata

In [32]:
for col in adata.obs.select_dtypes(include=["category"]):
    adata.obs[col].cat.remove_unused_categories()

In [33]:
adata = get_DE_genes(adata)

  adata.uns[key_added] = {}
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global

In [34]:
len(adata.obs["condition"].unique())

2128

In [35]:
len(adata.obs["drug"].unique())

189

In [36]:
adata.obs["drug"].value_counts()

drug
Vehicle                              12968
PD98059                               3758
Tranylcypromine_(2-PCPA)_HCl          3737
WP1066                                3710
RG108                                 3706
                                     ...  
AT9283                                1491
Flavopiridol_HCl                      1331
Patupilone_(EPO906,_Epothilone_B)     1078
Epothilone_A                           734
YM155_(Sepantronium_Bromide)           336
Name: count, Length: 189, dtype: int64

In [37]:
188/4

47.0

In [38]:
drugs = list(set(adata.obs["drug"].unique()) - set(("Vehicle",)))

In [39]:
len(drugs)

188

In [40]:
rng = np.random.default_rng(seed=0)
numbers = np.arange(len(adata.obs["drug"].unique())-1)
rng.shuffle(numbers)
subset_1 = numbers[:47]
subset_2 = numbers[47:2*47]
subset_3 = numbers[2*47:3*47]
subset_4 = numbers[3*47:]
ood_conditions_1 = [drugs[el] for el in subset_1]
ood_conditions_2 = [drugs[el] for el in subset_2]
ood_conditions_3 = [drugs[el] for el in subset_3]
ood_conditions_4 = [drugs[el] for el in subset_4]

In [41]:
len(subset_1), len(subset_2), len(subset_3), len(subset_4)

(47, 47, 47, 47)

In [42]:
ood_conditions_5 = [
    "Hesperadin",
    "TAK-901",
    "Dacinostat_(LAQ824)",
    "Givinostat_(ITF2357)",
    "Belinostat_(PXD101)",
    "Quisinostat_(JNJ-26481585)_2HCl",
    "Alvespimycin_(17-DMAG)_HCl",
    "Tanespimycin_(17-AAG)",
    "Flavopiridol_HCl",
]

In [43]:
adata.obs["ood_1"] = adata.obs.apply(lambda x: x["condition"] if x["drug"] in ood_conditions_1 else "not ood", axis=1)
adata.obs["ood_1"] = adata.obs["ood_1"].astype("category")
adata.obs["ood_2"] = adata.obs.apply(lambda x: x["condition"] if x["drug"] in ood_conditions_2 else "not ood", axis=1)
adata.obs["ood_2"] = adata.obs["ood_2"].astype("category")
adata.obs["ood_3"] = adata.obs.apply(lambda x: x["condition"] if x["drug"] in ood_conditions_3 else "not ood", axis=1)
adata.obs["ood_3"] = adata.obs["ood_3"].astype("category")
adata.obs["ood_4"] = adata.obs.apply(lambda x: x["condition"] if x["drug"] in ood_conditions_4 else "not ood", axis=1)
adata.obs["ood_4"] = adata.obs["ood_4"].astype("category")
adata.obs["ood_5"] = adata.obs.apply(lambda x: x["condition"] if x["drug"] in ood_conditions_5 else "not ood", axis=1)
adata.obs["ood_5"] = adata.obs["ood_5"].astype("category")

In [44]:
adata.write_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/full_adata_with_splits.h5ad")

In [45]:
adata_train_1 = adata[~adata.obs["drug"].isin(ood_conditions_1)].copy()
adata_ood_1 = adata[adata.obs["drug"].isin(ood_conditions_1)].copy()
adata_train_2 = adata[~adata.obs["drug"].isin(ood_conditions_2)].copy()
adata_ood_2 = adata[adata.obs["drug"].isin(ood_conditions_2)].copy()
adata_train_3 = adata[~adata.obs["drug"].isin(ood_conditions_3)].copy()
adata_ood_3 = adata[adata.obs["drug"].isin(ood_conditions_3)].copy()
adata_train_4 = adata[~adata.obs["drug"].isin(ood_conditions_4)].copy()
adata_ood_4 = adata[adata.obs["drug"].isin(ood_conditions_4)].copy()
adata_train_5 = adata[~adata.obs["drug"].isin(ood_conditions_5)].copy()
adata_ood_5 = adata[adata.obs["drug"].isin(ood_conditions_5)].copy()

In [46]:
def make_splits(adata_train, adata_ood):
    rng = np.random.default_rng(0)
    split_dfs = []
    for cond in adata_train.obs["condition"].unique():
        adata_subset = adata_train[(adata_train.obs["condition"] == cond)]
        n_cells = adata_subset.n_obs
        if "Vehicle" in cond:
            idx_test = rng.choice(np.arange(n_cells), 500, replace=False)
            remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
            split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
            df = adata_subset.obs[["condition"]].copy()
            df["split"] = split
            split_dfs.append(df)
        else:
            idx_test = rng.choice(np.arange(n_cells), 50, replace=False)
            split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
            df = adata_subset.obs[["condition"]].copy()
            df["split"] = split
            split_dfs.append(df)
    df_concat = pd.concat(split_dfs, axis=0)
    adata_train.obs["split"] = df_concat[["split"]]
    adata_ood.obs["split"] = "ood"
    adata_train_final = adata_train[adata_train.obs["split"] == "train"]
    adata_test_final = adata_train[adata_train.obs["split"] == "test"]
    adata_ood_final = anndata.concat((adata_ood, adata_test_final[adata_test_final.obs["condition"].str.contains("Vehicle")]))

    return adata_train_final, adata_test_final, adata_ood_final

In [47]:
adata_train_1, adata_test_1, adata_ood_1 =  make_splits(adata_train_1, adata_ood_1)
adata_train_2, adata_test_2, adata_ood_2 =  make_splits(adata_train_2, adata_ood_2)
adata_train_3, adata_test_3, adata_ood_3 =  make_splits(adata_train_3, adata_ood_3)
adata_train_4, adata_test_4, adata_ood_4 =  make_splits(adata_train_4, adata_ood_4)
adata_train_5, adata_test_5, adata_ood_5 =  make_splits(adata_train_5, adata_ood_5)

In [48]:
cfpp.centered_pca(adata_train_1, n_comps=300)
cfpp.project_pca(query_adata = adata_test_1, ref_adata=adata_train_1)
cfpp.project_pca(query_adata = adata_ood_1, ref_adata=adata_train_1)

cfpp.centered_pca(adata_train_2, n_comps=300)
cfpp.project_pca(query_adata = adata_test_2, ref_adata=adata_train_2)
cfpp.project_pca(query_adata = adata_ood_2, ref_adata=adata_train_2)

cfpp.centered_pca(adata_train_3, n_comps=300)
cfpp.project_pca(query_adata = adata_test_3, ref_adata=adata_train_3)
cfpp.project_pca(query_adata = adata_ood_3, ref_adata=adata_train_3)

cfpp.centered_pca(adata_train_4, n_comps=300)
cfpp.project_pca(query_adata = adata_test_4, ref_adata=adata_train_4)
cfpp.project_pca(query_adata = adata_ood_4, ref_adata=adata_train_4)

cfpp.centered_pca(adata_train_5, n_comps=300)
cfpp.project_pca(query_adata = adata_test_5, ref_adata=adata_train_5)
cfpp.project_pca(query_adata = adata_ood_5, ref_adata=adata_train_5)

  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(
  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(
  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(
  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(
  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(


In [49]:
adata_test_1.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_test_2.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_test_3.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_test_4.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_test_5.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']

adata_test_1.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_test_2.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_test_3.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_test_4.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_test_5.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']

In [50]:
adata_train_1.varm["X_mean"] = np.asarray(adata_train_1.varm["X_mean"])
adata_train_2.varm["X_mean"] = np.asarray(adata_train_2.varm["X_mean"])
adata_train_3.varm["X_mean"] = np.asarray(adata_train_3.varm["X_mean"])
adata_train_4.varm["X_mean"] = np.asarray(adata_train_4.varm["X_mean"])
adata_train_5.varm["X_mean"] = np.asarray(adata_train_5.varm["X_mean"])

In [51]:
cell_line_dict = {'MCF7': np.asarray(mcf7_emb).astype("float"),
                 'K562': np.asarray(k562_emb).astype("float"),
                 'A549': np.asarray(a549_emb).astype("float")}


In [52]:
adata_train_1.uns['cell_line_dict'] = cell_line_dict
adata_train_2.uns['cell_line_dict'] = cell_line_dict
adata_train_3.uns['cell_line_dict'] = cell_line_dict
adata_train_4.uns['cell_line_dict'] = cell_line_dict
adata_train_5.uns['cell_line_dict'] = cell_line_dict

adata_test_1.uns['cell_line_dict'] = cell_line_dict
adata_test_2.uns['cell_line_dict'] = cell_line_dict
adata_test_3.uns['cell_line_dict'] = cell_line_dict
adata_test_4.uns['cell_line_dict'] = cell_line_dict
adata_test_5.uns['cell_line_dict'] = cell_line_dict

adata_ood_1.uns['cell_line_dict'] = cell_line_dict
adata_ood_2.uns['cell_line_dict'] = cell_line_dict
adata_ood_3.uns['cell_line_dict'] = cell_line_dict
adata_ood_4.uns['cell_line_dict'] = cell_line_dict
adata_ood_5.uns['cell_line_dict'] = cell_line_dict


In [53]:
df_red = adata.obs.drop_duplicates(subset="drug")
drug_to_smiles = dict(zip(df_red["drug"], df_red["smiles"]))

In [54]:
drug_to_ecfp = {drug: smiles_to_fp[smile] for drug, smile in drug_to_smiles.items() if drug != "Vehicle"}

In [55]:
len(drug_to_ecfp) - len(adata.obs["drug"].unique())

-1

In [56]:
set(adata.obs["drug"].unique()) - set(drug_to_ecfp.keys())

{'Vehicle'}

In [57]:
adata_train_1.uns['ecfp_dict'] = drug_to_ecfp
adata_train_2.uns['ecfp_dict'] = drug_to_ecfp
adata_train_3.uns['ecfp_dict'] = drug_to_ecfp
adata_train_4.uns['ecfp_dict'] = drug_to_ecfp
adata_train_5.uns['ecfp_dict'] = drug_to_ecfp

adata_test_1.uns['ecfp_dict'] = drug_to_ecfp
adata_test_2.uns['ecfp_dict'] = drug_to_ecfp
adata_test_3.uns['ecfp_dict'] = drug_to_ecfp
adata_test_4.uns['ecfp_dict'] = drug_to_ecfp
adata_test_5.uns['ecfp_dict'] = drug_to_ecfp

adata_ood_1.uns['ecfp_dict'] = drug_to_ecfp
adata_ood_2.uns['ecfp_dict'] = drug_to_ecfp
adata_ood_3.uns['ecfp_dict'] = drug_to_ecfp
adata_ood_4.uns['ecfp_dict'] = drug_to_ecfp
adata_ood_5.uns['ecfp_dict'] = drug_to_ecfp

In [58]:
adata_train_1.obs["control"] = adata_train_1.obs["control"].astype("bool")
adata_train_2.obs["control"] = adata_train_2.obs["control"].astype("bool")
adata_train_3.obs["control"] = adata_train_3.obs["control"].astype("bool")
adata_train_4.obs["control"] = adata_train_4.obs["control"].astype("bool")
adata_train_5.obs["control"] = adata_train_5.obs["control"].astype("bool")

adata_test_1.obs["control"] = adata_test_1.obs["control"].astype("bool")
adata_test_2.obs["control"] = adata_test_2.obs["control"].astype("bool")
adata_test_3.obs["control"] = adata_test_3.obs["control"].astype("bool")
adata_test_4.obs["control"] = adata_test_4.obs["control"].astype("bool")
adata_test_5.obs["control"] = adata_test_5.obs["control"].astype("bool")

adata_ood_1.obs["control"] = adata_ood_1.obs["control"].astype("bool")
adata_ood_2.obs["control"] = adata_ood_2.obs["control"].astype("bool")
adata_ood_3.obs["control"] = adata_ood_3.obs["control"].astype("bool")
adata_ood_4.obs["control"] = adata_ood_4.obs["control"].astype("bool")
adata_ood_5.obs["control"] = adata_ood_5.obs["control"].astype("bool")

In [59]:
import os
# TODO: add cell line embedding to adata.uns
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex"

adata_train_1.write(os.path.join(output_dir, f"adata_train_1.h5ad"))
adata_ood_1.write(os.path.join(output_dir, f"adata_ood_1.h5ad"))
adata_test_1.write(os.path.join(output_dir, f"adata_test_1.h5ad"))

adata_train_2.write(os.path.join(output_dir, f"adata_train_2.h5ad"))
adata_ood_2.write(os.path.join(output_dir, f"adata_ood_2.h5ad"))
adata_test_2.write(os.path.join(output_dir, f"adata_test_2.h5ad"))

adata_train_3.write(os.path.join(output_dir, f"adata_train_3.h5ad"))
adata_ood_3.write(os.path.join(output_dir, f"adata_ood_3.h5ad"))
adata_test_3.write(os.path.join(output_dir, f"adata_test_3.h5ad"))

adata_train_4.write(os.path.join(output_dir, f"adata_train_4.h5ad"))
adata_ood_4.write(os.path.join(output_dir, f"adata_ood_4.h5ad"))
adata_test_4.write(os.path.join(output_dir, f"adata_test_4.h5ad"))

adata_train_5.write(os.path.join(output_dir, f"adata_train_5.h5ad"))
adata_ood_5.write(os.path.join(output_dir, f"adata_ood_5.h5ad"))
adata_test_5.write(os.path.join(output_dir, f"adata_test_5.h5ad"))