In [1]:
import numpy as np
import pandas as pd
import pertpy
import scanpy as sc
from rdkit import Chem

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(pertpy.__version__)

0.7.0


In [3]:
df_cell_line = pd.read_csv(
    "/lustre/groups/ml01/workspace/alejandro.tejada/super_rad_project/cell_line_embeddings/cell_line_embedding_full_ccle_300.csv"
)

In [4]:
df_cell_line

Unnamed: 0,stripped_cell_line_name,0,1,2,3,4,5,6,7,8,...,290,291,292,293,294,295,296,297,298,299
0,LC1SQSF,-18.392076,-4.417238,-20.063112,11.465690,-13.339292,-12.112496,35.471240,-0.270913,-9.146312,...,-1.125420,-2.111881,1.171622,0.312624,-4.815182,0.046210,-1.088446,1.115714,2.255941,0.758469
1,COGAR359,18.690002,-52.472620,-30.365618,48.330795,-1.021920,-19.068594,10.263748,18.269790,-7.957102,...,-8.222550,7.555310,-1.057298,-3.757104,-4.835751,2.893510,-2.442304,-0.242649,-2.353538,5.478947
2,COLO794,-20.296114,-40.405025,44.187954,14.443377,54.673800,19.321410,66.251015,17.002750,-7.278839,...,-1.242781,1.199354,0.235240,-2.573665,-0.173201,1.236485,0.292051,3.587131,0.715897,-2.135028
3,KKU213,-27.366976,65.398420,14.705721,-7.561102,-13.943193,-10.408412,36.653570,25.901730,-14.653162,...,1.666245,-1.350878,2.853838,-1.963892,-6.147074,1.655094,4.646910,-0.388469,-3.185148,-0.161350
4,RT4,-12.577552,53.457300,-30.118093,-8.491399,-0.911463,-0.323962,12.441322,2.255388,-20.927351,...,-0.842232,3.961688,-1.545963,-3.391504,-1.337374,0.594023,5.249398,-6.942966,0.880904,0.362034
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1401,TOLEDO,116.668144,-17.291414,-8.687904,-47.281456,-8.376022,-19.178888,14.364631,23.961452,1.910983,...,2.444952,0.755310,0.995932,-0.341458,-1.487758,-5.967835,-0.450100,-4.981550,-3.410032,1.127219
1402,KP363T,-13.191258,56.058205,18.988297,9.880114,10.289624,-20.659739,7.638542,36.926937,-4.758129,...,-0.800371,-1.895617,-5.030682,-0.387408,1.123831,1.471802,-3.277741,-0.741038,-3.711716,6.372540
1403,SSP25,-58.022533,-6.140889,37.952465,-7.738521,-9.254046,-8.542123,1.232968,23.845469,4.071693,...,-1.679806,-2.305954,-1.516615,0.437813,-0.213782,2.173585,-3.377421,-1.699599,-0.910430,-3.234072
1404,ECC2,-33.736560,68.016460,0.001811,-11.769738,2.804548,0.707982,3.129512,26.379341,-0.923006,...,-0.507274,-4.123481,0.603723,-2.813056,1.446870,-1.526545,-1.199487,3.615918,0.856891,-0.973312


In [5]:
mcf7_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "MCF7"].iloc[0, 1:].values
k562_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "K562"].iloc[0, 1:].values
a549_emb = df_cell_line[df_cell_line["stripped_cell_line_name"] == "A549"].iloc[0, 1:].values

In [6]:
adata = pertpy.data.sciplex3_raw()

In [7]:
adata.obs["cell_type"].value_counts()

cell_type
MCF7    292010
K562    146752
A549    143015
Name: count, dtype: int64

In [8]:
adata.obs["perturbation"] = adata.obs["product_name"]
adata.obs["drug"] = adata.obs.apply(lambda x: x["product_name"].replace(" ", "_"), axis=1)
adata.obs["cell_line"] = adata.obs["cell_type"]
adata.obs["logdose"] = adata.obs.apply(lambda x: np.log10(x["dose"]) if x["dose"] > 0.0 else 0.0, axis=1)
adata.obs.loc[:, "condition"] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.drug, str(x.dose)]), axis=1)

