In [1]:
import numpy as np
import pandas as pd
import pertpy
import scanpy as sc
from rdkit import Chem
import anndata
from cfp import preprocessing as cfpp

  from optuna import progress_bar as pbar_module


In [2]:
print(pertpy.__version__)

0.9.4


In [3]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex"

In [4]:
adata = pertpy.data.combosciplex()

In [5]:
adata.obs["condition"].value_counts()

condition
Dacinostat+PCI-34051         3298
SRT3025+Cediranib            3016
Givinostat+Cediranib         2783
control+SRT2104              2756
Givinostat+Curcumin          2736
Givinostat+Sorafenib         2734
Givinostat+Carmofur          2692
Givinostat+Crizotinib        2662
Givinostat+Dasatinib         2421
Givinostat+SRT2104           2353
control+Dasatinib            2343
Givinostat+SRT1720           2260
Panobinostat+Curcumin        2244
Cediranib+PCI-34051          2161
Panobinostat+Sorafenib       2013
Panobinostat+SRT2104         1971
Panobinostat+Dasatinib       1955
Dacinostat+Danusertib        1939
Panobinostat+SRT3025         1889
control+Dacinostat           1869
Panobinostat+SRT1720         1826
Panobinostat+PCI-34051       1814
control+Givinostat           1682
Panobinostat+Crizotinib      1641
control+Panobinostat         1578
control+control              1451
Givinostat+Tanespimycin      1310
Dacinostat+Dasatinib         1231
Panobinostat+Alvespimycin     996
cont

In [6]:
adata.obs["Drug2"].value_counts()

Drug2
Dasatinib       7950
PCI-34051       7273
SRT2104         7080
Cediranib       5799
Curcumin        4980
Sorafenib       4747
Crizotinib      4303
SRT1720         4086
Carmofur        2692
Alvespimycin    2274
Danusertib      1939
SRT3025         1889
Dacinostat      1869
Givinostat      1682
Panobinostat    1578
control         1451
Tanespimycin    1310
Pirarubicin      476
Name: count, dtype: int64

In [7]:
adata.obs["Drug1"].value_counts()

Drug1
Givinostat      21951
Panobinostat    16349
control         12437
Dacinostat       6468
SRT3025          3016
Cediranib        2161
SRT2104           520
Alvespimycin      476
Name: count, dtype: int64

In [8]:
adata.layers["counts"][0].data.max()

183.0

In [9]:
adata.X = adata.layers["counts"].copy()

In [10]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [11]:
sc.pp.highly_variable_genes(adata, inplace=True, n_top_genes=2000)

In [12]:
adata = adata[:, adata.var["highly_variable"] == True]

In [13]:
adata.X = adata.layers["counts"].copy()
sc.pp.normalize_total(adata, target_sum=10_000)
sc.pp.log1p(adata)



  view_to_actual(adata)


In [14]:
adata.obs["condition"] = adata.obs.apply(
    lambda x: "control" if x["condition"] == "control+control" else x["condition"], axis=1
)

In [15]:
def rank_genes_groups_by_cov(
    adata,
    groupby,
    control_group,
    covariate,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):
    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        # name of the control group in the groupby obs column
        control_group_cov = control_group  # "_".join([cov_cat, control_group])
        # subset adata to cells belonging to a covariate category
        adata_cov = adata[adata.obs[covariate] == cov_cat]
        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
            use_raw=False,
        )
        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()
    adata.uns[key_added] = gene_dict
    if return_dict:
        return gene_dict


def get_DE_genes(adata):
    adata.obs.loc[:, "control"] = adata.obs.condition.apply(lambda x: 1 if x == "control" else 0)
    adata.obs = adata.obs.astype("category")
    rank_genes_groups_by_cov(
        adata,
        groupby="condition",
        covariate="cell_type",
        control_group="control",
        n_genes=50,
        key_added="rank_genes_groups_cov_all",
    )
    return adata

In [16]:
adata = get_DE_genes(adata)

  adata.uns[key_added] = {}
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global

In [17]:
assert "rank_genes_groups_cov_all" in adata.uns.keys()

In [18]:
adata.obs["cell_line"] = adata.obs["cell_type"]

In [19]:
set(adata.obs["Drug1"].unique()) - set(adata.obs["Drug2"].unique())

set()

In [20]:
adata_dummy = adata[adata.obs_names.isin(adata.obs.drop_duplicates(subset=["Drug2"]).index)]

In [21]:
list(adata.obs["Drug2"].unique())

