In [1]:
import os
from gears import PertData, GEARS
import scanpy as sc
import numpy as np
import pickle
import jax
import jax.tree_util as jtu
from scipy.sparse import csr_matrix
from functools import partial
from tqdm import tqdm
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import wandb
from omegaconf import DictConfig, OmegaConf

from typing import Dict

import numpy as np
from ott.geometry import costs, pointcloud
from ott.tools.sinkhorn_divergence import sinkhorn_divergence
from sklearn.metrics import pairwise_distances, r2_score
from sklearn.metrics.pairwise import rbf_kernel

In [2]:
def setup_logger(cfg):
    """Initialize and return a Weights & Biases logger."""
    wandb.login()
    return wandb.init(
        project=cfg['dataset']['wandb_project'],
        # config=OmegaConf.to_container(cfg, resolve=True), #TODO: uncomment omegaconf
        dir="/home/icb/lea.zimmermann/projects/pertot/ot_pert_reproducibility/runs_gears/bash_scripts",
        settings=wandb.Settings(start_method="thread"),
    )

In [3]:
def get_mask(x, y, var_names):
    return x[:, [gene in y for gene in var_names]]


def compute_r_squared(x: np.ndarray, y: np.ndarray) -> float:
    return r2_score(np.mean(x, axis=0), np.mean(y, axis=0))


def compute_sinkhorn_div(x: np.ndarray, y: np.ndarray, epsilon: float) -> float:
    return float(
        sinkhorn_divergence(
            pointcloud.PointCloud,
            x=x,
            y=y,
            cost_fn=costs.SqEuclidean(),
            epsilon=epsilon,
            scale_cost=1.0,
        ).divergence
    )


def compute_e_distance(x: np.ndarray, y: np.ndarray) -> float:
    sigma_X = pairwise_distances(x, x, metric="sqeuclidean").mean()
    sigma_Y = pairwise_distances(y, y, metric="sqeuclidean").mean()
    delta = pairwise_distances(x, y, metric="sqeuclidean").mean()
    return 2 * delta - sigma_X - sigma_Y


def compute_metrics(x: np.ndarray, y: np.ndarray) -> Dict[str, float]:
    metrics = {}
    metrics["r_squared"] = compute_r_squared(x, y)
    metrics["sinkhorn_div_1"] = compute_sinkhorn_div(x, y, epsilon=1.0)
    metrics["sinkhorn_div_10"] = compute_sinkhorn_div(x, y, epsilon=10.0)
    metrics["sinkhorn_div_100"] = compute_sinkhorn_div(x, y, epsilon=100.0)
    metrics["e_distance"] = compute_e_distance(x, y)
    metrics["mmd"] = compute_scalar_mmd(x, y)
    return metrics


def compute_mean_metrics(metrics: Dict[str, Dict[str, float]], prefix: str = ""):
    metric_names = list(list(metrics.values())[0].keys())
    metric_dict = {prefix + met_name: [] for met_name in metric_names}
    for met in metric_names:
        stat = 0.0
        for vals in metrics.values():
            stat += vals[met]
        metric_dict[prefix + met] = stat / len(metrics)
    return metric_dict


def mmd_distance(x, y, gamma):
    xx = rbf_kernel(x, x, gamma)
    xy = rbf_kernel(x, y, gamma)
    yy = rbf_kernel(y, y, gamma)

    return xx.mean() + yy.mean() - 2 * xy.mean()


def compute_scalar_mmd(target, transport, gammas=None):  # from CellOT repo
    if gammas is None:
        gammas = [2, 1, 0.5, 0.1, 0.01, 0.005]

    def safe_mmd(*args):
        try:
            mmd = mmd_distance(*args)
        except ValueError:
            mmd = np.nan
        return mmd

    return np.mean(list(map(lambda x: safe_mmd(target, transport, x), gammas)))