In [9]:
def get_cell_line_embedding(x):
    if x["cell_line"] == "MCF7":
        return mcf7_emb.astype("float")[None, :]
    elif x["cell_line"] == "A549":
        return a549_emb.astype("float")[None, :]
    elif x["cell_line"] == "K562":
        return k562_emb.astype("float")[None, :]


adata.obsm["cell_line_emb"] = np.concatenate(adata.obs.apply(get_cell_line_embedding, axis=1).values, axis=0)

In [10]:
adata.n_obs

581777

In [11]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [12]:
sc.pp.highly_variable_genes(adata, inplace=True, n_top_genes=2000)

  disp_grouped = df.groupby("mean_bin")["dispersions"]


In [13]:
adata = adata[:, adata.var["highly_variable"] == True]

In [14]:
adata = adata[~adata.obs["drug"].isnull()]

In [15]:
pertpy.md.Compound().annotate_compounds(adata)

[bold blue]There are 189 identifiers in `adata.obs`.However, 6 identifiers can't be found in the compound annotation,leading to the presence of NA values for their respective metadata.

- Please check again: 
- Vehicle
- MC1568
- Glesatinib?(MGCD265)
- Dacinostat (LAQ824)
- Bisindolylmaleimide IX (Ro 31-8220 Mesylate)
- ...


AnnData object with n_obs × n_vars = 581777 × 2000
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'perturbation', 'drug', 'cell_line', 'logdose', 'condition', 'pubchem_name', 'pubchem_ID', 'smiles'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'hvg'
    obsm: 'cell_line_emb'

In [16]:
# taken from pubchem
smiles_dict = {
    "Dacinostat (LAQ824)": "C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)/C=C/C(=O)NO",
    "Glesatinib?(MGCD265)": "COCCNCC1=CN=C(C=C1)C2=CC3=NC=CC(=C3S2)OC4=C(C=C(C=C4)NC(=S)NC(=O)CC5=CC=C(C=C5)F)F",
    "MC1568": "CN1C=C(C=C1/C=C/C(=O)NO)/C=C/C(=O)C2=CC(=CC=C2)F",
    "Ivosidenib (AG-120)": "C1CC(=O)N([C@@H]1C(=O)N(C2=CC(=CN=C2)F)[C@@H](C3=CC=CC=C3Cl)C(=O)NC4CC(C4)(F)F)C5=NC=CC(=C5)C#N",
    "Bisindolylmaleimide IX (Ro 31-8220 Mesylate)": "CN1C=C(C2=CC=CC=C21)C3=C(C(=O)NC3=O)C4=CN(C5=CC=CC=C54)CCCSC(=N)N",
}

In [17]:
adata.obs["smiles"].value_counts(dropna=False)