['Panobinostat',
 'PCI-34051',
 'SRT1720',
 'SRT3025',
 'Dacinostat',
 'Sorafenib',
 'Cediranib',
 'Givinostat',
 'Danusertib',
 'Dasatinib',
 'Tanespimycin',
 'Carmofur',
 'SRT2104',
 'Crizotinib',
 'Pirarubicin',
 'control',
 'Alvespimycin',
 'Curcumin']

In [22]:
# taken from pubchem
drug_to_smiles = {
    "Panobinostat": "CC1=C(C2=CC=CC=C2N1)CCNCC3=CC=C(C=C3)/C=C/C(=O)NO",
    "PCI-34051": "COC1=CC=C(C=C1)CN2C=CC3=C2C=C(C=C3)C(=O)NO",
    "SRT1720": "C1CN(CCN1)CC2=CSC3=NC(=CN23)C4=CC=CC=C4NC(=O)C5=NC6=CC=CC=C6N=C5",
    "SRT3025": "COCCCC1=C(N=C(S1)C2=CC=CC=C2)C(=O)NC3=CC=CC=C3C4=NC5=C(S4)N=CC(=C5)CN6CCCC6",
    "Dacinostat": "C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)/C=C/C(=O)NO",
    "Sorafenib": "CNC(=O)C1=NC=CC(=C1)OC2=CC=C(C=C2)NC(=O)NC3=CC(=C(C=C3)Cl)C(F)(F)F",
    "Cediranib": "CC1=CC2=C(N1)C=CC(=C2F)OC3=NC=NC4=CC(=C(C=C43)OC)OCCCN5CCCC5",
    "Givinostat": "CCN(CC)CC1=CC2=C(C=C1)C=C(C=C2)COC(=O)NC3=CC=C(C=C3)C(=O)NO",
    "Danusertib": "CN1CCN(CC1)C2=CC=C(C=C2)C(=O)NC3=NNC4=C3CN(C4)C(=O)[C@@H](C5=CC=CC=C5)OC",
    "Dasatinib": "CC1=C(C(=CC=C1)Cl)NC(=O)C2=CN=C(S2)NC3=CC(=NC(=N3)C)N4CCN(CC4)CCO",
    "Tanespimycin": r"C[C@H]1C[C@@H]([C@@H]([C@H](/C=C(/[C@@H]([C@H](/C=C\C=C(\C(=O)NC2=CC(=O)C(=C(C1)C2=O)NCC=C)/C)OC)OC(=O)N)\C)C)O)OC",
    "Carmofur": "CCCCCCNC(=O)N1C=C(C(=O)NC1=O)F",
    "SRT2104": "CC1=C(SC(=N1)C2=CN=CC=C2)C(=O)NC3=CC=CC=C3C4=CN5C(=CSC5=N4)CN6CCOCC6",
    "Crizotinib": "C[C@H](C1=C(C=CC(=C1Cl)F)Cl)OC2=C(N=CC(=C2)C3=CN(N=C3)C4CCNCC4)N",
    "Pirarubicin": "C[C@H]1[C@H]([C@H](C[C@@H](O1)O[C@H]2C[C@@](CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O[C@@H]6CCCCO6",
    "Alvespimycin": r"C[C@H]1C[C@@H]([C@@H]([C@H](/C=C(/[C@@H]([C@H](/C=C\C=C(\C(=O)NC2=CC(=O)C(=C(C1)C2=O)NCCN(C)C)/C)OC)OC(=O)N)\C)C)O)OC",
    "Curcumin": "COC1=C(C=CC(=C1)/C=C/C(=O)CC(=O)/C=C/C2=CC(=C(C=C2)O)OC)O",
    "control": None,
}

In [23]:
adata.obs["smiles_drug_1"] = adata.obs["Drug1"].map(drug_to_smiles)
adata.obs["smiles_drug_2"] = adata.obs["Drug2"].map(drug_to_smiles)

In [24]:
from rdkit.Chem import AllChem


def get_fp(smiles, radius=4, nBits=1024):
    m = Chem.MolFromSmiles(smiles, sanitize=False)
    if m is None:
        return "invalid"
    else:
        try:
            Chem.SanitizeMol(m)
        except:
            return "invalid"
    return AllChem.GetHashedMorganFingerprint(m, radius=radius, nBits=nBits)

In [25]:
drug_to_fp = {}
for drug, sm in drug_to_smiles.items():
    if not isinstance(sm, str):
        continue
    drug_to_fp[drug] = np.array(list(get_fp(sm)))



