In [1]:
import scanpy as sc
import pertpy
import numpy as np
import pandas as pd
from rdkit import Chem

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(pertpy.__version__)

0.7.0


In [3]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex"

In [4]:
adata = pertpy.data.sciplex3_raw()

In [5]:
adata

AnnData object with n_obs × n_vars = 581777 × 58347
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1'

In [6]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [7]:
sc.pp.highly_variable_genes(adata, inplace=True, n_top_genes=2000)

  disp_grouped = df.groupby("mean_bin")["dispersions"]


In [8]:
adata = adata[:,adata.var["highly_variable"]==True]

In [9]:
adata

View of AnnData object with n_obs × n_vars = 581777 × 2000
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'hvg'

In [10]:
adata.obs["perturbation"] = adata.obs["product_name"]
adata.obs["cell_line"] = adata.obs["cell_type"]

  adata.obs["perturbation"] = adata.obs["product_name"]


In [11]:
adata = adata[~adata.obs["perturbation"].isnull()]

In [12]:
adata.n_obs

581777

In [13]:
adata = adata[adata.obs["dose"].isin((0, 10_000))] # use all doses


In [14]:
adata.n_obs

139266

In [15]:
pertpy.md.Compound().annotate_compounds(adata)

[bold blue]There are 189 identifiers in `adata.obs`.However, 6 identifiers can't be found in the compound annotation,leading to the presence of NA values for their respective metadata.

- Please check again: 
- Vehicle
- Glesatinib?(MGCD265)
- Ivosidenib (AG-120)
- MC1568
- Dacinostat (LAQ824)
- ...


AnnData object with n_obs × n_vars = 139266 × 2000
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'perturbation', 'cell_line', 'pubchem_name', 'pubchem_ID', 'smiles'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'hvg'

In [16]:
smiles_dict = {"Dacinostat (LAQ824)": "C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)C=CC(=O)NO",
              "Glesatinib?(MGCD265)": "COCCNCc1ccc(nc1)c2cc3c(s2)c(ccn3)Oc4ccc(cc4F)NC(=S)NC(=O)Cc5ccc(cc5)F",
              "MC1568": "CN1C=C(C=C1C=CC(=O)NO)C=CC(=O)C2=CC(=CC=C2)F",
              "Ivosidenib (AG-120)": "ClC1=C([C@H](N(C2=CC(F)=CN=C2)C([C@H]3N(C4=NC=CC(C#N)=C4)C(CC3)=O)=O)C(NC5CC(F)(F)C5)=O)C=CC=C1",
               "Bisindolylmaleimide IX (Ro 31-8220 Mesylate)": "NC(SCCCN1C=C(C2=C(C3=CN(C)C4=C3C=CC=C4)C(NC2=O)=O)C5=C1C=CC=C5)=N.CS(=O)(O)=O",
              }

In [17]:
adata.obs["smiles"].value_counts(dropna=False)