def compute_metrics_fast(x: np.ndarray, y: np.ndarray) -> Dict[str, float]:
    metrics = {}
    metrics["r_squared"] = compute_r_squared(x, y)
    metrics["e_distance"] = compute_e_distance(x, y)
    metrics["mmd_distance"] = compute_scalar_mmd(x, y)
    return metrics


In [4]:
cfg = {
    'dataset': {
        'dataset_path': '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/r2',
        'base_path': '/home/icb/lea.zimmermann/projects/pertot/gears/data/satija',
        'train_data': '/home/icb/lea.zimmermann/projects/pertot/gears/data/satija/satija_ifng_bxpc3',
        'split_data': '/home/icb/lea.zimmermann/projects/pertot/gears/data/satija/satija_ifng_bxpc3/custom_split.pkl',
        'test_data': '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/r2',
        'pathway': 'IFNG',
        'cell_type': 'BXPC3',
        'obsm_key_data': 'X_pca',
        'obsm_key_cond': 'cond_emb',
        'wandb_project': 'gears_satija_test'
    },
    'training': {
        'batch_size': 1024,
        'valid_freq': 10,
        'num_iterations': 10000,
        'out_dir': '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets',
        'n_train_samples': 0,
        'n_test_samples': 10,
        'n_ood_samples': -1,
        'fast_metrics': True,
        'learning_rate': 0.0001,
        'weight_decay': 0.5
    },
    'model': {
        'hidden_dims': [1024, 1024, 1024],
        'time_dims': [512, 512],
        'output_dims': [1024, 1024, 1024],
        'condition_dims': [1024, 1024, 512],
        'time_n_freqs': 1024,
        'flow_noise': 0.1,
        'multi_steps': 20,
        'epsilon': 0.01,
        'tau_a': 1.0,
        'tau_b': 1.0,
        'dropout_rate': 0.0
    },
    'conf': {
        'run': {
            'dir': 'checkpoint_dir'
        }
    }
}

In [5]:
def load_data(adata, cfg, *, skip_cond, deg, return_dl: bool):
    """Loads data and preprocesses it based on configuration."""
    dls = []
    data_source = {}
    data_target = {}
    data_source_decoded = {}
    data_target_decoded = {}
    data_conditions = {}
    genes = []
    for cond in adata.obs["condition"].cat.categories:
        if cond!= 'NT+ctrl' and cond not in skip_cond:
            src_str_unique = list(adata[adata.obs["condition"] == cond].obs["cell_type"].unique())
            assert len(src_str_unique) == 1
            src_str = 'ctrl'
            source = adata[adata.obs["condition"] == src_str].obsm[cfg['dataset']['obsm_key_data']]
            source_decoded = adata[adata.obs["condition"] == src_str].X.A
            target = adata[adata.obs["condition"] == cond].obsm[cfg['dataset']['obsm_key_data']]
            target_decoded = adata[adata.obs["condition"] == cond].X.A
            data_source[cond] = source
            data_target[cond] = target
            data_source_decoded[cond] = source_decoded
            data_target_decoded[cond] = target_decoded
            data_conditions[cond] = cond

    deg_dict = {k: v for k, v in deg.items() if k in data_conditions.keys()}

    return {
        "source": data_source,
        "target": data_target,
        "source_decoded": data_source_decoded,
        "target_decoded": data_target_decoded,
        "conditions": data_conditions,
        "deg_dict": deg_dict,
    }

In [6]:
custom_split_path = './data/satija/satija_ifng_bxpc3/custom_split.pkl'
with open(custom_split_path, 'rb') as file:
    custom_split_dict = pickle.load(file)