In [26]:
len(drug_to_fp), len(drug_to_smiles)

(17, 18)

In [27]:
features_df = pd.DataFrame.from_dict(drug_to_fp).T

In [28]:
features_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
Panobinostat,0,0,0,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
PCI-34051,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
SRT1720,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,1,0,0,0,0,0
SRT3025,0,0,1,0,2,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
Dacinostat,0,0,0,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
Sorafenib,0,0,0,0,0,0,0,0,0,1,...,1,0,0,0,0,1,0,0,0,0
Cediranib,0,0,0,0,2,0,0,0,0,0,...,0,0,0,1,1,0,0,0,0,1
Givinostat,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Danusertib,0,1,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Dasatinib,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [29]:
features_cells_drug_1 = np.zeros((adata.shape[0], features_df.shape[1]))
features_cells_drug_2 = np.zeros((adata.shape[0], features_df.shape[1]))
for mol, ecfp in features_df.iterrows():
    print(mol)
    features_cells_drug_1[adata.obs["Drug1"].isin([mol])] = ecfp.values
    features_cells_drug_2[adata.obs["Drug2"].isin([mol])] = ecfp.values

Panobinostat
PCI-34051
SRT1720
SRT3025
Dacinostat
Sorafenib
Cediranib
Givinostat
Danusertib
Dasatinib
Tanespimycin
Carmofur
SRT2104
Crizotinib
Pirarubicin
Alvespimycin
Curcumin


In [30]:
adata.obsm["ecfp_drug_1"] = features_cells_drug_1
adata.obsm["ecfp_drug_2"] = features_cells_drug_2

In [31]:
adata_dummy = adata[adata.obs_names.isin(adata.obs.drop_duplicates(subset=["Drug2"]).index)]

In [32]:
adata.obs["condition"].value_counts()

condition
Dacinostat+PCI-34051         3298
SRT3025+Cediranib            3016
Givinostat+Cediranib         2783
control+SRT2104              2756
Givinostat+Curcumin          2736
Givinostat+Sorafenib         2734
Givinostat+Carmofur          2692
Givinostat+Crizotinib        2662
Givinostat+Dasatinib         2421
Givinostat+SRT2104           2353
control+Dasatinib            2343
Givinostat+SRT1720           2260
Panobinostat+Curcumin        2244
Cediranib+PCI-34051          2161
Panobinostat+Sorafenib       2013
Panobinostat+SRT2104         1971
Panobinostat+Dasatinib       1955
Dacinostat+Danusertib        1939
Panobinostat+SRT3025         1889
control+Dacinostat           1869
Panobinostat+SRT1720         1826
Panobinostat+PCI-34051       1814
control+Givinostat           1682
Panobinostat+Crizotinib      1641
control+Panobinostat         1578
control                      1451
Givinostat+Tanespimycin      1310
Dacinostat+Dasatinib         1231
Panobinostat+Alvespimycin     996
cont

In [33]:
df_conds = adata.obs.drop_duplicates(subset=["condition"])

In [34]:
at_least_twice = (
    set(df_conds["Drug1"].value_counts()[df_conds["Drug1"].value_counts() >= 2].index)
    | set(df_conds["Drug2"].value_counts()[df_conds["Drug2"].value_counts() >= 2].index)
    | (
        set(df_conds["Drug1"].value_counts()[df_conds["Drug1"].value_counts() >= 1].index).intersection(
            set(df_conds["Drug2"].value_counts()[df_conds["Drug2"].value_counts() >= 1].index)
        )
    )
)

In [35]:
filtered_df = df_conds[
    (df_conds["Drug1"].isin(at_least_twice)) & (df_conds["Drug2"].isin(at_least_twice))
]

In [36]:
filtered_df