smiles
NaN                                                                     27103
C1CCC(C1)C(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                             6627
COC1=CC=CC(=C1N)C2=CC(=O)C3=CC=CC=C3O2                                   3763
C1C(C1N)C2=CC=CC=C2.Cl                                                   3744
CC(C1=CC=CC=C1)NC(=O)C(=CC2=NC(=CC=C2)Br)C#N                             3722
                                                                        ...  
C1CC1NC(=O)NC2=C(NN=C2)C3=NC4=C(N3)C=C(C=C4)CN5CCOCC5                    1795
CC1CCCC2(C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C)C     1481
CN1CCC(C(C1)O)C2=C(C=C(C3=C2OC(=CC3=O)C4=CC=CC=C4Cl)O)O.Cl               1407
CC1CCCC2C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C        1221
CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=NC=CN=C4.[Br-]              770
Name: count, Length: 183, dtype: int64

In [18]:
adata.obs["smiles"] = adata.obs[["smiles", "perturbation"]].apply(
    lambda x: smiles_dict[x["perturbation"]] if x["perturbation"] in smiles_dict.keys() else x["smiles"], axis=1
)

In [19]:
adata.obs["smiles"].value_counts(dropna=False)

smiles
NaN                                                                     13004
C1CCC(C1)C(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3                             6627
COC1=CC=CC(=C1N)C2=CC(=O)C3=CC=CC=C3O2                                   3763
C1C(C1N)C2=CC=CC=C2.Cl                                                   3744
CC(C1=CC=CC=C1)NC(=O)C(=CC2=NC(=CC=C2)Br)C#N                             3722
                                                                        ...  
C1CC1NC(=O)NC2=C(NN=C2)C3=NC4=C(N3)C=C(C=C4)CN5CCOCC5                    1795
CC1CCCC2(C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C)C     1481
CN1CCC(C(C1)O)C2=C(C=C(C3=C2OC(=CC3=O)C4=CC=CC=C4Cl)O)O.Cl               1407
CC1CCCC2C(O2)CC(OC(=O)CC(C(C(=O)C(C1O)C)(C)C)O)C(=CC3=CSC(=N3)C)C        1221
CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=NC=CN=C4.[Br-]              770
Name: count, Length: 188, dtype: int64

In [20]:
adata[adata.obs["smiles"].isnull()].obs.dose.value_counts()

dose
0.0    13004
Name: count, dtype: int64

In [21]:
adata[~adata.obs["smiles"].isnull()].obs.dose.value_counts()

dose
10.0       153013
100.0      147670
1000.0     141828
10000.0    126262
Name: count, dtype: int64

In [22]:
from rdkit.Chem import AllChem


def get_fp(smiles, radius=4, nBits=1024):
    m = Chem.MolFromSmiles(smiles, sanitize=False)
    if m is None:
        return "invalid"
    else:
        try:
            Chem.SanitizeMol(m)
        except:
            return "invalid"
    return AllChem.GetHashedMorganFingerprint(m, radius=radius, nBits=nBits)

In [23]:
smiles_to_fp = {}
for sm in adata.obs["smiles"].unique():
    if not isinstance(sm, str):
        continue
    smiles_to_fp[sm] = np.array(list(get_fp(sm)))

In [24]:
features_df = pd.DataFrame.from_dict(smiles_to_fp).T

In [25]:
features_cells = np.zeros((adata.shape[0], features_df.shape[1]))
for mol, ecfp in features_df.iterrows():
    features_cells[adata.obs["smiles"].isin([mol])] = ecfp.values

In [26]:
adata.obsm["ecfp"] = features_cells

In [27]:
def rank_genes_groups_by_cov(
    adata,
    groupby,
    control_group,
    covariate,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):
    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        # name of the control group in the groupby obs column
        control_group_cov = "_".join([cov_cat, "Vehicle", "0.0"])
        adata_cov = adata[adata.obs[covariate] == cov_cat]

        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
            use_raw=False,
        )
        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()
    adata.uns[key_added] = gene_dict
    if return_dict:
        return gene_dict


def get_DE_genes(adata):
    adata.obs.loc[:, "control"] = adata.obs.perturbation.apply(lambda x: 1 if x == "Vehicle" else 0)
    adata.obs = adata.obs.astype("category")
    rank_genes_groups_by_cov(
        adata,
        groupby="condition",
        covariate="cell_type",
        control_group="Vehicle",
        n_genes=50,
        key_added="rank_genes_groups_cov_all",
    )
    return adata

In [28]:
for col in adata.obs.select_dtypes(include=["category"]):
    adata.obs[col].cat.remove_unused_categories()

In [29]:
adata = get_DE_genes(adata)



  adata.uns[key_added] = {}
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global



  adata.uns[key_added] = {}
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global



  adata.uns[key_added] = {}
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global_indices]
  self.stats[group_name, 'scores'] = scores[global_indices]
  self.stats[group_name, 'pvals'] = pvals[global_indices]
  self.stats[group_name, 'pvals_adj'] = pvals_adj[global_indices]
  self.stats[group_name, 'logfoldchanges'] = np.log2(
  self.stats[group_name, 'names'] = self.var_names[global

In [30]:
ood_drugs = [
    "Dacinostat",
    "Givinostat",
    "Belinostat",
    "Hesperadin",
    "Quisinostat",
    "Alvespimycin",
    "Tanespimycin",
    "TAK-901",
    "Flavopiridol",
]

In [31]:
[d for d in ood_drugs if d in adata.obs["drug"].unique()]

['Hesperadin', 'TAK-901']

In [32]:
[el for el in adata.obs["drug"].unique() if "acinostat" in el]

['Dacinostat_(LAQ824)', 'Pracinostat_(SB939)']

In [33]:
[el for el in adata.obs["drug"].unique() if "ivinostat" in el]

['Givinostat_(ITF2357)']

In [34]:
[el for el in adata.obs["drug"].unique() if "elinostat" in el]

['Belinostat_(PXD101)']

In [35]:
[el for el in adata.obs["drug"].unique() if "Quisinostat" in el]

['Quisinostat_(JNJ-26481585)_2HCl']

In [36]:
[el for el in adata.obs["drug"].unique() if "Alvespimycin" in el]

['Alvespimycin_(17-DMAG)_HCl']

In [37]:
[el for el in adata.obs["drug"].unique() if "Tanespimycin" in el]

['Tanespimycin_(17-AAG)']

In [38]:
[el for el in adata.obs["drug"].unique() if "Flavopiridol" in el]

['Flavopiridol_HCl']

In [39]:
ood_drugs = [
    "Hesperadin",
    "TAK-901",
    "Dacinostat_(LAQ824)",
    "Givinostat_(ITF2357)",
    "Belinostat_(PXD101)",
    "Quisinostat_(JNJ-26481585)_2HCl",
    "Alvespimycin_(17-DMAG)_HCl",
    "Tanespimycin_(17-AAG)",
    "Flavopiridol_HCl",
]

In [40]:
len(ood_drugs)

9

In [41]:
adata.obs["ood"] = adata.obs.apply(lambda x: x["drug"] if x["drug"] in ood_drugs else "not ood", axis=1)
adata.obs["is_ood"] = adata.obs.apply(lambda x: x["drug"] in ood_drugs, axis=1)
adata.obs["ood"] = adata.obs["ood"].astype("category")

In [42]:
adata_ood = adata[adata.obs["is_ood"]]

In [43]:
adata_train = adata[~adata.obs["is_ood"]]

In [44]:
rng = np.random.default_rng(0)
split_dfs = []
for cond in adata_train.obs["condition"].unique():
    adata_subset = adata_train[(adata_train.obs["condition"] == cond)]
    n_cells = adata_subset.n_obs
    if "Vehicle" in cond:
        idx_test = rng.choice(np.arange(n_cells), 500, replace=False)
        remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
        split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
        df = adata_subset.obs[["condition"]].copy()
        df["split"] = split
        split_dfs.append(df)
    else:
        if adata_subset.n_obs >= 100:
            idx_test = rng.choice(np.arange(n_cells), 50, replace=False)
            split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
        else:
            split = ["train"] * adata_subset.n_obs
        df = adata_subset.obs[["condition"]].copy()
        df["split"] = split
        split_dfs.append(df)

In [45]:
df_concat = pd.concat(split_dfs, axis=0)

In [46]:
len(df_concat), adata_train.n_obs

(561285, 561285)

In [47]:
adata_train.obs["split"] = df_concat[["split"]]

  adata_train.obs["split"] = df_concat[["split"]]


In [48]:
adata_ood[adata_ood.obs["condition"] == "control"].n_obs

0

In [49]:
adata_ood.obs["split"] = "ood"

  adata_ood.obs["split"] = "ood"


In [50]:
import anndata

adata_train_final = adata_train[adata_train.obs["split"] == "train"]
adata_test_final = adata_train[adata_train.obs["split"] == "test"]
adata_ood_final = anndata.concat((adata_ood, adata_test_final[adata_test_final.obs["perturbation"] == "Vehicle"]))

In [51]:
adata_train_final.obs["condition"].value_counts()

condition
MCF7_Vehicle_0.0                                             5858
K562_Vehicle_0.0                                             2859
A549_Vehicle_0.0                                             2787
MCF7_Mesna__100.0                                             657
MCF7_Fasudil_(HA-1077)_HCl_10000.0                            543
                                                             ... 
A549_YM155_(Sepantronium_Bromide)_10000.0                      20
MCF7_YM155_(Sepantronium_Bromide)_10000.0                      18
MCF7_YM155_(Sepantronium_Bromide)_100.0                        13
MCF7_YM155_(Sepantronium_Bromide)_1000.0                       11
K562_Bisindolylmaleimide_IX_(Ro_31-8220_Mesylate)_10000.0       7
Name: count, Length: 2151, dtype: int64

In [52]:
adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T

  adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T


In [53]:
from scipy.sparse import csr_matrix

train_mean = adata_train_final.varm["X_train_mean"].T
adata_train_final.layers["centered_X"] = csr_matrix(adata_train_final.X.A - train_mean)

In [54]:
adata_train_final.layers["X_log1p"] = adata_train_final.X.copy()
adata_train_final.X = adata_train_final.layers["centered_X"]

In [55]:
sc.pp.pca(adata_train_final, zero_center=False, n_comps=30)

In [56]:
adata_train_final.X = adata_train_final.layers["X_log1p"]

In [57]:
adata_ood_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]
adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]

  adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]


