In [1]:
import pandas as pd
import scanpy as sc
import numpy as np
import anndata
import pickle
import os

from gears import PertData, GEARS
from gears.inference import evaluate
import torch

In [None]:
DATA_DIR = "/lustre/groups/ml01/workspace/leander.dony/projects/cellflow/241106_gears/reproduce_biolord_repro/data/perturbations/norman/"

device = torch.cuda.current_device()
batch_size = 32

In [3]:
def rank_genes_groups_by_cov(
    adata,
    groupby,
    control_group,
    covariate,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):
    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        # name of the control group in the groupby obs column
        control_group_cov = control_group  # "_".join([cov_cat, control_group])
        # subset adata to cells belonging to a covariate category
        adata_cov = adata[adata.obs[covariate] == cov_cat]
        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
            use_raw=False,
        )
        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()
    adata.uns[key_added] = gene_dict
    if return_dict:
        return gene_dict

In [4]:
pert_data = PertData(DATA_DIR[:-1],  gene_set_path=DATA_DIR + "essential_norman.pkl")
pert_data.load(data_path = DATA_DIR + "norman2019")
de_dict = rank_genes_groups_by_cov(
    pert_data.adata,
    groupby="condition",
    covariate="cell_type",
    control_group="ctrl",
    n_genes=50,
    key_added="rank_genes_groups_cov_all",
    return_dict=True
)
pert_data.adata.write("/lustre/groups/ml01/workspace/leander.dony/projects/cellflow/241106_gears/evaluate_gears/norman/output/adata_all.h5ad", compression="gzip")

Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['RHOXF2BB+ctrl' 'LYL1+IER5L' 'ctrl+IER5L' 'KIAA1804+ctrl' 'IER5L+ctrl'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Local copy of pyg dataset is detected. Loading...
Done!


In [7]:
pert_data.adata.obs

Unnamed: 0_level_0,condition,cell_type,dose_val,control,condition_name
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
AAACCTGAGGCATGTG-1,TSC22D1+ctrl,A549,1+1,0,A549_TSC22D1+ctrl_1+1
AAACCTGAGGCCCTTG-1,KLF1+MAP2K6,A549,1+1,0,A549_KLF1+MAP2K6_1+1
AAACCTGCACGAAGCA-1,ctrl,A549,1,1,A549_ctrl_1
AAACCTGCAGACGTAG-1,CEBPE+RUNX1T1,A549,1+1,0,A549_CEBPE+RUNX1T1_1+1
AAACCTGCAGCCTTGG-1,MAML2+ctrl,A549,1+1,0,A549_MAML2+ctrl_1+1
...,...,...,...,...,...
TTTGTCAGTCAGAATA-8,ctrl,A549,1,1,A549_ctrl_1
TTTGTCATCAGTACGT-8,FOXA3+ctrl,A549,1+1,0,A549_FOXA3+ctrl_1+1
TTTGTCATCCACTCCA-8,CELF2+ctrl,A549,1+1,0,A549_CELF2+ctrl_1+1
TTTGTCATCCCAACGG-8,BCORL1+ctrl,A549,1+1,0,A549_BCORL1+ctrl_1+1


# Identity model

In [6]:
epoch = 0
no_perturb = True
savepath = "/lustre/groups/ml01/workspace/leander.dony/projects/cellflow/241106_gears/evaluate_gears/norman/output/adatas_identity/"

for seed in range(1,6):
    pert_data = PertData(DATA_DIR[:-1],  gene_set_path=DATA_DIR + "essential_norman.pkl")
    pert_data.load(data_path = DATA_DIR + "norman2019")
    pert_data.prepare_split(split = "simulation", seed = seed)
    pert_data.get_dataloader(batch_size = batch_size, test_batch_size = batch_size)
    
    noperturb_model = GEARS(
        pert_data, 
        device = "cuda:" + str(device), 
        weight_bias_track = False, 
        proj_name = "norman2019",
        exp_name = "no_perturb_seed" + str(seed)
    )
    
    noperturb_model.model_initialize(hidden_size = 64, no_perturb = no_perturb,  go_path = DATA_DIR + "go_essential_norman.csv")
    noperturb_model.train(epochs = epoch)
    results = {
        "res_test": evaluate(noperturb_model.dataloader['test_loader'], noperturb_model.best_model, noperturb_model.config['uncertainty'], noperturb_model.device),
        "res_val": evaluate(noperturb_model.dataloader['val_loader'], noperturb_model.best_model, noperturb_model.config['uncertainty'], noperturb_model.device),
    }
    
    for split in ["val", "test"]:
        for kind in ["truth", "pred"]:
            adata = anndata.AnnData(
                X=results[f"res_{split}"][kind],
                var=pert_data.adata.var,
                obs=pd.DataFrame({
                    "condition": results[f"res_{split}"]["pert_cat"], 
                    "subgroup": pd.Series(results[f"res_{split}"]["pert_cat"]).map({v: k for k, vs in pert_data.subgroup[f"{split}_subgroup"].items() for v in vs})
                }),
                uns={"rank_genes_groups_cov_all": de_dict},
            )
            for subgroup in adata.obs["subgroup"].unique():
                adata[adata.obs["subgroup"] == subgroup].write(os.path.join(savepath, f"norman_identity_seed{seed}_{split}_{subgroup.replace("_", "")}_{kind}.h5ad"), compression="gzip")

Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['RHOXF2BB+ctrl' 'LYL1+IER5L' 'ctrl+IER5L' 'KIAA1804+ctrl' 'IER5L+ctrl'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:43
combo_seen2:19
unseen_single:36
Done!
Creating dataloaders....
Done!


here1


Start Training...
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.4383
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: 0.008115437
test_combo_seen0_pearson: 0.9752434404669947
test_combo_seen0_mse_de: 0.36536023
test_combo_seen0_pearson_de: 0.6961077031964022
test_combo_seen1_mse: 0.009594609
test_combo_seen1_pearson: 0.9712218925820775
test_combo_seen1_mse_de: 0.4883878
test_combo_seen1_pearson_de: 0.7296004629532744
test_combo_seen2_mse: 0.0067827976
test_combo_seen2_pearson: 0.9796019813297331
test_combo_seen2_mse_de: 0.60923195
test_combo_seen2_pearson_de: 0.7952379593142974
test_unseen_single_mse: 0.0037921676
test_unseen_single_pearson: 0.9884388360874871
test_unseen_single_mse_de: 0.3063777
test_unseen_single_pearson_de: 0.846590207130147
test_combo_seen0_pearson_delta: -0.011405931789983191
test_combo_seen0_frac_opposite_direction_top20_non_dropout: 0.5166666666666666
test_combo_seen0_frac_sigma_below_1_non_dropout: 0.533333333333

here1


Start Training...
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.5159
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: 0.0062959776
test_combo_seen0_pearson: 0.9810927714215337
test_combo_seen0_mse_de: 0.681767
test_combo_seen0_pearson_de: 0.8323177199393589
test_combo_seen1_mse: 0.008714246
test_combo_seen1_pearson: 0.9740398060755299
test_combo_seen1_mse_de: 0.55266553
test_combo_seen1_pearson_de: 0.7477024047454699
test_combo_seen2_mse: 0.005804072
test_combo_seen2_pearson: 0.9823109862747891
test_combo_seen2_mse_de: 0.5144702
test_combo_seen2_pearson_de: 0.8472784628598902
test_unseen_single_mse: 0.0049762754
test_unseen_single_pearson: 0.9851787852514509
test_unseen_single_mse_de: 0.4109658
test_unseen_single_pearson_de: 0.8594419270252787
test_combo_seen0_pearson_delta: 0.036050283206535756
test_combo_seen0_frac_opposite_direction_top20_non_dropout: 0.4791666666666667
test_combo_seen0_frac_sigma_below_1_non_dropout: 0.5375
test_comb

here1


Start Training...
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.4531
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: 0.0061482713
test_combo_seen0_pearson: 0.981118791034024
test_combo_seen0_mse_de: 0.74522185
test_combo_seen0_pearson_de: 0.752536651219998
test_combo_seen1_mse: 0.006285342
test_combo_seen1_pearson: 0.9808828788384245
test_combo_seen1_mse_de: 0.53575903
test_combo_seen1_pearson_de: 0.8217631694684608
test_combo_seen2_mse: 0.011091816
test_combo_seen2_pearson: 0.967008530808813
test_combo_seen2_mse_de: 0.5225223
test_combo_seen2_pearson_de: 0.600003044039783
test_unseen_single_mse: 0.002916057
test_unseen_single_pearson: 0.9910963879618866
test_unseen_single_mse_de: 0.26150978
test_unseen_single_pearson_de: 0.893943295459342
test_combo_seen0_pearson_delta: -0.0022238702962948595
test_combo_seen0_frac_opposite_direction_top20_non_dropout: 0.475
test_combo_seen0_frac_sigma_below_1_non_dropout: 0.48750000000000004
test_combo

here1


Start Training...
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.3882
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: 0.00519515
test_combo_seen0_pearson: 0.9840850989540574
test_combo_seen0_mse_de: 0.43063676
test_combo_seen0_pearson_de: 0.91496417298913
test_combo_seen1_mse: 0.007430646
test_combo_seen1_pearson: 0.9776223607338014
test_combo_seen1_mse_de: 0.45155925
test_combo_seen1_pearson_de: 0.7092290699055883
test_combo_seen2_mse: 0.006670755
test_combo_seen2_pearson: 0.9796676921024599
test_combo_seen2_mse_de: 0.5045317
test_combo_seen2_pearson_de: 0.8847274513329038
test_unseen_single_mse: 0.0029013075
test_unseen_single_pearson: 0.9911273280102706
test_unseen_single_mse_de: 0.23805368
test_unseen_single_pearson_de: 0.8990595415688019
test_combo_seen0_pearson_delta: -0.059500240476381853
test_combo_seen0_frac_opposite_direction_top20_non_dropout: 0.5875
test_combo_seen0_frac_sigma_below_1_non_dropout: 0.675
test_combo_seen0_mse_t

here1


Start Training...
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.5410
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: 0.019676955
test_combo_seen0_pearson: 0.9424961528663598
test_combo_seen0_mse_de: 0.75205344
test_combo_seen0_pearson_de: 0.4237033897564591
test_combo_seen1_mse: 0.008966429
test_combo_seen1_pearson: 0.9731722353195832
test_combo_seen1_mse_de: 0.63086444
test_combo_seen1_pearson_de: 0.7321368490226438
test_combo_seen2_mse: 0.00756799
test_combo_seen2_pearson: 0.9770574949375104
test_combo_seen2_mse_de: 0.5423224
test_combo_seen2_pearson_de: 0.8212438305937693
test_unseen_single_mse: 0.0046949345
test_unseen_single_pearson: 0.9860701306773928
test_unseen_single_mse_de: 0.39887267
test_unseen_single_pearson_de: 0.850946819858882
test_combo_seen0_pearson_delta: 0.0232733180718116
test_combo_seen0_frac_opposite_direction_top20_non_dropout: 0.4166666666666667
test_combo_seen0_frac_sigma_below_1_non_dropout: 0.15
test_combo_se

Start Training...
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.5410
Start doing subgroup analysis for simulation split...
test_combo_seen0_mse: 0.019676955
test_combo_seen0_pearson: 0.9424961528663598
test_combo_seen0_mse_de: 0.75205344
test_combo_seen0_pearson_de: 0.4237033897564591
test_combo_seen1_mse: 0.008966429
test_combo_seen1_pearson: 0.9731722353195832
test_combo_seen1_mse_de: 0.63086444
test_combo_seen1_pearson_de: 0.7321368490226438
test_combo_seen2_mse: 0.00756799
test_combo_seen2_pearson: 0.9770574949375104
test_combo_seen2_mse_de: 0.5423224
test_combo_seen2_pearson_de: 0.8212438305937693
test_unseen_single_mse: 0.0046949345
test_unseen_single_pearson: 0.9860701306773928
test_unseen_single_mse_de: 0.39887267
test_unseen_single_pearson_de: 0.850946819858882
test_combo_seen0_pearson_delta: 0.0232733180718116
test_combo_seen0_frac_opposite_direction_top20_non_dropout: 0.4166666666666667
test_combo_seen0_frac_sigma_below_1_non_dropout: 0.15
test_combo_se

## GEARS model

In [None]:
#optimal hyperparams accoding to supplementary note 22 of gears publication

epoch = 15
no_perturb = False
savepath = "/lustre/groups/ml01/workspace/leander.dony/projects/cellflow/241106_gears/evaluate_gears/norman/output/adatas_gearsoptimal/"

for seed in range(1,6):
    pert_data = PertData(DATA_DIR[:-1], gene_set_path=DATA_DIR + "essential_norman.pkl")
    pert_data.load(data_path = DATA_DIR + "norman2019")
    pert_data.prepare_split(split = "simulation", seed = seed)
    pert_data.get_dataloader(batch_size = batch_size, test_batch_size = batch_size)
    
    gears_model = GEARS(
        pert_data, 
        device = "cuda:" + str(device), 
        weight_bias_track = False, 
        proj_name = "norman2019",
        exp_name = "gears_seed" + str(seed)
    )
    
    gears_model.model_initialize(num_similar_genes_co_express_graph = 5, no_perturb = no_perturb,  go_path = DATA_DIR + "go_essential_norman.csv")
    gears_model.train(epochs = epoch)
    results = {
        "res_test": evaluate(gears_model.dataloader['test_loader'], gears_model.best_model, gears_model.config['uncertainty'], gears_model.device),
        "res_val": evaluate(gears_model.dataloader['val_loader'], gears_model.best_model, gears_model.config['uncertainty'], gears_model.device),
    }
    
    for split in ["val", "test"]:
        for kind in ["truth", "pred"]:
            adata = anndata.AnnData(
                X=results[f"res_{split}"][kind],
                var=pert_data.adata.var,
                obs=pd.DataFrame({
                    "condition": results[f"res_{split}"]["pert_cat"], 
                    "subgroup": pd.Series(results[f"res_{split}"]["pert_cat"]).map({v: k for k, vs in pert_data.subgroup[f"{split}_subgroup"].items() for v in vs})
                }),
                uns={"rank_genes_groups_cov_all": de_dict},
            )
            for subgroup in adata.obs["subgroup"].unique():
                adata[adata.obs["subgroup"] == subgroup].write(os.path.join(savepath, f"norman_gearsoptimal_seed{seed}_{split}_{subgroup.replace("_", "")}_{kind}.h5ad"), compression="gzip")