Unnamed: 0_level_0,sample,Size_Factor,n.umi,RT_well,Drug1,Drug2,Well,n_genes,n_genes_by_counts,total_counts,...,leiden,condition,pathway1,pathway2,split,control,cell_type,cell_line,smiles_drug_1,smiles_drug_2
Cell,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A01_A02_RT_BC_10_Lig_BC_18,sciPlex_theis,0.533816,1433,RT_10,control,Panobinostat,A10,1004,1004,1433.0,...,1,control+Panobinostat,Vehicle,HDAC inhibitor,train,0,A549,A549,,CC1=C(C2=CC=CC=C2N1)CCNCC3=CC=C(C=C3)/C=C/C(=O)NO
A01_A02_RT_BC_13_Lig_BC_1,sciPlex_theis,1.083277,2908,RT_13,Cediranib,PCI-34051,B1,1896,1895,2907.0,...,0,Cediranib+PCI-34051,EGFR inhibitor,HDAC inhibitor,test,0,A549,A549,CC1=CC2=C(N1)C=CC(=C2F)OC3=NC=NC4=CC(=C(C=C43)...,COC1=CC=C(C=C1)CN2C=CC3=C2C=C(C=C3)C(=O)NO
A01_A02_RT_BC_16_Lig_BC_2,sciPlex_theis,1.298964,3487,RT_16,Givinostat,SRT1720,B4,2154,2152,3485.0,...,0,Givinostat+SRT1720,HDAC inhibitor,Sirtuin inhibitor,train,0,A549,A549,CCN(CC)CC1=CC2=C(C=C1)C=C(C=C2)COC(=O)NC3=CC=C...,C1CN(CCN1)CC2=CSC3=NC(=CN23)C4=CC=CC=C4NC(=O)C...
A01_A02_RT_BC_19_Lig_BC_11,sciPlex_theis,0.587086,1576,RT_19,Panobinostat,SRT3025,B7,1131,1129,1574.0,...,1,Panobinostat+SRT3025,HDAC inhibitor,Sirtuin inhibitor,train,0,A549,A549,CC1=C(C2=CC=CC=C2N1)CCNCC3=CC=C(C=C3)/C=C/C(=O)NO,COCCCC1=C(N=C(S1)C2=CC=CC=C2)C(=O)NC3=CC=CC=C3...
A01_A02_RT_BC_1_Lig_BC_3,sciPlex_theis,2.151281,5775,RT_1,Panobinostat,PCI-34051,A1,3139,3138,5774.0,...,1,Panobinostat+PCI-34051,HDAC inhibitor,HDAC inhibitor,train,0,A549,A549,CC1=C(C2=CC=CC=C2N1)CCNCC3=CC=C(C=C3)/C=C/C(=O)NO,COC1=CC=C(C=C1)CN2C=CC3=C2C=C(C=C3)C(=O)NO
A01_A02_RT_BC_22_Lig_BC_17,sciPlex_theis,0.899999,2416,RT_22,control,Dacinostat,B10,1572,1569,2412.0,...,1,control+Dacinostat,Vehicle,HDAC inhibitor,test,0,A549,A549,,C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)/C=...
A01_A02_RT_BC_25_Lig_BC_12,sciPlex_theis,0.534188,1434,RT_25,Dacinostat,PCI-34051,C1,1015,1014,1433.0,...,1,Dacinostat+PCI-34051,HDAC inhibitor,HDAC inhibitor,train,0,A549,A549,C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)/C=...,COC1=CC=C(C=C1)CN2C=CC3=C2C=C(C=C3)C(=O)NO
A01_A02_RT_BC_28_Lig_BC_12,sciPlex_theis,1.136175,3050,RT_28,Panobinostat,Sorafenib,C4,2015,2015,3050.0,...,1,Panobinostat+Sorafenib,HDAC inhibitor,EGFR inhibitor,train,0,A549,A549,CC1=C(C2=CC=CC=C2N1)CCNCC3=CC=C(C=C3)/C=C/C(=O)NO,CNC(=O)C1=NC=CC(=C1)OC2=CC=C(C=C2)NC(=O)NC3=CC...
A01_A02_RT_BC_31_Lig_BC_2,sciPlex_theis,2.073798,5567,RT_31,Givinostat,Cediranib,C7,3103,3100,5564.0,...,0,Givinostat+Cediranib,HDAC inhibitor,EGFR inhibitor,ood,0,A549,A549,CCN(CC)CC1=CC2=C(C=C1)C=C(C=C2)COC(=O)NC3=CC=C...,CC1=CC2=C(N1)C=CC(=C2F)OC3=NC=NC4=CC(=C(C=C43)...
A01_A02_RT_BC_34_Lig_BC_21,sciPlex_theis,1.1548,3100,RT_34,control,Givinostat,C10,1987,1985,3098.0,...,0,control+Givinostat,Vehicle,HDAC inhibitor,train,0,A549,A549,,CCN(CC)CC1=CC2=C(C=C1)C=C(C=C2)COC(=O)NC3=CC=C...


In [37]:
filtered_df = filtered_df[~((filtered_df["Drug1"]=="control") & (filtered_df["Drug2"]=="control"))]

In [38]:
filtered_df.shape

(27, 22)

In [39]:
condition_not_satisfied = True
rng_seed = 0

def check_condition(df_ood, df_train):
    ood_drugs_1 = set(df_ood["Drug1"])
    ood_drugs_2 = set(df_ood["Drug2"])
    ood_drugs = ood_drugs_1 | ood_drugs_2

    train_drugs_1 = set(df_ood["Drug1"])
    train_drugs_2 = set(df_ood["Drug2"])
    train_drugs = train_drugs_1 | train_drugs_2

    if ood_drugs.issubset(train_drugs):
        return True
    return False
    
while True:
    rng = np.random.default_rng(rng_seed)
    numbers = np.arange(27)
    rng.shuffle(numbers)
    subset_1 = numbers[:7]
    subset_2 = numbers[7:14]
    subset_3 = numbers[14:21]
    subset_4 = numbers[21:]
    df_1 = filtered_df.iloc[subset_1, :]
    df_2 = filtered_df.iloc[subset_2, :]
    df_3 = filtered_df.iloc[subset_3, :]
    df_4 = filtered_df.iloc[subset_4, :]
    cond_1 = check_condition(df_1, pd.concat((df_2, df_3, df_4)))
    cond_2 = check_condition(df_2, pd.concat((df_1, df_3, df_4)))
    cond_3 = check_condition(df_3, pd.concat((df_1, df_2, df_4)))
    cond_4 = check_condition(df_4, pd.concat((df_1, df_2, df_3)))
    if cond_1+cond_2+cond_3+cond_4 == 4:
        break
    rng_seed +=1
    

In [40]:
ood_conditions_1 = df_1.condition.values
ood_conditions_2 = df_2.condition.values
ood_conditions_3 = df_3.condition.values
ood_conditions_4 = df_4.condition.values

In [41]:
adata.obs["ood_1"] = adata.obs.apply(lambda x: x["condition"] if x["condition"] in ood_conditions_1 else "not ood", axis=1)
adata.obs["ood_1"] = adata.obs["ood_1"].astype("category")
adata.obs["ood_2"] = adata.obs.apply(lambda x: x["condition"] if x["condition"] in ood_conditions_2 else "not ood", axis=1)
adata.obs["ood_2"] = adata.obs["ood_2"].astype("category")
adata.obs["ood_3"] = adata.obs.apply(lambda x: x["condition"] if x["condition"] in ood_conditions_3 else "not ood", axis=1)
adata.obs["ood_3"] = adata.obs["ood_3"].astype("category")
adata.obs["ood_4"] = adata.obs.apply(lambda x: x["condition"] if x["condition"] in ood_conditions_4 else "not ood", axis=1)
adata.obs["ood_4"] = adata.obs["ood_4"].astype("category")

In [42]:
sc.pl.umap(adata, color=["ood_1", "ood_2", "ood_3", "ood_4"])

In [43]:
adata_train_1 = adata[~adata.obs["condition"].isin(ood_conditions_1)].copy()
adata_ood_1 = adata[adata.obs["condition"].isin(ood_conditions_1)].copy()

In [44]:
adata_train_2 = adata[~adata.obs["condition"].isin(ood_conditions_2)].copy()
adata_ood_2 = adata[adata.obs["condition"].isin(ood_conditions_2)].copy()

In [45]:
adata_train_3 = adata[~adata.obs["condition"].isin(ood_conditions_3)].copy()
adata_ood_3 = adata[adata.obs["condition"].isin(ood_conditions_3)].copy()

In [46]:
adata_train_4 = adata[~adata.obs["condition"].isin(ood_conditions_4)].copy()
adata_ood_4 = adata[adata.obs["condition"].isin(ood_conditions_4)].copy()

In [47]:
adata_train_1.n_obs, adata_train_2.n_obs, adata_train_3.n_obs, adata_train_4.n_obs

(47598, 48867, 52507, 49030)

In [48]:
adata_ood_1.n_obs, adata_ood_2.n_obs, adata_ood_3.n_obs, adata_ood_4.n_obs

(15780, 14511, 10871, 14348)

In [49]:
def make_splits(adata_train, adata_ood):
    rng = np.random.default_rng(0)
    split_dfs = []
    for drug in adata_train.obs["condition"].unique():
        adata_subset = adata_train[(adata_train.obs["condition"] == drug)]
        n_cells = adata_subset.n_obs
        if drug == "control":
            idx_test = rng.choice(np.arange(n_cells), 500, replace=False)
            remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
            split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
            df = adata_subset.obs[["condition"]].copy()
            df["split"] = split
            split_dfs.append(df)
        elif n_cells > 300:
            idx_test = rng.choice(np.arange(n_cells), 100, replace=False)
            split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
            df = adata_subset.obs[["condition"]].copy()
            df["split"] = split
            split_dfs.append(df)
    df_concat = pd.concat(split_dfs, axis=0)
    adata_train.obs["split"] = df_concat[["split"]]
    adata_ood.obs["split"] = "ood"
    adata_train_final = adata_train[adata_train.obs["split"] == "train"]
    adata_test_final = adata_train[adata_train.obs["split"] == "test"]
    adata_ood_final = anndata.concat((adata_ood, adata_test_final[adata_test_final.obs["condition"] == "control"]))

    return adata_train_final, adata_test_final, adata_ood_final

In [50]:
adata_train_1, adata_test_1, adata_ood_1 =  make_splits(adata_train_1, adata_ood_1)
adata_train_2, adata_test_2, adata_ood_2 =  make_splits(adata_train_2, adata_ood_2)
adata_train_3, adata_test_3, adata_ood_3 =  make_splits(adata_train_3, adata_ood_3)
adata_train_4, adata_test_4, adata_ood_4 =  make_splits(adata_train_4, adata_ood_4)

In [51]:
cfpp.centered_pca(adata_train_1, n_comps=100)
cfpp.project_pca(query_adata = adata_test_1, ref_adata=adata_train_1)
cfpp.project_pca(query_adata = adata_ood_1, ref_adata=adata_train_1)

cfpp.centered_pca(adata_train_2, n_comps=100)
cfpp.project_pca(query_adata = adata_test_2, ref_adata=adata_train_2)
cfpp.project_pca(query_adata = adata_ood_2, ref_adata=adata_train_2)

cfpp.centered_pca(adata_train_3, n_comps=100)
cfpp.project_pca(query_adata = adata_test_3, ref_adata=adata_train_3)
cfpp.project_pca(query_adata = adata_ood_3, ref_adata=adata_train_3)

cfpp.centered_pca(adata_train_4, n_comps=100)
cfpp.project_pca(query_adata = adata_test_4, ref_adata=adata_train_4)
cfpp.project_pca(query_adata = adata_ood_4, ref_adata=adata_train_4)

  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(
  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(
  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(
  adata.varm["X_mean"] = np.array(X.mean(axis=0).T)
  query_adata.obsm[obsm_key_added] = np.array(


In [52]:
'rank_genes_groups_cov_all' in adata_train_1.uns.keys(), 'rank_genes_groups_cov_all' in adata_test_1.uns.keys(), 'rank_genes_groups_cov_all' in adata_ood_1.uns.keys(),

(True, True, False)

In [53]:
adata_ood_1.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_ood_2.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_ood_3.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']
adata_ood_4.uns['rank_genes_groups_cov_all'] = adata_train_1.uns['rank_genes_groups_cov_all']

In [54]:
adata_train_1.varm["X_mean"] = np.asarray(adata_train_1.varm["X_mean"])
adata_train_2.varm["X_mean"] = np.asarray(adata_train_2.varm["X_mean"])
adata_train_3.varm["X_mean"] = np.asarray(adata_train_3.varm["X_mean"])
adata_train_4.varm["X_mean"] = np.asarray(adata_train_4.varm["X_mean"])

In [55]:
import os

adata_train_1.write(os.path.join(output_dir, "adata_train_1.h5ad"))
adata_test_1.write(os.path.join(output_dir, "adata_test_1.h5ad"))
adata_ood_1.write(os.path.join(output_dir, "adata_ood_1.h5ad"))

adata_train_2.write(os.path.join(output_dir, "adata_train_2.h5ad"))
adata_test_2.write(os.path.join(output_dir, "adata_test_2.h5ad"))
adata_ood_2.write(os.path.join(output_dir, "adata_ood_2.h5ad"))

adata_train_3.write(os.path.join(output_dir, "adata_train_3.h5ad"))
adata_test_3.write(os.path.join(output_dir, "adata_test_3.h5ad"))
adata_ood_3.write(os.path.join(output_dir, "adata_ood_3.h5ad"))

adata_train_4.write(os.path.join(output_dir, "adata_train_4.h5ad"))
adata_test_4.write(os.path.join(output_dir, "adata_test_4.h5ad"))
adata_ood_4.write(os.path.join(output_dir, "adata_ood_4.h5ad"))
