In [1]:
import os
os.environ['TMP'] = '/home/share/huadjyin/home/zhoumin3/tmp'
os.environ['TEMP'] = '/home/share/huadjyin/home/zhoumin3/tmp'
os.environ['WANDB_DIR'] = '/home/share/huadjyin/home/zhoumin3/tmp'

In [2]:
!python /home/share/huadjyin/home/zhoumin3/zhoumin/gears/GEARS_misc/paper/CPA_reproduce/model.py

In [3]:
import os
import json
import argparse
import torch
import numpy as np
import sys

In [4]:
from collections import defaultdict
sys.path.append('/home/share/huadjyin/home/zhoumin3/zhoumin/gears/GEARS_misc/paper/CPA_reproduce')
from data import load_dataset_splits
from model import ComPert

In [5]:
from sklearn.metrics import r2_score, balanced_accuracy_score, make_scorer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
import time

In [6]:
def pjson(s):
    """
    Prints a string in JSON format and flushes stdout
    """
    print(json.dumps(s), flush=True)


def evaluate_disentanglement(autoencoder, dataset, nonlinear=False):
    """
    Given a ComPert model, this function measures the correlation between
    its latent space and 1) a dataset's drug vectors 2) a datasets covariate
    vectors.

    """
    _, latent_basal = autoencoder.predict(
        dataset.genes,
        dataset.drugs,
        dataset.cell_types,
        return_latent_basal=True)

    latent_basal = latent_basal.detach().cpu().numpy()

    if nonlinear:
        clf = KNeighborsClassifier(
            n_neighbors=int(np.sqrt(len(latent_basal))))
    else:
        clf = LogisticRegression(solver="liblinear",
                                 multi_class="auto",
                                 max_iter=10000)

    pert_scores = cross_val_score(
        clf,
        StandardScaler().fit_transform(latent_basal), dataset.drugs_names,
        scoring=make_scorer(balanced_accuracy_score), cv=5, n_jobs=-1)

    if len(np.unique(dataset.cell_types_names)) > 1:
        cov_scores = cross_val_score(
            clf,
            StandardScaler().fit_transform(latent_basal), dataset.cell_types_names,
            scoring=make_scorer(balanced_accuracy_score), cv=5, n_jobs=-1)
        return np.mean(pert_scores), np.mean(cov_scores)
    else:
        return np.mean(pert_scores), 0



def evaluate_r2(autoencoder, dataset, genes_control):
    """
    Measures different quality metrics about an ComPert `autoencoder`, when
    tasked to translate some `genes_control` into each of the drug/cell_type
    combinations described in `dataset`.

    Considered metrics are R2 score about means and variances for all genes, as
    well as R2 score about means and variances about differentially expressed
    (_de) genes.
    """

    mean_score, var_score, mean_score_de, var_score_de = [], [], [], []
    num, dim = genes_control.size(0), genes_control.size(1)

    total_cells = len(dataset)

    for pert_category in np.unique(dataset.pert_categories):
        # pert_category category contains: 'celltype_perturbation_dose' info
        de_idx = np.where(
            dataset.var_names.isin(
                np.array(dataset.de_genes[pert_category])))[0]

        idx = np.where(dataset.pert_categories == pert_category)[0]

        if len(idx) > 30:
            emb_drugs = dataset.drugs[idx][0].view(
                1, -1).repeat(num, 1).clone()
            emb_cts = dataset.cell_types[idx][0].view(
                1, -1).repeat(num, 1).clone()

            genes_predict = autoencoder.predict(
                genes_control, emb_drugs, emb_cts).detach().cpu()

            mean_predict = genes_predict[:, :dim]
            var_predict = genes_predict[:, dim:]

            # estimate metrics only for reasonably-sized drug/cell-type combos

            y_true = dataset.genes[idx, :].numpy()

            # true means and variances
            yt_m = y_true.mean(axis=0)
            yt_v = y_true.var(axis=0)
            # predicted means and variances
            yp_m = mean_predict.mean(0)
            yp_v = var_predict.mean(0)

            mean_score.append(r2_score(yt_m, yp_m))
            var_score.append(r2_score(yt_v, yp_v))

            try:
                mean_score_de.append(r2_score(yt_m[de_idx], yp_m[de_idx]))
                var_score_de.append(r2_score(yt_v[de_idx], yp_v[de_idx]))
            except:
                mean_score_de.append(0)
                var_score_de.append(0)

    return [np.mean(s) if len(s) else -1
            for s in [mean_score, mean_score_de, var_score, var_score_de]]


