In [None]:
import numpy as np
import pandas as pd
import pertpy
import scanpy as sc
from rdkit import Chem

In [None]:
print(pertpy.__version__)

In [None]:
df_cell_line = pd.read_csv(
    "/lustre/groups/ml01/workspace/alejandro.tejada/super_rad_project/cell_line_embeddings/cell_line_embedding_full_ccle_300.csv"
)

In [None]:
df_cell_line

In [None]:
mcf7_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "MCF7"].iloc[0, 1:].values
k562_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "K562"].iloc[0, 1:].values
a549_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "A549"].iloc[0, 1:].values

In [None]:
adata = pertpy.data.sciplex3_raw()

In [None]:
adata.obs["cell_type"].value_counts()

In [None]:
adata.obs["perturbation"] = adata.obs["product_name"]
adata.obs["drug"] = adata.obs.apply(lambda x: x["product_name"].replace(" ", "_"), axis=1)
adata.obs["cell_line"] = adata.obs["cell_type"]
adata.obs["logdose"] = adata.obs.apply(lambda x: np.log10(x["dose"]) if x["dose"] > 0.0 else 0.0, axis=1)
adata.obs.loc[:, "condition"] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.drug, str(x.dose)]), axis=1)

In [None]:
def get_cell_line_embedding(x):
    if x["cell_line"] == "MCF7":
        return mcf7_emb.astype("float")[None, :]
    elif x["cell_line"] == "A549":
        return a549_emb.astype("float")[None, :]
    elif x["cell_line"] == "K562":
        return k562_emb.astype("float")[None, :]


adata.obsm["cell_line_emb"] = np.concatenate(adata.obs.apply(get_cell_line_embedding, axis=1).values, axis=0)

In [None]:
adata.n_obs

In [None]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [None]:
sc.pp.highly_variable_genes(adata, inplace=True, n_top_genes=2000)

In [None]:
adata = adata[:, adata.var["highly_variable"] == True]

In [None]:
adata = adata[~adata.obs["drug"].isnull()]

In [None]:
pertpy.md.Compound().annotate_compounds(adata)

In [None]:
# taken from pubchem
smiles_dict = {
    "Dacinostat (LAQ824)": "C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)/C=C/C(=O)NO",
    "Glesatinib?(MGCD265)": "COCCNCC1=CN=C(C=C1)C2=CC3=NC=CC(=C3S2)OC4=C(C=C(C=C4)NC(=S)NC(=O)CC5=CC=C(C=C5)F)F",
    "MC1568": "CN1C=C(C=C1/C=C/C(=O)NO)/C=C/C(=O)C2=CC(=CC=C2)F",
    "Ivosidenib (AG-120)": "C1CC(=O)N([C@@H]1C(=O)N(C2=CC(=CN=C2)F)[C@@H](C3=CC=CC=C3Cl)C(=O)NC4CC(C4)(F)F)C5=NC=CC(=C5)C#N",
    "Bisindolylmaleimide IX (Ro 31-8220 Mesylate)": "CN1C=C(C2=CC=CC=C21)C3=C(C(=O)NC3=O)C4=CN(C5=CC=CC=C54)CCCSC(=N)N",
}

In [None]:
adata.obs["smiles"].value_counts(dropna=False)

In [None]:
adata.obs["smiles"] = adata.obs[["smiles", "perturbation"]].apply(
    lambda x: smiles_dict[x["perturbation"]] if x["perturbation"] in smiles_dict.keys() else x["smiles"], axis=1
)

In [None]:
adata.obs["smiles"].value_counts(dropna=False)

In [None]:
adata[adata.obs["smiles"].isnull()].obs.dose.value_counts()

In [None]:
adata[~adata.obs["smiles"].isnull()].obs.dose.value_counts()

In [None]:
from rdkit.Chem import AllChem