In [7]:
pert_data = PertData('./data/satija')
pert_data.load(data_path = '/home/icb/lea.zimmermann/projects/pertot/gears/data/satija/satija_ifng_bxpc3')

Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['RARRES3+ctrl' 'HLA-DQB1+ctrl' 'FMNL2+ctrl' 'PLEK+ctrl' 'SRC+ctrl'
 'IFI16+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [8]:
genes_not_in_GO_graph = ['RARRES3+ctrl','HLA-DQB1+ctrl','FMNL2+ctrl','PLEK+ctrl','SRC+ctrl', 'IFI16+ctrl']

In [9]:
# custom_split_dict['train'] = [gene for gene in custom_split_dict['train'] if gene not in genes_not_in_GO_graph]
# custom_split_dict['test'] = [gene for gene in custom_split_dict['test'] if gene not in genes_not_in_GO_graph]

In [10]:
pert_data.prepare_split(split = 'custom', split_dict_path=custom_split_path, seed = 1) # get data split with seed
pert_data.get_dataloader(batch_size = 512, test_batch_size = 512) # prepare data loader

Creating dataloaders....
Done!


In [11]:
logger = setup_logger(cfg)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlea-zimmermann[0m ([33mmodality_translation[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
gears_model = GEARS(pert_data, device='cuda', 
                    weight_bias_track = True,
                    wandb_logger=logger,
                    proj_name = 'gears_test_satija', 
                    exp_name = 'gears_test_satija')

In [13]:
gears_model.model_initialize(hidden_size = 64)
#gears_model.load_pretrained('test_model_final_split')

Found local copy...


In [14]:
gears_model.tunable_parameters()

{'hidden_size': 'hidden dimension, default 64',
 'num_go_gnn_layers': 'number of GNN layers for GO graph, default 1',
 'num_gene_gnn_layers': 'number of GNN layers for co-expression gene graph, default 1',
 'decoder_hidden_size': 'hidden dimension for gene-specific decoder, default 16',
 'num_similar_genes_go_graph': 'number of maximum similar K genes in the GO graph, default 20',
 'num_similar_genes_co_express_graph': 'number of maximum similar K genes in the co expression graph, default 20',
 'coexpress_threshold': 'pearson correlation threshold when constructing coexpression graph, default 0.4',
 'uncertainty': 'whether or not to turn on uncertainty mode, default False',
 'uncertainty_reg': 'regularization term to balance uncertainty loss and prediction loss, default 1',
 'direction_lambda': 'regularization term to balance direction loss and prediction loss, default 1'}

In [15]:
for i in range(1):
    optimizer = optim.Adam(gears_model.model.parameters(), lr=cfg['training']['learning_rate'], weight_decay = cfg['training']['weight_decay'])
    scheduler = StepLR(optimizer, step_size=1, gamma=0.5)
    gears_model.train(epochs = 1, lr = 1e-4, optimizer=optimizer, scheduler=scheduler)
    scheduler.step()

Start Training...
Epoch 1 Step 1 Train Loss: 2.8046
Epoch 1 Step 51 Train Loss: 2.7799
Epoch 1: Train Overall MSE: 0.0030 Validation Overall MSE: 0.0018. 
Train Top 20 DE MSE: 0.1014 Validation Top 20 DE MSE: 0.0427. 
Done!
Start Testing...
Best performing model: Test Top 20 DE MSE: 0.1927
Done!


In [40]:
#gears_model.save_model('test_model_final_split')

In [16]:
def eval_step(cfg, model, data, log_metrics, comp_metrics_fn, mask_fn, PCs, train_mean):
    for split, dat in data.items():
        if split == "test":
            n_samples = cfg['training']['n_test_samples']
        if split == "ood":
            n_samples = cfg['training']['n_ood_samples']

        if n_samples != 0:
            if n_samples > 0:
                idcs = np.random.choice(list(list(dat.values())[0]), n_samples)
                dat_conditions = {k: v for k, v in dat["conditions"].items() if k in idcs}
                dat_deg_dict = {k: v for k, v in dat["deg_dict"].items() if k in idcs}
                dat_target = {k: v for k, v in dat["target"].items() if k in idcs}
                dat_target_decoded = {k: v for k, v in dat["target_decoded"].items() if k in idcs}
            else:
                dat_conditions = dat["conditions"]
                dat_deg_dict = dat["deg_dict"]
                dat_target = dat["target"]
                dat_target_decoded = dat["target_decoded"]
            
            predictions = {}
            predictions_pca = {}
            for k, v in dat_target_decoded.items():
                cond = dat_conditions[k]
                gene = cond.split('+')[0]
                samples = np.zeros((v.shape[0], v.shape[1]))
                for i in range(samples.shape[0]):
                    samples[i] = model.predict([[gene]])[gene]
                predictions[k] = samples
                samples_centered = csr_matrix(samples - train_mean)
                predictions_pca[k] = np.matmul(samples_centered.A, PCs)

            metrics = jtu.tree_map(comp_metrics_fn, dat_target, predictions_pca)
            mean_metrics = compute_mean_metrics(metrics, prefix=f"{split}_")
            log_metrics.update(mean_metrics)

            metrics_decoded = jtu.tree_map(comp_metrics_fn, dat_target_decoded, predictions) # TODO: besser verstehen
            mean_metrics_decoded = compute_mean_metrics(metrics_decoded, prefix=f"decoded_{split}_")
            log_metrics.update(mean_metrics_decoded)

            prediction_decoded_deg = jtu.tree_map(mask_fn, predictions, dat_deg_dict)
            target_decoded_deg = jax.tree_util.tree_map(mask_fn, dat_target_decoded, dat_deg_dict)
            metrics_deg = jtu.tree_map(comp_metrics_fn, target_decoded_deg, prediction_decoded_deg)
            mean_metrics_deg = compute_mean_metrics(metrics_deg, prefix=f"deg_{split}_")
            log_metrics.update(mean_metrics_deg)
            wandb.log(log_metrics)
    return log_metrics

In [17]:
pathway = 'IFNG'
cell_type = 'BXPC3'
data_path = cfg['dataset']['test_data']
adata_train = sc.read_h5ad(os.path.join(data_path, "adata_train_" + pathway + "_" + cell_type + ".h5ad"))
train_mean = adata_train.varm["X_train_mean"].T
PCs = adata_train.varm["PCs"]
mask_fn = partial(get_mask, var_names=adata_train.var_names)
adata_test = sc.read_h5ad(os.path.join(data_path, "adata_test_" + pathway + "_" + cell_type + ".h5ad"))
adata_test.obs['condition'] = adata_test.obs['condition'].apply(lambda x: x.split('_')[-1] + '+ctrl')
adata_ood = sc.read_h5ad(os.path.join(data_path, "adata_ood_" + pathway + "_" + cell_type + ".h5ad"))
adata_ood.obs['condition'] = adata_ood.obs['condition'].apply(lambda x: x.split('_')[-1] + '+ctrl')

deg = {k.split('_')[1]+'+ctrl': v  for k, v in adata_train.uns["rank_genes_groups_cov_all"].items()}
testset = load_data(adata_test, cfg, skip_cond=genes_not_in_GO_graph, deg=deg, return_dl=False)
oodset = load_data(adata_ood, cfg, skip_cond=genes_not_in_GO_graph, deg=deg, return_dl=False)
test_data = {
    'test': testset,
    'ood': oodset
}

In [18]:
dat_conditions = oodset["conditions"]
dat_deg_dict = oodset["deg_dict"]
dat_target = oodset["target"]
dat_target_decoded = oodset["target_decoded"]

In [20]:
predictions = {}
predictions_pca = {}
for k, v in dat_target_decoded.items():
    cond = dat_conditions[k]
    gene = cond.split('+')[0]
    samples = np.zeros((v.shape[0], v.shape[1]))
    for i in tqdm(range(samples.shape[0])):
        samples[i] = gears_model.predict([[gene]])[gene]
        gears_model.saved_pred = {}
    predictions[k] = samples
    samples_centered = csr_matrix(samples - train_mean)
    predictions_pca[k] = np.matmul(samples_centered.A, PCs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1762/1762 [01:57<00:00, 14.97it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3002/3002 [03:21<00:00, 14.87it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1055/1055 [01:10<00:00, 14.87it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1012/1012 [01:08<00:00, 14.88it/s]
100%|███████████████████████

In [21]:
predictions

{'ATF5+ctrl': array([[ 1.02094603,  0.03732041,  0.24941753, ..., -0.01482642,
          0.00944634, -0.00974088],
        [ 0.99692935,  0.09204905,  0.29063824, ..., -0.01482642,
          0.00944634, -0.00974088],
        [ 0.93782216,  0.05852691,  0.2642433 , ..., -0.01482642,
          0.00944634, -0.00974088],
        ...,
        [ 0.99005729,  0.05048174,  0.2993288 , ..., -0.01482642,
          0.00944634, -0.00974088],
        [ 1.01992083,  0.05399077,  0.31296879, ..., -0.01482642,
          0.00944634, -0.00974088],
        [ 1.04258752,  0.04126821,  0.26125166, ..., -0.01482642,
          0.00944634, -0.00974088]]),
 'EHF+ctrl': array([[ 1.06576431,  0.03177255,  0.21048012, ..., -0.01465773,
          0.01149083, -0.0070027 ],
        [ 1.02943575,  0.04352525,  0.24986631, ..., -0.01465773,
          0.01149083, -0.0070027 ],
        [ 1.04573739,  0.05694483,  0.26531848, ..., -0.01465773,
          0.01149083, -0.0070027 ],
        ...,
        [ 1.12565029,  0.0467

In [22]:
comp_metrics_fn = compute_metrics_fast if cfg['training']['fast_metrics'] == True else compute_metrics
logger = {}

In [23]:
eval_step(cfg, gears_model, test_data, logger, compute_metrics_fast, mask_fn, PCs, train_mean)

prediction keys:  dict_keys(['ATF3+ctrl', 'CEBPE+ctrl', 'ETV7+ctrl', 'IFNGR1+ctrl', 'IFNGR2+ctrl', 'IRF9+ctrl', 'JUN+ctrl', 'PRDM1+ctrl', 'STAT3+ctrl'])
dict_keys(['ATF3+ctrl', 'CEBPE+ctrl', 'ETV7+ctrl', 'IFNGR1+ctrl', 'IFNGR2+ctrl', 'IRF9+ctrl', 'JUN+ctrl', 'PRDM1+ctrl', 'STAT3+ctrl'])
prediction keys:  dict_keys(['ATF5+ctrl', 'EHF+ctrl', 'IRF7+ctrl', 'MAFB+ctrl', 'STAT1+ctrl', 'TRAFD1+ctrl'])
dict_keys(['ATF5+ctrl', 'EHF+ctrl', 'IRF7+ctrl', 'MAFB+ctrl', 'STAT1+ctrl', 'TRAFD1+ctrl'])


{'test_r_squared': -0.13263206607339656,
 'test_e_distance': 9.333805714100865,
 'test_mmd_distance': 0.7411452757046997,
 'decoded_test_r_squared': 0.9564043227957176,
 'decoded_test_e_distance': 20.553277096523377,
 'decoded_test_mmd_distance': 0.9352270529716594,
 'deg_test_r_squared': 0.711557366690024,
 'deg_test_e_distance': 7.38784069975603,
 'deg_test_mmd_distance': 0.6890335084127872,
 'ood_r_squared': -1.3054404151460541,
 'ood_e_distance': 11.966043398017737,
 'ood_mmd_distance': 0.7372170868247233,
 'decoded_ood_r_squared': 0.9585772880291242,
 'decoded_ood_e_distance': 19.234332002478236,
 'decoded_ood_mmd_distance': 0.926436611801701,
 'deg_ood_r_squared': 0.5104898641044103,
 'deg_ood_e_distance': 10.102668101192448,
 'deg_ood_mmd_distance': 0.6970982322128306}