def evaluate(autoencoder, datasets):
    """
    Measure quality metrics using `evaluate()` on the training, test, and
    out-of-distributiion (ood) splits.
    """

    autoencoder.eval()
    with torch.no_grad():
        stats_test = evaluate_r2(
            autoencoder,
            datasets["test_treated"],
            datasets["test_control"].genes)

        stats_disent_pert, stats_disent_cov = evaluate_disentanglement(
            autoencoder, datasets["test"])

        evaluation_stats = {
            "training": evaluate_r2(
                autoencoder,
                datasets["training_treated"],
                datasets["training_control"].genes),
            "test": stats_test,
            "ood": evaluate_r2(
                autoencoder,
                datasets["ood"],
                datasets["test_control"].genes),
            "perturbation disentanglement": stats_disent_pert,
            "optimal for perturbations": 1/datasets['test'].num_drugs,
            "covariate disentanglement": stats_disent_cov,
            "optimal for covariates": 1/datasets['test'].num_cell_types,
        }
    autoencoder.train()
    return evaluation_stats


def prepare_compert(args, state_dict=None):
    """
    Instantiates autoencoder and dataset to run an experiment.
    """

    device = "cuda:"+str(args['cuda']) if torch.cuda.is_available() else "cpu"
    #device = "cpu"
    print('-1')
    datasets = load_dataset_splits(
        args["dataset_path"],
        args["perturbation_key"],
        args["dose_key"],
        args["cell_type_key"],
        args["split_key"])
    print('-2')
    autoencoder = ComPert(
        datasets["training"].num_genes,
        datasets["training"].num_drugs,
        datasets["training"].num_cell_types,
        device=device,
        seed=args["seed"],
        loss_ae=args["loss_ae"],
        doser_type=args["doser_type"],
        patience=args["patience"],
        hparams=args["hparams"],
        decoder_activation=args["decoder_activation"],
        emb_kg = args["emb"],
        num_pert_in_graph = args['num_pert_in_graph']
    )
    print('-3')
    if state_dict is not None:
        autoencoder.load_state_dict(state_dict)

    return autoencoder, datasets