In [58]:
adata_test_final.layers["centered_X"] = csr_matrix(adata_test_final.X.A - train_mean)
adata_ood_final.layers["centered_X"] = csr_matrix(adata_ood_final.X.A - train_mean)

In [59]:
adata_test_final.obsm["X_pca"] = np.matmul(adata_test_final.layers["centered_X"].A, adata_train_final.varm["PCs"])
adata_ood_final.obsm["X_pca"] = np.matmul(adata_ood_final.layers["centered_X"].A, adata_train_final.varm["PCs"])

In [60]:
adata_train_final.obs["logdose"] = adata_train_final.obs.apply(
    lambda x: np.log10(x["dose"]) if x["dose"] > 0.0 else 0.0, axis=1
)

In [61]:
adata_train_final.obsm["ecfp_cell_line"] = np.concatenate(
    (adata_train_final.obsm["ecfp"], adata_train_final.obsm["cell_line_emb"]), axis=1
)
adata_test_final.obsm["ecfp_cell_line"] = np.concatenate(
    (adata_test_final.obsm["ecfp"], adata_test_final.obsm["cell_line_emb"]), axis=1
)
adata_ood_final.obsm["ecfp_cell_line"] = np.concatenate(
    (adata_ood_final.obsm["ecfp"], adata_ood_final.obsm["cell_line_emb"]), axis=1
)