def get_fp(smiles, radius=4, nBits=1024):
    m = Chem.MolFromSmiles(smiles, sanitize=False)
    if m is None:
        return "invalid"
    else:
        try:
            Chem.SanitizeMol(m)
        except:
            return "invalid"
    return AllChem.GetHashedMorganFingerprint(m, radius=radius, nBits=nBits)

In [None]:
smiles_to_fp = {}
for sm in adata.obs["smiles"].unique():
    if not isinstance(sm, str):
        continue
    smiles_to_fp[sm] = np.array(list(get_fp(sm)))

In [None]:
features_df = pd.DataFrame.from_dict(smiles_to_fp).T

In [None]:
features_cells = np.zeros((adata.shape[0], features_df.shape[1]))
for mol, ecfp in features_df.iterrows():
    features_cells[adata.obs["smiles"].isin([mol])] = ecfp.values

In [None]:
adata.obsm["ecfp"] = features_cells

In [None]:
def rank_genes_groups_by_cov(
    adata,
    groupby,
    control_group,
    covariate,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):
    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        # name of the control group in the groupby obs column
        control_group_cov = "_".join([cov_cat, "Vehicle", "0.0"])
        adata_cov = adata[adata.obs[covariate] == cov_cat]

        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
            use_raw=False,
        )
        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()
    adata.uns[key_added] = gene_dict
    if return_dict:
        return gene_dict


def get_DE_genes(adata):
    adata.obs.loc[:, "control"] = adata.obs.perturbation.apply(lambda x: 1 if x == "Vehicle" else 0)
    adata.obs = adata.obs.astype("category")
    rank_genes_groups_by_cov(
        adata,
        groupby="condition",
        covariate="cell_type",
        control_group="Vehicle",
        n_genes=50,
        key_added="rank_genes_groups_cov_all",
    )
    return adata

In [None]:
for col in adata.obs.select_dtypes(include=["category"]):
    adata.obs[col].cat.remove_unused_categories()

In [None]:
adata = get_DE_genes(adata)

In [None]:
ood_drugs = [
    "Dacinostat",
    "Givinostat",
    "Belinostat",
    "Hesperadin",
    "Quisinostat",
    "Alvespimycin",
    "Tanespimycin",
    "TAK-901",
    "Flavopiridol",
]

In [None]:
[d for d in ood_drugs if d in adata.obs["drug"].unique()]

In [None]:
[el for el in adata.obs["drug"].unique() if "acinostat" in el]

In [None]:
[el for el in adata.obs["drug"].unique() if "ivinostat" in el]

In [None]:
[el for el in adata.obs["drug"].unique() if "elinostat" in el]

In [None]:
[el for el in adata.obs["drug"].unique() if "Quisinostat" in el]

In [None]:
[el for el in adata.obs["drug"].unique() if "Alvespimycin" in el]

In [None]:
[el for el in adata.obs["drug"].unique() if "Tanespimycin" in el]

In [None]:
[el for el in adata.obs["drug"].unique() if "Flavopiridol" in el]

In [None]:
ood_drugs = [
    "Hesperadin",
    "TAK-901",
    "Dacinostat_(LAQ824)",
    "Givinostat_(ITF2357)",
    "Belinostat_(PXD101)",
    "Quisinostat_(JNJ-26481585)_2HCl",
    "Alvespimycin_(17-DMAG)_HCl",
    "Tanespimycin_(17-AAG)",
    "Flavopiridol_HCl",
]

In [None]:
len(ood_drugs)

In [None]:
adata.obs["ood"] = adata.obs.apply(lambda x: x["drug"] if x["drug"] in ood_drugs else "not ood", axis=1)
adata.obs["is_ood"] = adata.obs.apply(lambda x: x["drug"] in ood_drugs, axis=1)
adata.obs["ood"] = adata.obs["ood"].astype("category")

In [None]:
adata_ood = adata[adata.obs["is_ood"]]

In [None]:
adata_train = adata[~adata.obs["is_ood"]]