In [None]:
def train_compert(args, return_model=False):
    """
    Trains a ComPert autoencoder
    """
    dataset = args["dataset"]
    
    import pandas as pd

    if dataset == 'AdamsonWeissman2016_GSM2406675_1':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406675_1.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'AdamsonWeissman2016_GSM2406677_2':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406677_2.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'AdamsonWeissman2016_GSM2406681_3':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406681_3.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'DatlingerBock2017_stimulated':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_datlingerbock2017_stimulated.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'DatlingerBock2017_unstimulated':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_datlingerbock2017_unstimulated.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'DatlingerBock2021_stimulated':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_datlingerbock2021_stimulated.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'DatlingerBock2021_unstimulated':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_datlingerbock2021_unstimulated.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Dixit_combined':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_dixit_combined.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Dixit_GSM2396858':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_dixit_gsm2396858.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Dixit_GSM2396861':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_dixit_gsm2396861.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'NormanWeissman2019':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_normanweissman2019.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'PapalexiSatija2021_eccite_arrayed_RNA':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_papalexisatija2021_eccite_arrayed_rna.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'PapalexiSatija2021_eccite_RNA':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_papalexisatija2021_eccite_rna.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Replogle_k562_essential':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_replogle_k562_essential.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Replogle_rpe1_essential':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_replogle_rpe1_essential.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'ShifrutMarson2018':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_shifrutmarson2018.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'TianKampmann2019_day7neuron':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_tiankampmann2019_day7neuron.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'TianKampmann2019_iPSC':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_tiankampmann2019_ipsc.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'TianKampmann2021_CRISPRa':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_tiankampmann2021_crispra.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'TianKampmann2021_CRISPRi':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_tiankampmann2021_crispri.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'XuCao2023':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_xucao2023.csv'
        print(f"go_path: {go_path}")
    else:
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/de_data_gears/go_essential_all/go_essential_all.csv'
        print(f"go_path: {go_path}")
        
    df = pd.read_csv(go_path)
    df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)
        
    gene_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/06cpa/dataset_all_gene_names.pkl'
      
    def get_map(pert):
        tmp = np.zeros(len(gene_list))
        tmp[np.where(np.in1d(gene_list, df[df.target == pert].source.values))[0]] = 1
        return tmp    
    
    import pickle
    with open(gene_path, 'rb') as f:
        gene_list = pickle.load(f)
    args['num_pert_in_graph'] = len(gene_list)
    
    autoencoder, datasets = prepare_compert(args)
    
    pert_dict = datasets["training"].perts_dict
   
    
    pert2neighbor =  {i: get_map(i) for i in list(pert_dict.keys())}    
    pert_dict_rev = {'+'.join([str(x) for x in np.where(j == 1)[0]]): i for i,j  in pert_dict.items()}
    print(pert_dict_rev)
    print(pert2neighbor)
    autoencoder.pert2neighbor = pert2neighbor
    autoencoder.pert_dict_rev = pert_dict_rev
    
    datasets.update({
        "loader_tr": torch.utils.data.DataLoader(
                        datasets["training"],
                        batch_size=autoencoder.hparams["batch_size"],
                        shuffle=True,
                        drop_last=True)
    })
    
    print(datasets["training"].perts_dict)
    
    pjson({"training_args": args})
    pjson({"autoencoder_params": autoencoder.hparams})

    start_time = time.time()
    for epoch in range(args["max_epochs"]):
        epoch_training_stats = defaultdict(float)

        for genes, drugs, cell_types in datasets["loader_tr"]:
            minibatch_training_stats = autoencoder.update(
                genes, drugs, cell_types)

            for key, val in minibatch_training_stats.items():
                epoch_training_stats[key] += val

        for key, val in epoch_training_stats.items():
            epoch_training_stats[key] = val / len(datasets["loader_tr"])
            if not (key in autoencoder.history.keys()):
                autoencoder.history[key] = []
            autoencoder.history[key].append(val)
        autoencoder.history['epoch'].append(epoch)

        ellapsed_minutes = (time.time() - start_time) / 60
        autoencoder.history['elapsed_time_min'] = ellapsed_minutes

        # decay learning rate if necessary
        # also check stopping condition: patience ran out OR
        # time ran out OR max epochs achieved
        stop = ellapsed_minutes > args["max_minutes"] or \
            (epoch == args["max_epochs"] - 1)

        if (epoch % args["checkpoint_freq"]) == 0 or stop:
            #evaluation_stats = evaluate(autoencoder, datasets)
            #for key, val in evaluation_stats.items():
            #    if not (key in autoencoder.history.keys()):
            #        autoencoder.history[key] = []
            #    autoencoder.history[key].append(val)
            #autoencoder.history['stats_epoch'].append(epoch)

            pjson({
                "epoch": epoch,
                "training_stats": epoch_training_stats,
                #"evaluation_stats": evaluation_stats,
                "ellapsed_minutes": ellapsed_minutes
            })

            # autoencoder.state_dict(),
            if epoch == 499:
                if args['emb'] == 'kg':
                    torch.save(
                        autoencoder.state_dict(),
                        os.path.join(
                            args["save_dir"],
                            "model_kg_seed={}_epoch={}.pt".format(args["seed"], epoch)))
                else:
                    torch.save(
                        autoencoder.state_dict(),
                        os.path.join(
                            args["save_dir"],
                            "model_seed={}_epoch={}.pt".format(args["seed"], epoch)))
            """
            # Some error when I tried to include the history + args
            torch.save(
                autoencoder.state_dict(), args, autoencoder.history,
                os.path.join(
                    args["save_dir"],
                    "model_seed={}_epoch={}.pt".format(args["seed"], epoch)))
            """

            pjson({"model_saved": "model_seed={}_epoch={}.pt\n".format(
                args["seed"], epoch)})
            #stop = stop or autoencoder.early_stopping(
            #    np.mean(evaluation_stats["test"]))
            #if stop:
            #    pjson({"early_stop": epoch})
            #    break

    if return_model:
        return autoencoder, datasets