In [62]:
adata_train_final.obsm["ecfp_cell_line_dose"] = np.concatenate(
    (adata_train_final.obsm["ecfp_cell_line"], adata_train_final.obs["dose"].values[:, None]), axis=1
)
adata_test_final.obsm["ecfp_cell_line_dose"] = np.concatenate(
    (adata_test_final.obsm["ecfp_cell_line"], adata_test_final.obs["dose"].values[:, None]), axis=1
)
adata_ood_final.obsm["ecfp_cell_line_dose"] = np.concatenate(
    (adata_ood_final.obsm["ecfp_cell_line"], adata_ood_final.obs["dose"].values[:, None]), axis=1
)

In [63]:
adata_train_final.obsm["ecfp_cell_line_logdose"] = np.concatenate(
    (adata_train_final.obsm["ecfp_cell_line"], adata_train_final.obs["logdose"].values[:, None]), axis=1
)
adata_test_final.obsm["ecfp_cell_line_logdose"] = np.concatenate(
    (adata_test_final.obsm["ecfp_cell_line"], adata_test_final.obs["logdose"].values[:, None]), axis=1
)
adata_ood_final.obsm["ecfp_cell_line_logdose"] = np.concatenate(
    (adata_ood_final.obsm["ecfp_cell_line"], adata_ood_final.obs["logdose"].values[:, None]), axis=1
)

