In [13]:
import json
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm.auto import tqdm

from utils import load_smiles, compute_pred_ctrl,load_dataset

from chemCPA.data import load_dataset_splits
from chemCPA.paths import DATA_DIR, FIGURE_DIR, PROJECT_DIR, ROOT

In [173]:
pd.set_option('display.max_columns', 200)

In [15]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2

The lab_black extension is already loaded. To reload it, use:
  %reload_ext lab_black
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
seml_collection = "baseline_comparison"

In [9]:
seml_collection = "baseline_comparison"

base_hash_dict = dict(
    baseline_A549="044c4dba0c8719985c3622834f2cbd58",
    baseline_K562="5ea85d5dd7abd5962d1d3eeff1b8c1ff",
    baseline_MCF7="4d98e7d857f497d870e19c6d12175aaa",
)

In [10]:
def load_config(seml_collection, model_hash):
    file_path = PROJECT_DIR / f"{seml_collection}.json"  # Provide path to json

    with open(file_path) as f:
        file_data = json.load(f)

    for _config in tqdm(file_data):
        if _config["config_hash"] == model_hash:
            # print(config)
            config = _config["config"]
            config["config_hash"] = _config["config_hash"]
    return config

In [100]:
config = load_config(seml_collection, base_hash_dict['baseline_A549'])

config["dataset"]["data_params"]["dataset_path"] = DATA_DIR/ "sciplex_complete_middle_subset_lincs_genes.h5ad"
config["model"]["embedding"]["directory"] = (
    ROOT / config["model"]["embedding"]["directory"]
)

dataset, key_dict = load_dataset(config)

config["dataset"]["n_vars"] = dataset.n_vars

canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles(
    config, dataset, key_dict, True
)
# model_pretrained, embedding_pretrained = load_model(config, canon_smiles_unique_sorted)

  0%|          | 0/9 [00:00<?, ?it/s]

In [103]:
data_params = config["dataset"]["data_params"]
data_params['split_key'] = 'split_random'
datasets, dataset_all = load_dataset_splits(**data_params, return_dataset=True)

In [104]:
data_params

{'covariate_keys': 'cell_type',
 'dataset_path': PosixPath('/nfs/staff-ssd/hetzell/code/chemCPA_v2/project_folder/datasets/sciplex_complete_middle_subset_lincs_genes.h5ad'),
 'degs_key': 'lincs_DEGs',
 'dose_key': 'dose',
 'pert_category': 'cov_drug_dose_name',
 'perturbation_key': 'condition',
 'smiles_key': 'SMILES',
 'split_key': 'split_random',
 'use_drugs_idx': True}

In [16]:
dosages = [1e1, 1e2, 1e3, 1e4]
cell_lines = ["A549", "K562", "MCF7"]
use_DEGs = True

In [109]:
len(pd.Series(dataset_all.pert_categories).unique())

2247

In [110]:
def get_baseline_predictions(
    hash,
    seml_collection="baseline_comparison",
    smiles=None,
    dosages=[1e1, 1e2, 1e3, 1e4],
    cell_lines=["A549", "K562", "MCF7"],
    use_DEGs=False,
    verbose=False,
    name_tag=None,
):
    if smiles is None:
        smiles = canon_smiles_unique_sorted

    config = load_config(seml_collection, hash)
    config["dataset"]["n_vars"] = dataset.n_vars
    config["dataset"]["data_params"]["dataset_path"] = (
        ROOT / config["dataset"]["data_params"]["dataset_path"]
    )
    config["model"]["embedding"]["directory"] = (
        ROOT / config["model"]["embedding"]["directory"]
    )
    data_params = config["dataset"]["data_params"]
    datasets = load_dataset_splits(**data_params, return_dataset=False)

    predictions, _ = compute_pred_ctrl(
        dataset_all,
        dataset_ctrl=datasets["test_control"],
        # dataset_ctrl=datasets["training_control"],
        dosages=dosages,
        cell_lines=cell_lines,
        use_DEGs=use_DEGs,
        verbose=verbose,
    )

    predictions = pd.DataFrame.from_dict(predictions, orient="index", columns=["R2"])
    if name_tag:
        predictions["model"] = name_tag
    predictions["genes"] = "degs" if use_DEGs else "all"
    return predictions