def parse_arguments(dataset, split_number):
    
    if split_number not in range(1, 6):
        raise ValueError("split_number")

    base_args = {
        'dataset_path': f'{dataset}.h5ad',
        'dataset': dataset,
        'perturbation_key': "condition",
        'dose_key': "dose_val",
        'cell_type_key': "cell_type",
        'loss_ae': 'gauss',
        'doser_type': 'sigm',
        'decoder_activation': 'linear',
        'seed': 0,
        'hparams': "",
        'cuda': 0,
        'max_epochs': 500,
        'max_minutes': 400,
        'patience': 20,
        'checkpoint_freq': 20,
        'sweep_seeds': 200,
        'emb': 'one_hot'
    }
    
    split_key = f'split{split_number}'
    save_dir = f'/home/share/huadjyin/home/zhoumin3/zhoumin/model_benchmark/01_A_results/{dataset}/cpa/{split_key}/'
    os.makedirs(save_dir, exist_ok=True)
    args = base_args.copy()
    args['split_key'] = split_key
    args['save_dir'] = save_dir
    
    return args

In [None]:
import wandb
import logging
from logging import StreamHandler
dataset = "AdamsonWeissman2016_GSM2406675_1"
split_number = 1
wandb.init(project='01_dataset_all_cpa', name=f'{dataset}_split{split_number}', config={
        "seed": split_number,
        "data_name": dataset
    })
logger = logging.getLogger()
logger.setLevel(logging.INFO)

class WandbHandler(StreamHandler):
    def emit(self, record):
        log_entry = self.format(record)
        wandb.log({"log": log_entry})
wandb_handler = WandbHandler()
wandb_handler.setLevel(logging.INFO)
logger.addHandler(wandb_handler)

class WandbStream:
    def __init__(self):
        self.buffer = []

    def write(self, message):
        if message.strip() != "":  
            self.buffer.append(message)
            wandb.log({"log": message.strip()})
        sys.__stdout__.write(message)

    def flush(self):
        self.buffer.clear()

sys.stdout = WandbStream()
sys.stderr = WandbStream()