In [None]:
rng = np.random.default_rng(0)
split_dfs = []
for cond in adata_train.obs["condition"].unique():
    adata_subset = adata_train[(adata_train.obs["condition"] == cond)]
    n_cells = adata_subset.n_obs
    if "Vehicle" in cond:
        idx_test = rng.choice(np.arange(n_cells), 500, replace=False)
        remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
        split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
        df = adata_subset.obs[["condition"]].copy()
        df["split"] = split
        split_dfs.append(df)
    else:
        if adata_subset.n_obs >= 100:
            idx_test = rng.choice(np.arange(n_cells), 50, replace=False)
            split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
        else:
            split = ["train"] * adata_subset.n_obs
        df = adata_subset.obs[["condition"]].copy()
        df["split"] = split
        split_dfs.append(df)

In [None]:
df_concat = pd.concat(split_dfs, axis=0)

In [None]:
len(df_concat), adata_train.n_obs

In [None]:
adata_train.obs["split"] = df_concat[["split"]]

In [None]:
adata_ood[adata_ood.obs["condition"] == "control"].n_obs

In [None]:
adata_ood.obs["split"] = "ood"

In [None]:
import anndata

adata_train_final = adata_train[adata_train.obs["split"] == "train"]
adata_test_final = adata_train[adata_train.obs["split"] == "test"]
adata_ood_final = anndata.concat((adata_ood, adata_test_final[adata_test_final.obs["perturbation"] == "Vehicle"]))

In [None]:
adata_train_final.obs["condition"].value_counts()

In [None]:
adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T

In [None]:
from scipy.sparse import csr_matrix

train_mean = adata_train_final.varm["X_train_mean"].T
adata_train_final.layers["centered_X"] = csr_matrix(adata_train_final.X.A - train_mean)

In [None]:
adata_train_final.layers["X_log1p"] = adata_train_final.X.copy()
adata_train_final.X = adata_train_final.layers["centered_X"]

In [None]:
sc.pp.pca(adata_train_final, zero_center=False, n_comps=300)

In [None]:
adata_train_final.X = adata_train_final.layers["X_log1p"]

In [None]:
adata_ood_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]
adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]

In [None]:
adata_test_final.layers["centered_X"] = csr_matrix(adata_test_final.X.A - train_mean)
adata_ood_final.layers["centered_X"] = csr_matrix(adata_ood_final.X.A - train_mean)

In [None]:
adata_test_final.obsm["X_pca"] = np.matmul(adata_test_final.layers["centered_X"].A, adata_train_final.varm["PCs"])
adata_ood_final.obsm["X_pca"] = np.matmul(adata_ood_final.layers["centered_X"].A, adata_train_final.varm["PCs"])

In [None]:
adata_train_final.obs["logdose"] = adata_train_final.obs.apply(
    lambda x: np.log10(x["dose"]) if x["dose"] > 0.0 else 0.0, axis=1
)

In [None]:
adata_train_final.obsm["ecfp_cell_line"] = np.concatenate(
    (adata_train_final.obsm["ecfp"], adata_train_final.obsm["cell_line_emb"]), axis=1
)
adata_test_final.obsm["ecfp_cell_line"] = np.concatenate(
    (adata_test_final.obsm["ecfp"], adata_test_final.obsm["cell_line_emb"]), axis=1
)
adata_ood_final.obsm["ecfp_cell_line"] = np.concatenate(
    (adata_ood_final.obsm["ecfp"], adata_ood_final.obsm["cell_line_emb"]), axis=1
)

In [None]:
adata_train_final.obsm["ecfp_cell_line_dose"] = np.concatenate(
    (adata_train_final.obsm["ecfp_cell_line"], adata_train_final.obs["dose"].values[:, None]), axis=1
)
adata_test_final.obsm["ecfp_cell_line_dose"] = np.concatenate(
    (adata_test_final.obsm["ecfp_cell_line"], adata_test_final.obs["dose"].values[:, None]), axis=1
)
adata_ood_final.obsm["ecfp_cell_line_dose"] = np.concatenate(
    (adata_ood_final.obsm["ecfp_cell_line"], adata_ood_final.obs["dose"].values[:, None]), axis=1
)