In [111]:
# drug_r2_baseline_degs, _ = compute_pred_ctrl(
#     dataset=datasets["ood"],
#     dataset_ctrl=datasets["test_control"],
#     dosages=dosages,
#     cell_lines=cell_lines,
#     use_DEGs=True,
#     verbose=False,
# )

# drug_r2_baseline_all, _ = compute_pred_ctrl(
#     dataset=datasets["ood"],
#     dataset_ctrl=datasets["test_control"],
#     dosages=dosages,
#     cell_lines=cell_lines,
#     use_DEGs=False,
#     verbose=False,
# )
predictions = []

predictions.extend(
    [
        get_baseline_predictions(_hash, name_tag=name_tag, use_DEGs=True)
        for name_tag, _hash in base_hash_dict.items()
    ]
)

predictions.extend(
    [
        get_baseline_predictions(_hash, name_tag=name_tag, use_DEGs=False)
        for name_tag, _hash in base_hash_dict.items()
    ]
)

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

['A549', 'K562', 'MCF7']


0it [00:00, ?it/s]

In [112]:
predictions

[                                     R2          model genes
 A549_(+)-JQ1_0.001             0.977773  baseline_A549  degs
 A549_(+)-JQ1_0.01              0.961488  baseline_A549  degs
 A549_(+)-JQ1_0.1               0.877050  baseline_A549  degs
 A549_(+)-JQ1_1.0               0.819917  baseline_A549  degs
 A549_2-Methoxyestradiol_0.001  0.913421  baseline_A549  degs
 ...                                 ...            ...   ...
 MCF7_ZM_1.0                    0.966686  baseline_A549  degs
 MCF7_Zileuton_0.001            0.986553  baseline_A549  degs
 MCF7_Zileuton_0.01             0.991991  baseline_A549  degs
 MCF7_Zileuton_0.1              0.989118  baseline_A549  degs
 MCF7_Zileuton_1.0              0.993228  baseline_A549  degs
 
 [2244 rows x 3 columns],
                                      R2          model genes
 A549_(+)-JQ1_0.001             0.977773  baseline_K562  degs
 A549_(+)-JQ1_0.01              0.961488  baseline_K562  degs
 A549_(+)-JQ1_0.1               0.877050  

In [113]:
def rename_model(str):
    str_list = str.split("_")
    if len(str_list) == 2:
        return str_list[0]
    else:
        assert len(str_list) == 3
        return "_".join([str_list[0], str_list[2]])

In [114]:
predictions = pd.concat(predictions)
predictions.reset_index(inplace=True)
predictions["cell_type"] = predictions["index"].apply(lambda s: s.split("_")[0])
predictions["condition"] = predictions["index"].apply(lambda s: s.split("_")[1])
predictions["dose"] = predictions["index"].apply(lambda s: s.split("_")[2])
predictions["model_ct"] = predictions["model"]
predictions["model"] = predictions["model"].apply(rename_model)

In [118]:
predictions['dose'].unique()

array(['0.001', '0.01', '0.1', '1.0'], dtype=object)

In [120]:
cond = (predictions['genes']=='degs') & (predictions['dose']=='1.0')
predictions[cond].groupby(['condition']).mean().sort_values('R2')

Unnamed: 0_level_0,R2
condition,Unnamed: 1_level_1
Panobinostat,0.000000
Quisinostat,0.000000
AR-42,0.000000
Bisindolylmaleimide,0.004051
Dacinostat,0.013970
...,...
Aminoglutethimide,0.975252
Zileuton,0.976290
Amisulpride,0.976606
Meprednisone,0.983226


In [121]:
mean_df = predictions[cond].groupby(['condition']).mean().sort_values('R2')
mol_set = set(mean_df.index[:50])

In [123]:
std_df = predictions[cond].groupby(['condition']).std().sort_values('R2')
std_df.rename(columns={"R2": "R2_std"}, inplace=True)

In [124]:
std_df

Unnamed: 0_level_0,R2_std
condition,Unnamed: 1_level_1
Quisinostat,0.000000
Panobinostat,0.000000
AR-42,0.000000
Lenalidomide,0.002868
Bisindolylmaleimide,0.006076
...,...
Entinostat,0.377008
Fluorouracil,0.391580
Flavopiridol,0.408244
Raltitrexed,0.414432


In [125]:
mean_df

Unnamed: 0_level_0,R2
condition,Unnamed: 1_level_1
Panobinostat,0.000000
Quisinostat,0.000000
AR-42,0.000000
Bisindolylmaleimide,0.004051
Dacinostat,0.013970
...,...
Aminoglutethimide,0.975252
Zileuton,0.976290
Amisulpride,0.976606
Meprednisone,0.983226


In [126]:
df = pd.concat([mean_df, std_df], axis=1)

In [138]:
df_subset = df[(df['R2']<0.784) & (df['R2_std']<0.3)]
len(df_subset)

45

In [146]:
ood_drugs = df_subset.sample(frac=1).index.to_list()

In [203]:
import scanpy as sc
adata_sciplex = sc.read(DATA_DIR/ "sciplex_complete_middle_subset_lincs_genes.h5ad")

In [204]:
pd.crosstab(adata_sciplex.obs['split_ood_multi_task'], adata_sciplex.obs['condition'].isin(['control']))

condition,False,True
split_ood_multi_task,Unnamed: 1_level_1,Unnamed: 2_level_1
ood,11850,0
test,28338,1132
train,301448,11872


In [205]:

print(adata_sciplex.obs.dose.unique())
adata_sciplex = adata_sciplex[adata_sciplex.obs.dose.isin([0., 1e3])].copy()

[ 1000.     0.   100. 10000.    10.]


In [206]:
i = 0
n_drugs = 5
drug_sets = []
for j in range(n_drugs, len(ood_drugs), n_drugs):
    drug_set = ood_drugs[i:j+n_drugs]
    if (j+n_drugs) > len(ood_drugs):
        drug_set += ood_drugs[:2*n_drugs-len(drug_set)]
    i = j
    drug_sets.append(drug_set)

In [207]:
[print(s) for s in drug_sets]

['YM155', 'Pracinostat', 'Givinostat', 'Trichostatin', 'TAK-901', 'TMP195', 'Resminostat', 'Trametinib', '(+)-JQ1', 'JNJ-7706621']
['TMP195', 'Resminostat', 'Trametinib', '(+)-JQ1', 'JNJ-7706621', 'Quisinostat', 'FLLL32', 'Momelotinib', 'Rigosertib', 'Cyclocytidine']
['Quisinostat', 'FLLL32', 'Momelotinib', 'Rigosertib', 'Cyclocytidine', 'CUDC-101', 'Fedratinib', 'Bisindolylmaleimide', 'AT9283', 'Dasatinib']
['CUDC-101', 'Fedratinib', 'Bisindolylmaleimide', 'AT9283', 'Dasatinib', 'UNC0379', 'Triamcinolone', 'Cediranib', 'PFI-1', 'Toremifene']
['UNC0379', 'Triamcinolone', 'Cediranib', 'PFI-1', 'Toremifene', 'CUDC-907', 'Tucidinostat', 'Nintedanib', 'Dacinostat', 'Pirarubicin']
['CUDC-907', 'Tucidinostat', 'Nintedanib', 'Dacinostat', 'Pirarubicin', 'Regorafenib', 'M344', 'SGI-1776', 'PHA-680632', 'Obatoclax']
['Regorafenib', 'M344', 'SGI-1776', 'PHA-680632', 'Obatoclax', 'KW-2449', 'Panobinostat', 'Belinostat', 'AR-42', 'ENMD-2076']
['KW-2449', 'Panobinostat', 'Belinostat', 'AR-42', 'ENM

[None, None, None, None, None, None, None, None]

In [208]:
for i, drug_set in enumerate(drug_sets):
    for cell_type in adata_sciplex.obs.cell_type.unique():
        split = f'split_fold{i}_{cell_type}'
        print(split)
        adata_sciplex.obs[split] = adata_sciplex.obs['split_ood_multi_task']
        adata_sciplex.obs.loc[adata_sciplex.obs[split] == 'ood', split] = 'train'
        adata_sciplex.obs.loc[adata_sciplex.obs['condition'].isin(drug_set), split] = 'ood'

        sub_df = adata_sciplex.obs.loc[adata_sciplex.obs[split].isin(['ood']) * (adata_sciplex.obs.cell_type != cell_type)]
        train_test = sub_df.index
        test = sub_df.sample(frac=0.5).index 
        
        sub_df2 = adata_sciplex.obs.loc[adata_sciplex.obs[split].isin(['train'])]
        train_test2 = sub_df2.index
        test2 = sub_df.sample(frac=0.05).index 

        adata_sciplex.obs.loc[train_test,split] = 'train'
        adata_sciplex.obs.loc[test,split] = 'test'
        adata_sciplex.obs.loc[train_test2,split] = 'train'
        adata_sciplex.obs.loc[test2,split] = 'test'



split_fold0_A549
split_fold0_MCF7
split_fold0_K562
split_fold1_A549
split_fold1_MCF7
split_fold1_K562
split_fold2_A549
split_fold2_MCF7
split_fold2_K562
split_fold3_A549
split_fold3_MCF7
split_fold3_K562
split_fold4_A549
split_fold4_MCF7
split_fold4_K562
split_fold5_A549
split_fold5_MCF7
split_fold5_K562
split_fold6_A549
split_fold6_MCF7
split_fold6_K562
split_fold7_A549
split_fold7_MCF7
split_fold7_K562


In [209]:
# split_fold0_A549
# split_fold0_MCF7
# split_fold0_K562
# split_fold1_A549
# split_fold1_MCF7
# split_fold1_K562
# split_fold2_A549
# split_fold2_MCF7
# split_fold2_K562
# split_fold3_A549
# split_fold3_MCF7
# split_fold3_K562
# split_fold4_A549
# split_fold4_MCF7
# split_fold4_K562
# split_fold5_A549
# split_fold5_MCF7
# split_fold5_K562
# split_fold6_A549
# split_fold6_MCF7
# split_fold6_K562
# split_fold7_A549
# split_fold7_MCF7
# split_fold7_K562

'split_fold7_K562'

In [210]:
pd.crosstab(adata_sciplex.obs['split_fold7_A549'], adata_sciplex.obs['condition'])

condition,2-Methoxyestradiol,JQ1,A-366,ABT-737,AC480,AG-490,AG-14361,AICAR,AMG-900,AR-42,AT9283,AZ,AZD1480,Abexinostat,Alendronate,Alisertib,Altretamine,Alvespimycin,Aminoglutethimide,Amisulpride,Anacardic,Andarine,Aurora,Avagacestat,Azacitidine,BMS-265246,BMS-536924,BMS-754807,BMS-911543,BRD4770,Barasertib,Baricitinib,Belinostat,Bisindolylmaleimide,Bosutinib,Busulfan,CEP-33779,CUDC-101,CUDC-907,CYC116,Capecitabine,Carmofur,Cediranib,Celecoxib,Cerdulatinib,Cimetidine,Clevudine,Costunolide,Crizotinib,Curcumin,Cyclocytidine,Dacinostat,Danusertib,Daphnetin,Dasatinib,Decitabine,Disulfiram,Divalproex,Droxinostat,EED226,ENMD-2076,Ellagic,Entacapone,Entinostat,Enzastaurin,Epothilone,FLLL32,Fasudil,Fedratinib,Filgotinib,Flavopiridol,Fluorouracil,Fulvestrant,G007-LK,GSK,GSK1070916,GSK-LSD1,Gandotinib,Givinostat,Glesatinib?(MGCD265),Hesperadin,INO-1001,IOX2,ITSA-1,Iniparib,Ivosidenib,JNJ-7706621,JNJ-26854165,KW-2449,Ki8751,Ki16425,Lapatinib,Lenalidomide,Linifanib,Lomustine,Luminespib,M344,MC1568,MK-0752,MK-5108,MLN8054,Maraviroc,Meprednisone,Mercaptopurine,Mesna,Mocetinostat,Momelotinib,Motesanib,NVP-BSK805,Navitoclax,Nilotinib,Nintedanib,Obatoclax,Ofloxacin,PCI-34051,PD98059,PD173074,PF-3845,PF-573228,PFI-1,PHA-680632,PJ34,Panobinostat,Patupilone,Pelitinib,Pirarubicin,Pracinostat,Prednisone,Quercetin,Quisinostat,RG108,Raltitrexed,Ramelteon,Regorafenib,Resminostat,Resveratrol,Rigosertib,Roscovitine,Roxadustat,Rucaparib,Ruxolitinib,S3I-201,S-Ruxolitinib,SB431542,SGI-1776,SL-327,SNS-314,SRT1720,SRT2104,SRT3025,Selisistat,Sirtinol,Sodium,Sorafenib,Streptozotocin,TAK-901,TG101209,TGX-221,TMP195,Tacedinaline,Tanespimycin,Tazemetostat,Temsirolimus,Thalidomide,Thiotepa,Tie2,Tofacitinib,Toremifene,Tozasertib,Trametinib,Tranylcypromine,Triamcinolone,Trichostatin,Tubastatin,Tucidinostat,UNC0379,UNC0631,UNC1999,Valproic,Vandetanib,Veliparib,WHI-P154,WP1066,XAV-939,YM155,ZM,Zileuton,control
split_fold7_A549,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1,Unnamed: 102_level_1,Unnamed: 103_level_1,Unnamed: 104_level_1,Unnamed: 105_level_1,Unnamed: 106_level_1,Unnamed: 107_level_1,Unnamed: 108_level_1,Unnamed: 109_level_1,Unnamed: 110_level_1,Unnamed: 111_level_1,Unnamed: 112_level_1,Unnamed: 113_level_1,Unnamed: 114_level_1,Unnamed: 115_level_1,Unnamed: 116_level_1,Unnamed: 117_level_1,Unnamed: 118_level_1,Unnamed: 119_level_1,Unnamed: 120_level_1,Unnamed: 121_level_1,Unnamed: 122_level_1,Unnamed: 123_level_1,Unnamed: 124_level_1,Unnamed: 125_level_1,Unnamed: 126_level_1,Unnamed: 127_level_1,Unnamed: 128_level_1,Unnamed: 129_level_1,Unnamed: 130_level_1,Unnamed: 131_level_1,Unnamed: 132_level_1,Unnamed: 133_level_1,Unnamed: 134_level_1,Unnamed: 135_level_1,Unnamed: 136_level_1,Unnamed: 137_level_1,Unnamed: 138_level_1,Unnamed: 139_level_1,Unnamed: 140_level_1,Unnamed: 141_level_1,Unnamed: 142_level_1,Unnamed: 143_level_1,Unnamed: 144_level_1,Unnamed: 145_level_1,Unnamed: 146_level_1,Unnamed: 147_level_1,Unnamed: 148_level_1,Unnamed: 149_level_1,Unnamed: 150_level_1,Unnamed: 151_level_1,Unnamed: 152_level_1,Unnamed: 153_level_1,Unnamed: 154_level_1,Unnamed: 155_level_1,Unnamed: 156_level_1,Unnamed: 157_level_1,Unnamed: 158_level_1,Unnamed: 159_level_1,Unnamed: 160_level_1,Unnamed: 161_level_1,Unnamed: 162_level_1,Unnamed: 163_level_1,Unnamed: 164_level_1,Unnamed: 165_level_1,Unnamed: 166_level_1,Unnamed: 167_level_1,Unnamed: 168_level_1,Unnamed: 169_level_1,Unnamed: 170_level_1,Unnamed: 171_level_1,Unnamed: 172_level_1,Unnamed: 173_level_1,Unnamed: 174_level_1,Unnamed: 175_level_1,Unnamed: 176_level_1,Unnamed: 177_level_1,Unnamed: 178_level_1,Unnamed: 179_level_1,Unnamed: 180_level_1,Unnamed: 181_level_1,Unnamed: 182_level_1,Unnamed: 183_level_1,Unnamed: 184_level_1,Unnamed: 185_level_1,Unnamed: 186_level_1,Unnamed: 187_level_1,Unnamed: 188_level_1
ood,0,0,0,0,0,0,0,0,0,107,0,0,0,113,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,147,0,156,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,87,0,0,0,0,0,0,0,210,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,97,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,75,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,139,0,0,0,0,0,0,87,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
test,22,232,19,23,25,25,27,20,11,171,11,24,9,180,9,164,31,0,14,16,21,13,22,23,35,20,21,229,29,28,179,20,211,22,138,29,26,156,112,16,29,26,18,16,28,25,23,20,218,22,21,0,169,18,178,15,17,18,26,24,392,28,22,252,29,76,27,24,28,22,0,29,214,12,28,178,15,18,0,27,0,17,26,23,26,26,27,29,182,26,24,17,29,19,25,132,206,27,29,25,15,22,18,26,24,210,22,19,25,21,22,255,31,27,22,15,24,19,31,29,29,18,192,131,20,150,211,16,22,0,15,121,23,28,236,28,108,18,26,13,15,24,27,28,23,22,14,21,22,20,25,23,22,222,19,0,17,23,11,256,91,16,17,30,16,18,28,21,12,172,32,21,226,23,130,33,20,26,26,22,22,30,30,24,42,13,25,1132
train,567,340,528,405,606,569,514,554,435,161,256,455,430,171,347,231,600,277,533,498,566,529,514,501,527,515,477,261,562,559,247,588,171,500,101,646,608,378,302,514,603,555,506,431,602,561,550,510,261,592,477,389,147,599,235,321,528,617,602,595,360,545,606,339,584,98,527,595,496,539,113,529,349,451,584,248,607,525,496,562,358,563,566,592,579,520,518,594,183,556,535,501,480,558,619,173,306,662,590,510,406,572,513,495,465,286,588,575,561,578,488,334,585,603,490,621,508,518,551,521,394,627,176,191,529,340,287,471,599,398,607,283,573,569,311,493,136,508,612,333,478,539,628,507,566,531,367,575,675,501,574,541,614,186,500,337,484,570,375,368,117,617,458,573,501,497,591,568,278,380,547,578,305,590,210,607,569,612,568,565,510,611,595,640,57,428,623,11872


In [211]:
pd.crosstab(adata_sciplex.obs['split_fold7_A549'], adata_sciplex.obs['cell_type'])

cell_type,A549,K562,MCF7
split_fold7_A549,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ood,1218,0,0
test,2235,2763,6046
train,24484,25866,49668


In [212]:
pd.crosstab(adata_sciplex.obs['split_fold7_A549'], adata_sciplex.obs['condition'].isin(['control']))

condition,False,True
split_fold7_A549,Unnamed: 1_level_1,Unnamed: 2_level_1
ood,1218,0
test,9912,1132
train,88146,11872


In [213]:
adata_sciplex.write(DATA_DIR/'adata_fold.h5ad', compression="gzip")