[34m[1mwandb[0m: Currently logged in as: [33mzhoumin1130[0m. Use [1m`wandb login --relogin`[0m to force relogin


go_path: /home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406675_1.csv
  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)
-1
--1
---1
---2
---3
---4
---5
  self.drugs = torch.Tensor(drugs)
---6
---7
--2
-2
-3
{'0': 'BHLHE40', '2': 'DDIT3', '5': 'SPI1', '7': 'ctrl', '1': 'CREB1', '6': 'ZNF326', '4': 'SNAI1', '3': 'EP300'}
{'BHLHE40': array([0., 0., 0., ..., 0., 0., 0.]), 'DDIT3': array([0., 0., 0., ..., 0., 0., 0.]), 'SPI1': array([0., 0., 0., ..., 0., 0., 0.]), 'ctrl': array([0., 0., 0., ..., 0., 0., 0.]), 'CREB1': array([0., 0., 0., ..., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., ..., 0., 0., 0.]), 'SNAI1': array([0., 0., 0., ..., 0., 0., 0.]), 'EP300': array([0., 0., 0., ..., 0., 0., 0.])}
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0.,

  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'0': 'BHLHE40', '2': 'DDIT3', '5': 'SPI1', '7': 'ctrl', '1': 'CREB1', '6': 'ZNF326', '4': 'SNAI1', '3': 'EP300'}
{'BHLHE40': array([0., 0., 0., ..., 0., 0., 0.]), 'DDIT3': array([0., 0., 0., ..., 0., 0., 0.]), 'SPI1': array([0., 0., 0., ..., 0., 0., 0.]), 'ctrl': array([0., 0., 0., ..., 0., 0., 0.]), 'CREB1': array([0., 0., 0., ..., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., ..., 0., 0., 0.]), 'SNAI1': array([0., 0., 0., ..., 0., 0., 0.]), 'EP300': array([0., 0., 0., ..., 0., 0., 0.])}
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}
{"training_args": {"dataset_path": "AdamsonWeissman201

  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'0': 'BHLHE40', '2': 'DDIT3', '5': 'SPI1', '7': 'ctrl', '1': 'CREB1', '6': 'ZNF326', '4': 'SNAI1', '3': 'EP300'}
{'BHLHE40': array([0., 0., 0., ..., 0., 0., 0.]), 'DDIT3': array([0., 0., 0., ..., 0., 0., 0.]), 'SPI1': array([0., 0., 0., ..., 0., 0., 0.]), 'ctrl': array([0., 0., 0., ..., 0., 0., 0.]), 'CREB1': array([0., 0., 0., ..., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., ..., 0., 0., 0.]), 'SNAI1': array([0., 0., 0., ..., 0., 0., 0.]), 'EP300': array([0., 0., 0., ..., 0., 0., 0.])}
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}
{"training_args": {"dataset_path": "AdamsonWeissman201

  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'0': 'BHLHE40', '2': 'DDIT3', '5': 'SPI1', '7': 'ctrl', '1': 'CREB1', '6': 'ZNF326', '4': 'SNAI1', '3': 'EP300'}
{'BHLHE40': array([0., 0., 0., ..., 0., 0., 0.]), 'DDIT3': array([0., 0., 0., ..., 0., 0., 0.]), 'SPI1': array([0., 0., 0., ..., 0., 0., 0.]), 'ctrl': array([0., 0., 0., ..., 0., 0., 0.]), 'CREB1': array([0., 0., 0., ..., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., ..., 0., 0., 0.]), 'SNAI1': array([0., 0., 0., ..., 0., 0., 0.]), 'EP300': array([0., 0., 0., ..., 0., 0., 0.])}
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}
{"training_args": {"dataset_path": "AdamsonWeissman201

  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'0': 'BHLHE40', '2': 'DDIT3', '5': 'SPI1', '7': 'ctrl', '1': 'CREB1', '6': 'ZNF326', '4': 'SNAI1', '3': 'EP300'}
{'BHLHE40': array([0., 0., 0., ..., 0., 0., 0.]), 'DDIT3': array([0., 0., 0., ..., 0., 0., 0.]), 'SPI1': array([0., 0., 0., ..., 0., 0., 0.]), 'ctrl': array([0., 0., 0., ..., 0., 0., 0.]), 'CREB1': array([0., 0., 0., ..., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., ..., 0., 0., 0.]), 'SNAI1': array([0., 0., 0., ..., 0., 0., 0.]), 'EP300': array([0., 0., 0., ..., 0., 0., 0.])}
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}
{"training_args": {"dataset_path": "AdamsonWeissman201

  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}


100%|##########| 2/2 [00:00<00:00, 47.32it/s]
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl


split1 computation completed
go_path: /home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406675_1.csv


  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}


100%|##########| 2/2 [00:00<00:00, 45.42it/s]
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl


split2 computation completed
go_path: /home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406675_1.csv


  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}


100%|##########| 2/2 [00:00<00:00, 52.73it/s]
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl


split3 computation completed
go_path: /home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406675_1.csv


  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}


100%|##########| 2/2 [00:00<00:00, 41.12it/s]
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl


split4 computation completed
go_path: /home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406675_1.csv


  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)


-1
--1
---1
---2
---3
---4
---5
---6
---7
--2
-2
-3
{'BHLHE40': array([1., 0., 0., 0., 0., 0., 0., 0.]), 'DDIT3': array([0., 0., 1., 0., 0., 0., 0., 0.]), 'SPI1': array([0., 0., 0., 0., 0., 1., 0., 0.]), 'ctrl': array([0., 0., 0., 0., 0., 0., 0., 1.]), 'CREB1': array([0., 1., 0., 0., 0., 0., 0., 0.]), 'ZNF326': array([0., 0., 0., 0., 0., 0., 1., 0.]), 'SNAI1': array([0., 0., 0., 0., 1., 0., 0., 0.]), 'EP300': array([0., 0., 0., 1., 0., 0., 0., 0.])}


100%|##########| 2/2 [00:00<00:00, 45.18it/s]
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl
  fold_change = pert_mean/ctrl


split5 computation completed


In [9]:
train_compert(parse_arguments(dataset, split_number))

In [10]:
wandb.finish()

0,1
log,"{""model_saved"": ""mod..."


In [11]:
for split_number in range(2, 6):
    args = parse_arguments(dataset, split_number)
    wandb.init(project='01_dataset_all_cpa', name=f'{dataset}_split{split_number}', config=args)
    print(f"Training for split {split_number}...")
    train_compert(args)

    wandb.finish()

0,1
log,Retrying (Retry(tota...


0,1
log,Retrying (Retry(tota...


0,1
log,Retrying (Retry(tota...


In [None]:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__

In [12]:
from pprint import pprint
from train import train_compert
from data import load_dataset_splits
from api import ComPertAPI
import pandas as pd
import scanpy as sc
import numpy as np
from sklearn.metrics import r2_score
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import mean_absolute_error as mae
import pandas as pd
import numpy as np
import networkx as nx
import statsmodels.stats.api as sms
from scipy.stats import ncx2
from os.path import isfile
import scanpy as sc
from tqdm import tqdm
import sys
sys.path.append('/home/share/huadjyin/home/fengtiannan/zhoumin/gears/GEARS_misc/legacy/')
from flow import get_graph, get_expression_data,\
            add_weight, I_TF, get_TFs, solve,\
            solve_parallel, get_expression_lambda


In [None]:
dataset_name = 'AdamsonWeissman2016_GSM2406675_1'
kg_mode = False
if kg_mode:
    kg_str = '_kg'
    emb_model = 'kg'
else:
    kg_str = ''
    emb_model = 'one_hot'
for split_num in range(1, 6):
    import pickle
    import torch
    from train import prepare_compert
    model_name = f'/home/share/huadjyin/home/zhoumin3/zhoumin/model_benchmark/01_A_results/{dataset_name}/cpa/split' + str(split_num) + '/model' + kg_str + '_seed=0_epoch=499.pt'
    #model_name ='./model_seed=0_epoch=499.pt'
    #state, args, history = torch.load(model_name, map_location=torch.device('cpu'))
    state  = torch.load(model_name, map_location=torch.device('cpu'))
    fname = f'/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/06cpa/{dataset_name}.h5ad'
    autoencoder_params = { 'adversary_depth': 4,
                              'adversary_lr': 0.0001875455179637405,
                              'adversary_steps': 3,
                              'adversary_wd': 0.00019718137187038062,
                              'adversary_width': 128,
                              'autoencoder_depth': 4,
                              'autoencoder_lr': 0.0011021870411382655,
                              'autoencoder_wd': 1.1455862519513426e-05,
                              'autoencoder_width': 512,
                              'batch_size': 128,
                              'dim': 256,
                              'dosers_depth': 2,
                              'dosers_lr': 0.00026396192072937485,
                              'dosers_wd': 7.165810318386074e-07,
                              'dosers_width': 64,
                              'penalty_adversary': 8.735507132389051,
                              'reg_adversary': 69.6011204833175,
                              'step_size_lr': 25}
    
    args = {'dataset_path': fname, # full path to the anndata dataset 
            'cell_type_key': 'cell_type', # necessary field for cell types. Fill it with a dummy variable if no celltypes present.
            'split_key': 'split' + str(split_num), # necessary field for train, test, ood splits.
            'perturbation_key': 'condition', # necessary field for perturbations
            'dose_key': 'dose_val', # necessary field for dose. Fill in with dummy variable if dose is the same. 
            'checkpoint_freq': 20, # checkoint frequencty to save intermediate results
            'hparams': "", #autoencoder_params, # autoencoder architecture
            'max_epochs': 500, # maximum epochs for training
            'max_minutes': 400, # maximum computation time
            'patience': 20, # patience for early stopping
            'loss_ae': 'gauss', # loss (currently only gaussian loss is supported)
            'doser_type': 'sigm', # non-linearity for doser function
            'save_dir': f'/home/share/huadjyin/home/zhoumin3/zhoumin/model_benchmark/01_A_results/{dataset_name}/cpa/split{split_num}/', # directory to save the model
            'decoder_activation': 'linear', # last layer of the decoder
            'seed': 0, # random seed
            'sweep_seeds': 0,
            'emb': emb_model,
            'dataset': dataset_name}
    args['cuda'] = 0
    
    dataset = args["dataset"]
    save_dir = args['save_dir'] 
    import pandas as pd
    if dataset == 'AdamsonWeissman2016_GSM2406675_1':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406675_1.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'AdamsonWeissman2016_GSM2406677_2':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406677_2.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'AdamsonWeissman2016_GSM2406681_3':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_adamsonweissman2016_gsm2406681_3.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'DatlingerBock2017_stimulated':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_datlingerbock2017_stimulated.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'DatlingerBock2017_unstimulated':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_datlingerbock2017_unstimulated.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'DatlingerBock2021_stimulated':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_datlingerbock2021_stimulated.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'DatlingerBock2021_unstimulated':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_datlingerbock2021_unstimulated.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Dixit_combined':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_dixit_combined.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Dixit_GSM2396858':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_dixit_gsm2396858.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Dixit_GSM2396861':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_dixit_gsm2396861.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'NormanWeissman2019':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_normanweissman2019.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'PapalexiSatija2021_eccite_arrayed_RNA':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_papalexisatija2021_eccite_arrayed_rna.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'PapalexiSatija2021_eccite_RNA':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_papalexisatija2021_eccite_rna.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Replogle_k562_essential':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_replogle_k562_essential.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'Replogle_rpe1_essential':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_replogle_rpe1_essential.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'ShifrutMarson2018':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_shifrutmarson2018.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'TianKampmann2019_day7neuron':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_tiankampmann2019_day7neuron.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'TianKampmann2019_iPSC':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_tiankampmann2019_ipsc.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'TianKampmann2021_CRISPRa':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_tiankampmann2021_crispra.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'TianKampmann2021_CRISPRi':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_tiankampmann2021_crispri.csv'
        print(f"go_path: {go_path}")
    elif dataset == 'XuCao2023':
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/data/go_essential_xucao2023.csv'
        print(f"go_path: {go_path}")
    else:
        go_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/de_data_gears/go_essential_all/go_essential_all.csv'
        print(f"go_path: {go_path}")
        
    df = pd.read_csv(go_path)
    df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1,['importance'])).reset_index(drop = True)
    gene_path = '/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/06cpa/dataset_all_gene_names.pkl'
    def get_map(pert):
        tmp = np.zeros(len(gene_list))
        tmp[np.where(np.in1d(gene_list, df[df.target == pert].source.values))[0]] = 1
        return tmp    
    
    import pickle
    with open(gene_path, 'rb') as f:
        gene_list = pickle.load(f)
    args['num_pert_in_graph'] = len(gene_list)
    
    
    # load the dataset and model pre-trained weights
    autoencoder, datasets = prepare_compert(args, state_dict=state)
    autoencoder.load_state_dict(state)
    pert_dict = datasets["training"].perts_dict
    
    pert2neighbor =  {i: get_map(i) for i in list(pert_dict.keys())}    
    pert_dict_rev = {'+'.join([str(x) for x in np.where(j == 1)[0]]): i for i,j  in pert_dict.items()}
    autoencoder.pert2neighbor = pert2neighbor
    autoencoder.pert_dict_rev = pert_dict_rev
    
    
    from inference import evaluate, compute_metrics, deeper_analysis, GI_subgroup, non_dropout_analysis, non_zero_analysis
    compert_api = ComPertAPI(datasets, autoencoder)
    adata = sc.read(fname)
    condition2num_of_cells = dict(adata.obs.cov_drug_dose_name.value_counts())
    name_map = dict(adata.obs[['cov_drug_dose_name', 'condition']].drop_duplicates().values)

    dataset = datasets['ood']
    compert_api.model.eval()
    scores = pd.DataFrame(columns=[compert_api.covars_key,
                                    compert_api.perturbation_key,
                                    compert_api.dose_key,
                                    'R2_mean', 'R2_mean_DE', 'R2_var',
                                    'R2_var_DE', 'num_cells'])

    total_cells = len(dataset)

    icond = 0

    pert_cat = []
    pred = []
    truth = []
    pred_de = []
    truth_de = []
    results = {}
    
    for pert_category in tqdm(np.unique(dataset.pert_categories)):
        # pert_category category contains: 'celltype_perturbation_dose' info
        de_idx = np.where(
            dataset.var_names.isin(
                np.array(dataset.de_genes[pert_category])))[0]
        idx = np.where(dataset.pert_categories == pert_category)[0]

        genes_control = datasets['training_control'].genes[np.random.randint(0,
                                            len(datasets['training_control'].genes), condition2num_of_cells[pert_category]), :]
        num, dim = genes_control.size(0), genes_control.size(1)

        pert_cat.extend([name_map[pert_category]] * condition2num_of_cells[pert_category])

        if len(idx) > 0:
            emb_drugs = dataset.drugs[idx][0].view(1, -1).repeat(num, 1).clone()
            emb_cts = dataset.cell_types[idx][0].view(1, -1).repeat(num, 1).clone()

            genes_predict = compert_api.model.predict(
                genes_control, emb_drugs, emb_cts).detach().cpu()
            y_pred = genes_predict[:, :dim]
            y_true = dataset.genes[idx, :].numpy()

            y_true_de = y_true[:, de_idx]
            y_pred_de = y_pred[:, de_idx]

            
            pred.extend(y_pred)
            truth.extend(y_true)
            pred_de.extend(y_pred_de)
            truth_de.extend(y_true_de)
    
    results['pert_cat'] = np.array(pert_cat)
    results['pred'] = torch.stack(pred).detach().numpy()
    results['pred_de'] = torch.stack(pred_de).detach().numpy()
    results['truth'] = np.stack(truth)
    results['truth_de'] = np.stack(truth_de)
    results['pred'] = results['pred']
    results['pred_de'] = results['pred_de']

   
    results_to_save = [
        (results, f"{dataset_name}_split{split_num}_test_res"),
        (compute_metrics(results), f"{dataset_name}_split{split_num}_test_metrics"),
        (deeper_analysis(adata, results, de_column_prefix='top_non_dropout_de_20'), f"{dataset_name}_split{split_num}_deeper_res"),
        (non_dropout_analysis(adata, results), f"{dataset_name}_split{split_num}_non_dropout_res"),
        (non_zero_analysis(adata, results), f"{dataset_name}_split{split_num}_non_zero_res")
    ]

    for result, file_prefix in results_to_save:
        file_path = f"{save_dir}/{file_prefix}.pkl"
        with open(file_path, 'wb') as f:
            pickle.dump(result, f)

    print(f"split{split_num} computation completed")