smiles
NaN                                                                                       15474
C1CCC(C1)C(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                                               1786
CC(=O)NC1=CC=C(C=C1)OCC(C)(C(=O)NC2=CC(=C(C=C2)[N+](=O)[O-])C(F)(F)F)O                      983
CCS(=O)(=O)N1CC(C1)(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                                       974
CCCCCOC(=O)NC1=NC(=O)N(C=C1F)C2C(C(C(O2)C)O)O                                               970
                                                                                          ...  
CCS(=O)(=O)C1=CC=CC(=C1)C2=CC(=C(C3=C2C4=C(N3)N=CC(=C4)C)C)C(=O)NC5CCN(CC5)C                229
CC1CC(C(C(C=C(C(C(C=CC=C(C(=O)NC2=CC(=O)C(=C(C1)C2=O)NCCN(C)C)C)OC)OC(=O)N)C)C)O)OC.Cl      206
CN1CCC(C(C1)O)C2=C(C=C(C3=C2OC(=CC3=O)C4=CC=CC=C4Cl)O)O.Cl                                  145
COC1=C(C=C(C=C1)CS(=O)(=O)C=CC2=C(C=C(C=C2OC)OC)OC)NCC(=O)[O-].[Na+]                        137
CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=C

In [18]:
adata.obs["smiles"] = adata.obs[["smiles", "perturbation"]].apply(lambda x: smiles_dict[x["perturbation"]] if x["perturbation"] in smiles_dict.keys() else x["smiles"], axis=1)

In [19]:
adata.obs["smiles"].value_counts(dropna=False)

smiles
NaN                                                                                       13004
C1CCC(C1)C(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                                               1786
CC(=O)NC1=CC=C(C=C1)OCC(C)(C(=O)NC2=CC(=C(C=C2)[N+](=O)[O-])C(F)(F)F)O                      983
CCS(=O)(=O)N1CC(C1)(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                                       974
CCCCCOC(=O)NC1=NC(=O)N(C=C1F)C2C(C(C(O2)C)O)O                                               970
                                                                                          ...  
CC1CC(C(C(C=C(C(C(C=CC=C(C(=O)NC2=CC(=O)C(=C(C1)C2=O)NCCN(C)C)C)OC)OC(=O)N)C)C)O)OC.Cl      206
CN1CCC(C(C1)O)C2=C(C=C(C3=C2OC(=CC3=O)C4=CC=CC=C4Cl)O)O.Cl                                  145
COC1=C(C=C(C=C1)CS(=O)(=O)C=CC2=C(C=C(C=C2OC)OC)OC)NCC(=O)[O-].[Na+]                        137
CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=NC=CN=C4.[Br-]                                113
NC(SCCCN1C=C(C2=C(C3=CN(C)C4=C3C=

In [20]:
adata[adata.obs["smiles"].isnull()].obs.dose.value_counts()


dose
0.0    13004
Name: count, dtype: int64

In [21]:
adata[~adata.obs["smiles"].isnull()].obs.dose.value_counts()

dose
10000.0    126262
Name: count, dtype: int64

In [22]:
from rdkit.Chem import AllChem

def get_fp(smiles, radius = 4, nBits=1024):
    m = Chem.MolFromSmiles(smiles,sanitize=False)
    if m is None:
        return "invalid"
    else:
        try:
            Chem.SanitizeMol(m)
        except:
            return "invalid"
    return AllChem.GetHashedMorganFingerprint(m, radius=radius, nBits=nBits)

In [23]:
smiles_to_fp = {}
for sm in adata.obs["smiles"].unique():
    if not isinstance(sm, str):
        continue
    smiles_to_fp[sm] = np.array(list(get_fp(sm)))

In [24]:
features_df = pd.DataFrame.from_dict(smiles_to_fp).T

In [25]:
normalized_df = (features_df - features_df.mean()) / features_df.std()

In [26]:
features_cells = np.zeros((adata.shape[0], normalized_df.shape[1]))
for mol, ecfp in normalized_df.iterrows():
    features_cells[adata.obs["smiles"].isin([mol])] = ecfp.values

In [27]:
adata.obsm["ecfp"] = features_cells

In [28]:
test_drugs = ["Dacinostat (LAQ824)", "Givinostat (ITF2357)", "Belinostat (PXD101)", "Hesperadin", "Quisinostat (JNJ-26481585) 2HCl", "Alvespimycin (17-DMAG) HCl", "Tanespimycin (17-AAG)", "TAK-901", "Flavopiridol HCl"]

In [29]:
remaining_drugs = list(set(adata.obs["perturbation"].unique()) - set(test_drugs).union(set(["control"])))

In [30]:
rng = np.random.default_rng(0)

In [31]:
val_drugs_idx = rng.choice(len(remaining_drugs), 9, replace=False)

In [32]:
val_drugs = [remaining_drugs[el] for el in val_drugs_idx]

In [33]:
val_drugs

['AMG-900',
 'Sorafenib Tosylate',
 'Thalidomide',
 'Ki16425',
 'Meprednisone',
 'GSK1070916',
 'Capecitabine ',
 'PD98059',
 'Aminoglutethimide ']

In [34]:
adata_train = adata[~adata.obs["perturbation"].isin(test_drugs+val_drugs)].copy()
adata_val = adata[adata.obs["perturbation"].isin(val_drugs)].copy()
adata_test = adata[adata.obs["perturbation"].isin(test_drugs)].copy()

In [35]:
adata_train.n_obs, adata_val.n_obs, adata_test.n_obs

(128248, 7145, 3873)

In [36]:
adata.obs["perturbation"] = adata.obs.apply(lambda x: "control" if x["perturbation"]=="Vehicle" else x["perturbation"], axis=1)

In [37]:
adata_train.obs["perturbation"].value_counts()

perturbation
Vehicle                                         13004
Andarine                                          983
Baricitinib (LY3009104, INCB028050)               974
Fasudil (HA-1077) HCl                             966
Ruxolitinib (INCB018424)                          962
                                                ...  
Fedratinib (SAR302503, TG101348)                  314
Regorafenib (BAY 73-4506)                         283
Rigosertib (ON-01910)                             137
YM155 (Sepantronium Bromide)                      113
Bisindolylmaleimide IX (Ro 31-8220 Mesylate)       62
Name: count, Length: 171, dtype: int64

In [38]:
split_dfs = []
for drug in adata_train.obs["perturbation"].unique():
    for cell_line in adata_train.obs["cell_line"].unique():
        adata_subset = adata_train[(adata_train.obs["perturbation"]==drug) & (adata_train.obs["cell_line"]==cell_line)]
        n_cells = adata_subset.n_obs
        if drug == "control":
            assert n_cells > 3_000
            idx_test = rng.choice(np.arange(n_cells), 500, replace=False)
            remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
            idx_valid = rng.choice(list(remaining_idcs), 500, replace=False)
            split = ["test" if idx in idx_test else ("valid" if idx in idx_valid else "train") for idx in range(n_cells)]
            df = adata_subset.obs[["perturbation", "cell_line"]].copy()
            df["split"] = split
            split_dfs.append(df)
        elif n_cells>300:
            idx_test = rng.choice(np.arange(n_cells), 100, replace=False)
            split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
            df = adata_subset.obs[["perturbation", "cell_line"]].copy()
            df["split"] = split
            split_dfs.append(df)
        else:
            df = adata_subset.obs[["perturbation", "cell_line"]].copy()
            df["split"] = "train"
            split_dfs.append(df)

In [39]:
df_concat = pd.concat(split_dfs, axis=0)

In [40]:
len(df_concat), adata_train.n_obs

(128248, 128248)

In [41]:
df_concat["split"].value_counts()

split
train    117148
test      11100
Name: count, dtype: int64

In [42]:
adata_train.obs["split"] = df_concat["split"]

In [43]:
import anndata
adata_train_final = adata_train[adata_train.obs["split"]=="train"]
adata_valid_final = anndata.concat((adata_val, adata_train[adata_train.obs["split"]=="valid"]))
adata_test_final = anndata.concat((adata_test, adata_train[adata_train.obs["split"]=="test"]))

In [44]:
adata_valid_final.obs["perturbation"].value_counts()

perturbation
Capecitabine          970
Meprednisone          958
Aminoglutethimide     938
PD98059               934
Ki16425               880
Thalidomide           818
AMG-900               696
GSK1070916            496
Sorafenib Tosylate    455
Name: count, dtype: int64

In [45]:
adata_train.obs["perturbation"].value_counts()

perturbation
Vehicle                                         13004
Andarine                                          983
Baricitinib (LY3009104, INCB028050)               974
Fasudil (HA-1077) HCl                             966
Ruxolitinib (INCB018424)                          962
                                                ...  
Fedratinib (SAR302503, TG101348)                  314
Regorafenib (BAY 73-4506)                         283
Rigosertib (ON-01910)                             137
YM155 (Sepantronium Bromide)                      113
Bisindolylmaleimide IX (Ro 31-8220 Mesylate)       62
Name: count, Length: 171, dtype: int64

In [46]:
adata_train.obs["cell_line"].value_counts()

cell_line
MCF7    64627
A549    32170
K562    31451
Name: count, dtype: int64

In [47]:
adata_val.obs["perturbation"].value_counts()

perturbation
Capecitabine          970
Meprednisone          958
Aminoglutethimide     938
PD98059               934
Ki16425               880
Thalidomide           818
AMG-900               696
GSK1070916            496
Sorafenib Tosylate    455
Name: count, dtype: int64

In [48]:
adata_val.obs["cell_line"].value_counts()

cell_line
MCF7    3615
K562    1837
A549    1693
Name: count, dtype: int64

In [49]:
adata_test_final.obs["perturbation"].value_counts()

perturbation
Belinostat (PXD101)                630
Givinostat (ITF2357)               623
Quisinostat (JNJ-26481585) 2HCl    618
Dacinostat (LAQ824)                608
Hesperadin                         438
                                  ... 
Celecoxib                          100
WP1066                             100
SL-327                             100
BRD4770                            100
PFI-1 (PF-6405761)                 100
Name: count, Length: 118, dtype: int64

In [50]:
adata_test_final.obs["cell_line"].value_counts()

cell_line
MCF7    13209
A549     1024
K562      740
Name: count, dtype: int64

In [51]:
assert len(set(test_drugs).intersection(adata_train.obs["perturbation"])) == 0
assert len(set(val_drugs).intersection(adata_train.obs["perturbation"])) == 0
assert len(set(val_drugs).intersection(adata_test.obs["perturbation"])) == 0

In [52]:
adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T

  adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T


In [53]:
from scipy.sparse import csr_matrix, vstack
mean = csr_matrix(adata_train_final.varm["X_train_mean"].T)
adata_train_final.X = adata_train_final.X - vstack([mean]*len(adata_train_final))

In [54]:
sc.pp.pca(adata_train_final)

In [55]:
adata_valid_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]

In [56]:
adata_valid_final.X = adata_valid_final.X - vstack([mean]*len(adata_valid_final))

In [57]:
adata_valid_final.obsm["X_pca"] = np.matmul(adata_valid_final.X.A, adata_train_final.varm["PCs"])

In [58]:
adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]

In [59]:
adata_test_final.X = adata_test_final.X - vstack([mean]*len(adata_test_final))

In [60]:
adata_test_final.obsm["X_pca"] = np.matmul(adata_test_final.X.A, adata_train_final.varm["PCs"])

In [61]:
adata_test_final.obs["perturbation"].value_counts()

perturbation
Belinostat (PXD101)                630
Givinostat (ITF2357)               623
Quisinostat (JNJ-26481585) 2HCl    618
Dacinostat (LAQ824)                608
Hesperadin                         438
                                  ... 
Celecoxib                          100
WP1066                             100
SL-327                             100
BRD4770                            100
PFI-1 (PF-6405761)                 100
Name: count, Length: 118, dtype: int64

In [71]:
adata_train_final.varm["X_train_mean"] = np.asarray(adata_train_final.varm["X_train_mean"])
adata_valid_final.varm["X_train_mean"] = np.asarray(adata_valid_final.varm["X_train_mean"])
adata_test_final.varm["X_train_mean"] = np.asarray(adata_test_final.varm["X_train_mean"])

In [72]:
import os
adata_train_final.write(os.path.join(output_dir, "adata_train.h5ad"))
adata_valid_final.write(os.path.join(output_dir, "adata_valid.h5ad"))
adata_test_final.write(os.path.join(output_dir, "adata_test.h5ad"))