In [64]:
adata_train_final.obsm["ecfp_cell_line_logdose_more_dose"] = np.concatenate(
    (adata_train_final.obsm["ecfp_cell_line"], np.tile(adata_train_final.obs["logdose"].values[:, None], (1, 100))),
    axis=1,
)
adata_test_final.obsm["ecfp_cell_line_logdose_more_dose"] = np.concatenate(
    (adata_test_final.obsm["ecfp_cell_line"], np.tile(adata_test_final.obs["logdose"].values[:, None], (1, 100))),
    axis=1,
)
adata_ood_final.obsm["ecfp_cell_line_logdose_more_dose"] = np.concatenate(
    (adata_ood_final.obsm["ecfp_cell_line"], np.tile(adata_ood_final.obs["logdose"].values[:, None], (1, 100))), axis=1
)

In [65]:
adata_train_final.obsm["ecfp_cell_line_dose_more_dose"] = np.concatenate(
    (adata_train_final.obsm["ecfp_cell_line"], np.tile(adata_train_final.obs["dose"].values[:, None], (1, 100))), axis=1
)
adata_test_final.obsm["ecfp_cell_line_dose_more_dose"] = np.concatenate(
    (adata_test_final.obsm["ecfp_cell_line"], np.tile(adata_test_final.obs["dose"].values[:, None], (1, 100))), axis=1
)
adata_ood_final.obsm["ecfp_cell_line_dose_more_dose"] = np.concatenate(
    (adata_ood_final.obsm["ecfp_cell_line"], np.tile(adata_ood_final.obs["dose"].values[:, None], (1, 100))), axis=1
)

In [66]:
adata_train_final.varm["X_train_mean"] = np.asarray(adata_train_final.varm["X_train_mean"])
adata_test_final.varm["X_train_mean"] = np.asarray(adata_test_final.varm["X_train_mean"])
adata_ood_final.varm["X_train_mean"] = np.asarray(adata_ood_final.varm["X_train_mean"])

In [67]:
adata_ood_final.uns = adata_test_final.uns.copy()

In [68]:
import os

output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex"
adata_train_final.write(os.path.join(output_dir, "adata_train_biolord_split_all_30.h5ad"))
adata_ood_final.write(os.path.join(output_dir, "adata_ood_biolord_split_all_30.h5ad"))
adata_test_final.write(os.path.join(output_dir, "adata_test_biolord_split_all_30.h5ad"))

In [69]:
from sklearn.metrics import r2_score


def compute_r_squared(x: np.ndarray, y: np.ndarray) -> float:
    return r2_score(np.mean(x, axis=0), np.mean(y, axis=0))

In [70]:
decoded_test = np.matmul(adata_test_final.obsm["X_pca"], adata_train_final.varm["PCs"].T)
compute_r_squared(adata_test_final.X.A, np.asarray(decoded_test + adata_test_final.varm["X_train_mean"].T))

0.9999532075726097

In [71]:
decoded_ood = np.matmul(adata_ood_final.obsm["X_pca"], adata_train_final.varm["PCs"].T)
compute_r_squared(adata_ood_final.X.A, np.asarray(decoded_ood + adata_ood_final.varm["X_train_mean"].T))

0.9955817839655192

In [72]:
adata_ood_final.obs["condition"].value_counts()

condition
MCF7_Givinostat_(ITF2357)_10.0             501
MCF7_Vehicle_0.0                           500
A549_Vehicle_0.0                           500
K562_Vehicle_0.0                           500
MCF7_Givinostat_(ITF2357)_100.0            491
                                          ... 
MCF7_Flavopiridol_HCl_1000.0                31
MCF7_Flavopiridol_HCl_10000.0               19
MCF7_Alvespimycin_(17-DMAG)_HCl_10000.0     16
K562_Flavopiridol_HCl_1000.0                15
K562_Flavopiridol_HCl_10000.0                6
Name: count, Length: 111, dtype: int64