In [None]:
adata_train_final.obsm["ecfp_cell_line_logdose"] = np.concatenate(
    (adata_train_final.obsm["ecfp_cell_line"], adata_train_final.obs["logdose"].values[:, None]), axis=1
)
adata_test_final.obsm["ecfp_cell_line_logdose"] = np.concatenate(
    (adata_test_final.obsm["ecfp_cell_line"], adata_test_final.obs["logdose"].values[:, None]), axis=1
)
adata_ood_final.obsm["ecfp_cell_line_logdose"] = np.concatenate(
    (adata_ood_final.obsm["ecfp_cell_line"], adata_ood_final.obs["logdose"].values[:, None]), axis=1
)

In [None]:
adata_train_final.obsm["ecfp_cell_line_logdose_more_dose"] = np.concatenate(
    (adata_train_final.obsm["ecfp_cell_line"], np.tile(adata_train_final.obs["logdose"].values[:, None], (1, 100))),
    axis=1,
)
adata_test_final.obsm["ecfp_cell_line_logdose_more_dose"] = np.concatenate(
    (adata_test_final.obsm["ecfp_cell_line"], np.tile(adata_test_final.obs["logdose"].values[:, None], (1, 100))),
    axis=1,
)
adata_ood_final.obsm["ecfp_cell_line_logdose_more_dose"] = np.concatenate(
    (adata_ood_final.obsm["ecfp_cell_line"], np.tile(adata_ood_final.obs["logdose"].values[:, None], (1, 100))), axis=1
)

In [None]:
adata_train_final.obsm["ecfp_cell_line_dose_more_dose"] = np.concatenate(
    (adata_train_final.obsm["ecfp_cell_line"], np.tile(adata_train_final.obs["dose"].values[:, None], (1, 100))), axis=1
)
adata_test_final.obsm["ecfp_cell_line_dose_more_dose"] = np.concatenate(
    (adata_test_final.obsm["ecfp_cell_line"], np.tile(adata_test_final.obs["dose"].values[:, None], (1, 100))), axis=1
)
adata_ood_final.obsm["ecfp_cell_line_dose_more_dose"] = np.concatenate(
    (adata_ood_final.obsm["ecfp_cell_line"], np.tile(adata_ood_final.obs["dose"].values[:, None], (1, 100))), axis=1
)

In [None]:
adata_train_final.varm["X_train_mean"] = np.asarray(adata_train_final.varm["X_train_mean"])
adata_test_final.varm["X_train_mean"] = np.asarray(adata_test_final.varm["X_train_mean"])
adata_ood_final.varm["X_train_mean"] = np.asarray(adata_ood_final.varm["X_train_mean"])

In [None]:
adata_ood_final.uns = adata_test_final.uns.copy()

In [None]:
import os

output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex"
adata_train_final.write(os.path.join(output_dir, "adata_train_biolord_split_all_300.h5ad"))
adata_ood_final.write(os.path.join(output_dir, "adata_ood_biolord_split_all_300.h5ad"))
adata_test_final.write(os.path.join(output_dir, "adata_test_biolord_split_all_300.h5ad"))

In [None]:
from sklearn.metrics import r2_score


def compute_r_squared(x: np.ndarray, y: np.ndarray) -> float:
    return r2_score(np.mean(x, axis=0), np.mean(y, axis=0))

In [None]:
decoded_test = np.matmul(adata_test_final.obsm["X_pca"], adata_train_final.varm["PCs"].T)
compute_r_squared(adata_test_final.X.A, np.asarray(decoded_test + adata_test_final.varm["X_train_mean"].T))

In [None]:
decoded_ood = np.matmul(adata_ood_final.obsm["X_pca"], adata_train_final.varm["PCs"].T)
compute_r_squared(adata_ood_final.X.A, np.asarray(decoded_ood + adata_ood_final.varm["X_train_mean"].T))

In [None]:
adata_ood_final.obs["condition"].value_counts()