In [1]:
import hydra
import wandb
from ott.neural import datasets
import sys
from omegaconf import DictConfig
import jax.numpy as jnp
from jax import random
from typing import Optional, Literal
import jax
import pathlib
import optax
import yaml
from datetime import datetime
from flax import linen as nn
import functools
from tqdm import tqdm
from flax.training import train_state

from ott.neural.networks.layers import time_encoder
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks import velocity_field
from ott.solvers import utils as solver_utils
import jax.tree_util as jtu
from ott.neural.networks.layers import time_encoder
from ott.neural.networks.velocity_field import VelocityField
import pandas as pd
import os

import hydra
from omegaconf import DictConfig, OmegaConf

from torch.utils.data import DataLoader
import numpy as np

import scanpy as sc
from ot_pert.metrics import compute_metrics, compute_mean_metrics
from ot_pert.nets.nets import VelocityFieldWithAttention
from ot_pert.utils import ConditionalLoader


In [2]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

In [3]:
obsm_key_cond = "ecfp_and_dose"
obsm_key_data = "X_pca"

In [4]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_30.h5ad"

In [5]:
adata_train = sc.read(adata_train_path)



In [6]:
sc.pp.subsample(adata_train, fraction = 0.5)

In [8]:
def data_match_fn(
    src_lin: Optional[jnp.ndarray], tgt_lin: Optional[jnp.ndarray],
    src_quad: Optional[jnp.ndarray], tgt_quad: Optional[jnp.ndarray], *,
    typ: Literal["lin", "quad", "fused"], epsilon: float = 1e-2, tau_a: float = 1.0,
    tau_b: float = 1.0,
) -> jnp.ndarray:
    if typ == "lin":
        return solver_utils.match_linear(x=src_lin, y=tgt_lin, scale_cost="mean", epsilon=epsilon, tau_a=tau_a, tau_b=tau_b)
    if typ == "quad":
        return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad)
    if typ == "fused":
        return solver_utils.match_quadratic(
            xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin
        )
    raise NotImplementedError(f"Unknown type: {typ}.")

# Load data

dls = []

train_data_source = {}
train_data_target = {}
train_data_source_decoded = {}
train_data_target_decoded = {}
train_data_conditions = {}


for cond in adata_train.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_train[adata_train.obs["condition"]==cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_train[adata_train.obs["condition"]==src_str[0]+"_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_train[adata_train.obs["condition"]==src_str[0]+"_Vehicle_0.0"].X.A
    target = adata_train[adata_train.obs["condition"]==cond].obsm[obsm_key_data]
    target_decoded = adata_train[adata_train.obs["condition"]==cond].X.A
    conds = adata_train[adata_train.obs["condition"]==cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    dls.append(DataLoader(datasets.OTDataset(datasets.OTData(
        lin=source,
        condition=conds,
    ), datasets.OTData(lin=target)), batch_size=10, shuffle=True))
    train_data_source[cond] = source
    train_data_target[cond] = target
    train_data_conditions[cond] = conds
    train_data_source_decoded[cond] = source_decoded
    train_data_target_decoded[cond] = target_decoded

train_loader = ConditionalLoader(dls, seed=0)

test_data_source = {}
test_data_target = {}
test_data_source_decoded = {}
test_data_target_decoded = {}
test_data_conditions = {}
adata_test = sc.read(adata_test_path)
for cond in adata_test.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_test[adata_test.obs["condition"]==cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_test[adata_test.obs["condition"]==src_str[0]+"_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_test[adata_test.obs["condition"]==src_str[0]+"_Vehicle_0.0"].X.A
    
    target = adata_test[adata_test.obs["condition"]==cond].obsm[obsm_key_data]
    target_decoded = adata_test[adata_test.obs["condition"]==cond].X.A
    conds = adata_test[adata_test.obs["condition"]==cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    test_data_source[cond] = source
    test_data_target[cond] = target
    test_data_source_decoded[cond] = source_decoded
    test_data_target_decoded[cond] = target_decoded
    test_data_conditions[cond] = conds

ood_data_source = {}
ood_data_target = {}
ood_data_source_decoded = {}
ood_data_target_decoded = {}
ood_data_conditions = {}
adata_ood = sc.read(adata_ood_path)
for cond in adata_ood.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_ood[adata_ood.obs["condition"]==cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_ood[adata_ood.obs["condition"]==src_str[0]+"_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_ood[adata_ood.obs["condition"]==src_str[0]+"_Vehicle_0.0"].X.A
    conds = adata_ood[adata_ood.obs["condition"]==cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    ood_data_source[cond] = source
    ood_data_target[cond] = target
    ood_data_source_decoded[cond] = source_decoded
    ood_data_target_decoded[cond] = target_decoded
    ood_data_conditions[cond] = conds

reconstruct_data_fn = functools.partial(reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T)




In [9]:
reconstruct_data_fn = functools.partial(reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T)

train_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in train_data_conditions.keys()}
test_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in test_data_conditions.keys()}
ood_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in ood_data_conditions.keys()}

def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [10]:
source_dim = source.shape[1]
target_dim = source_dim
condition_dim = conds.shape[1]

source_dim = source.shape[1]
target_dim = source_dim
condition_dim = conds.shape[1]

In [None]:
vf = VelocityField(
    hidden_dims=[512,512],
    time_dims=[512,512],
    output_dims=[30]+[target_dim],
    condition_dims=[512,512],
    time_encoder = functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024),
    )

model = otfm.OTFlowMatching(vf,
    flow=dynamics.ConstantNoiseFlow(0),
    match_fn=jax.jit(functools.partial(data_match_fn, typ="lin", src_quad=None, tgt_quad=None, epsilon=0.1, tau_a=1.0, tau_b=1.0)),
    condition_dim=condition_dim,
    rng=jax.random.PRNGKey(13),
    optimizer=optax.MultiSteps(optax.adam(learning_rate=1e-3), 5)
)

training_logs = {"loss": []}

rng = jax.random.PRNGKey(0)
for it in tqdm(range(1000)):
    rng, rng_resample, rng_step_fn = jax.random.split(rng, 3)
    batch = next(train_loader)
    batch = jtu.tree_map(jnp.asarray, batch)

    src, tgt = batch["src_lin"], batch["tgt_lin"]
    src_cond = batch.get("src_condition")

    if model.match_fn is not None:
        tmat = model.match_fn(src, tgt)
        src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
        src, tgt = src[src_ixs], tgt[tgt_ixs]
        src_cond = None if src_cond is None else src_cond[src_ixs]

    model.vf_state, loss = model.step_fn(
        rng_step_fn,
        model.vf_state,
        src,
        tgt,
        src_cond,
    )

    training_logs["loss"].append(float(loss))
    if (it % 100 == 0) and (it > 0):
        idcs = np.random.choice(list(test_data_source.keys()), 20)
        test_data_source_tmp = {k:v for k,v in test_data_source.items() if k in idcs}
        test_data_target_tmp = {k:v for k,v in test_data_target.items() if k in idcs}
        test_data_conditions_tmp = {k:v for k,v in test_data_conditions.items() if k in idcs}
        test_data_target_decoded_tmp = {k:v for k,v in test_data_target_decoded.items() if k in idcs}
        test_deg_dict_tmp = {k:v for k,v in test_deg_dict.items() if k in idcs}
        valid_losses = []
        for cond in test_data_source_tmp.keys():
            src = test_data_source_tmp[cond]
            tgt = test_data_target_tmp[cond]
            src_cond = test_data_conditions_tmp[cond]
            if model.match_fn is not None:
                tmat = model.match_fn(src, tgt)
                src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
                src, tgt = src[src_ixs], tgt[tgt_ixs]
                src_cond = None if src_cond is None else src_cond[src_ixs]
            _, valid_loss = model.step_fn(
                rng,
                model.vf_state,
                src,
                tgt,
                src_cond,
            )
            valid_losses.append(valid_loss)

        # predicted_target_train = jax.tree_util.tree_map(model.transport, train_data_source, train_data_conditions)
        # train_metrics = jax.tree_util.tree_map(compute_metrics, train_data_target, predicted_target_train)
        # mean_train_metrics = compute_mean_metrics(train_metrics, prefix="train_")

        # predicted_target_train_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_train)
        # train_metrics_decoded = jax.tree_util.tree_map(compute_metrics, train_data_target_decoded, predicted_target_train_decoded)
        # mean_train_metrics_decoded = compute_mean_metrics(train_metrics_decoded, prefix="decoded_train_")

        # train_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, predicted_target_train_decoded, train_deg_dict)
        # train_deg_target_decoded = jax.tree_util.tree_map(get_mask, train_data_target_decoded, test_deg_dict)

        predicted_target_test = jax.tree_util.tree_map(model.transport, test_data_source_tmp, test_data_conditions_tmp)
        #test_metrics = jax.tree_util.tree_map(compute_metrics, test_data_target_tmp, predicted_target_test)
        #mean_test_metrics = compute_mean_metrics(test_metrics, prefix="test_")

        predicted_target_test_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_test)
        #test_metrics_decoded = jax.tree_util.tree_map(
        #    compute_metrics, test_data_target_decoded_tmp, predicted_target_test_decoded
        #)
        #mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

        
        test_deg_target_decoded_predicted = jax.tree_util.tree_map(
            get_mask, predicted_target_test_decoded, test_deg_dict_tmp
        )
        test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded_tmp, test_deg_dict_tmp)
        deg_test_metrics_encoded = jax.tree_util.tree_map(
            compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted
        )
        deg_mean_test_metrics_encoded = compute_mean_metrics(deg_test_metrics_encoded, prefix="deg_test_")

        predicted_target_ood = jax.tree_util.tree_map(model.transport, ood_data_source, ood_data_conditions)
        ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_data_target, predicted_target_ood)
        mean_ood_metrics = compute_mean_metrics(ood_metrics, prefix="ood_")

        predicted_target_ood_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_ood)
        ood_metrics_decoded = jax.tree_util.tree_map(
            compute_metrics, ood_data_target_decoded, predicted_target_ood_decoded
        )
        mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

        ood_deg_target_decoded_predicted = jax.tree_util.tree_map(
            get_mask, predicted_target_ood_decoded, ood_deg_dict
        )
        ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)
        deg_ood_metrics_encoded = jax.tree_util.tree_map(
            compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted
        )
        deg_mean_ood_metrics_encoded = compute_mean_metrics(deg_ood_metrics_encoded, prefix="deg_ood_")


 30%|██▉       | 298/1000 [1:33:40<56:59,  4.87s/it]   2024-04-18 16:36:18.248686: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_while] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-04-18 16:36:18.753498: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m22.998164108s

********************************
[Compiling module jit_while] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


In [11]:
ood_data_conditions

{'A549_A-366_10.0': 'A549_A-366_10.0',
 'A549_A-366_100.0': 'A549_A-366_100.0',
 'A549_A-366_10000.0': 'A549_A-366_10000.0',
 'A549_Carmofur_10.0': 'A549_Carmofur_10.0',
 'A549_Carmofur_100.0': 'A549_Carmofur_100.0',
 'A549_Disulfiram__10.0': 'A549_Disulfiram__10.0',
 'A549_Disulfiram__1000.0': 'A549_Disulfiram__1000.0',
 'A549_Disulfiram__10000.0': 'A549_Disulfiram__10000.0',
 'A549_GSK-LSD1_2HCl_10.0': 'A549_GSK-LSD1_2HCl_10.0',
 'A549_GSK-LSD1_2HCl_100.0': 'A549_GSK-LSD1_2HCl_100.0',
 'A549_GSK-LSD1_2HCl_1000.0': 'A549_GSK-LSD1_2HCl_1000.0',
 'A549_GSK-LSD1_2HCl_10000.0': 'A549_GSK-LSD1_2HCl_10000.0',
 'A549_INO-1001_(3-Aminobenzamide)_100.0': 'A549_INO-1001_(3-Aminobenzamide)_100.0',
 'A549_INO-1001_(3-Aminobenzamide)_1000.0': 'A549_INO-1001_(3-Aminobenzamide)_1000.0',
 'A549_INO-1001_(3-Aminobenzamide)_10000.0': 'A549_INO-1001_(3-Aminobenzamide)_10000.0',
 'A549_Lenalidomide_(CC-5013)_10.0': 'A549_Lenalidomide_(CC-5013)_10.0',
 'A549_Lenalidomide_(CC-5013)_100.0': 'A549_Lenalidomi

In [39]:
src.shape, tgt.shape, src_cond.shape

((500, 30), (500, 30), (500, 30))

In [42]:
from ott.geometry import costs, pointcloud
pointcloud.PointCloud(
            x=src,
            y=tgt,
            cost_fn=costs.SqEuclidean(),
            scale_cost=1.0,
        ).cost_matrix.mean()

Array(17.699856, dtype=float32)

In [38]:
jax.tree_util.tree_map(lambda x: x.shape, test_data_conditions_tmp)

{'A549_Crizotinib_(PF-02341066)_1000.0': (500, 1025),
 'A549_Ki16425_10000.0': (500, 1025),
 'A549_MK-5108_(VX-689)_100.0': (500, 1025),
 'A549_Veliparib_(ABT-888)_1000.0': (500, 1025),
 'K562_Maraviroc_1000.0': (500, 1025),
 'K562_Ofloxacin_1000.0': (500, 1025),
 'K562_SRT2104_(GSK2245840)_100.0': (500, 1025),
 'K562_UNC0379_100.0': (500, 1025),
 'MCF7_Anacardic_Acid_1000.0': (500, 1025),
 'MCF7_CUDC-101_10000.0': (500, 1025),
 'MCF7_Capecitabine__1000.0': (500, 1025),
 'MCF7_Cerdulatinib_(PRT062070,_PRT2070)_10000.0': (500, 1025),
 'MCF7_Clevudine__10000.0': (500, 1025),
 'MCF7_G007-LK_100.0': (500, 1025),
 'MCF7_Iniparib_(BSI-201)_100.0': (500, 1025),
 'MCF7_PHA-680632_10.0': (500, 1025),
 'MCF7_RG108_10000.0': (500, 1025),
 'MCF7_Ramelteon_10.0': (500, 1025),
 'MCF7_Resveratrol_10000.0': (500, 1025),
 'MCF7_Tofacitinib_(CP-690550)_Citrate_100.0': (500, 1025)}

In [26]:
idcs = np.random.choice(list(test_data_source.keys()), 50)

In [27]:
idcs

array(['MCF7_Vandetanib_(ZD6474)_100.0', 'K562_Ramelteon_100.0',
       'K562_ENMD-2076_L-(+)-Tartaric_acid__100.0',
       'MCF7_Navitoclax_(ABT-263)_100.0',
       'MCF7_Motesanib_Diphosphate_(AMG-706)_10000.0',
       'A549_WP1066_1000.0', 'A549_Fulvestrant_100.0',
       'MCF7_Temsirolimus_(CCI-779,_NSC_683864)_100.0',
       'MCF7_Curcumin_1000.0', 'MCF7_Alvespimycin_(17-DMAG)_HCl_10.0',
       'MCF7_Barasertib_(AZD1152-HQPA)_10000.0', 'MCF7_Quercetin_10000.0',
       'K562_RG108_10.0', 'MCF7_Ofloxacin_100.0',
       'K562_Thalidomide_100.0', 'A549_Cimetidine__10.0',
       'MCF7_MC1568_1000.0', 'MCF7_RG108_10000.0',
       'A549_Resminostat_100.0', 'MCF7_SRT2104_(GSK2245840)_1000.0',
       'K562_ABT-737_100.0', 'K562_Fluorouracil_(5-Fluoracil,_5-FU)_10.0',
       'MCF7_TMP195_10.0', 'A549_AICAR_(Acadesine)_1000.0',
       'MCF7_Decitabine_10.0', 'MCF7_AICAR_(Acadesine)_1000.0',
       'MCF7_UNC0631_100.0', 'K562_UNC0379_1000.0',
       'MCF7_Trichostatin_A_(TSA)_10.0', 'A549_And

In [31]:
len({k:v for k,v in test_data_source.items() if k in idcs})

50

In [25]:
for cond in adata_train.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_train[adata_train.obs["condition"]==cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_train[adata_train.obs["condition"]==src_str[0]+"_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_train[adata_train.obs["condition"]==src_str[0]+"_Vehicle_0.0"].X.A
    target = adata_train[adata_train.obs["condition"]==cond].obsm[obsm_key_data]
    target_decoded = adata_train[adata_train.obs["condition"]==cond].X.A
    conds = adata_train[adata_train.obs["condition"]==cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    
    ds = datasets.OTDataset(datasets.OTData(
        lin=source,
        condition=conds,
    ),datasets.OTData(lin=target))
    break

In [47]:
conds = np.tile(conds[0], (len(source), 1))

In [48]:
conds.shape

(2787, 1025)

In [30]:
batch = next(iter(ds))

In [31]:
batch["src_lin"].shape

(30,)

In [32]:
batch["tgt_lin"].shape

(30,)

In [33]:
batch["src_condition"].shape

(1025,)

In [39]:
ds.src_data.lin.shape

(2787, 30)

In [40]:
ds.tgt_data.lin.shape

(114, 30)

In [44]:
ds.src_data.condition.shape

(114, 1025)

In [36]:
dir(ds)

['SRC_PREFIX',
 'TGT_PREFIX',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_rng',
 '_sample_from_target',
 '_tgt_cond_to_ix',
 '_verify_integrity',
 'is_aligned',
 'src_conditions',
 'src_data',
 'tgt_conditions',
 'tgt_data']

In [17]:
ood_data_source = {}
ood_data_target = {}
ood_data_source_decoded = {}
ood_data_target_decoded = {}
ood_data_conditions = {}
#adata_ood = sc.read(adata_ood_path)
for cond in adata_ood.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_ood[adata_ood.obs["condition"]==cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_ood[adata_ood.obs["condition"]==src_str[0]+"_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_ood[adata_ood.obs["condition"]==src_str[0]+"_Vehicle_0.0"].X.A
    conds = adata_ood[adata_ood.obs["condition"]==cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    ood_data_source[cond] = source
    ood_data_target[cond] = target
    ood_data_source_decoded[cond] = source_decoded
    ood_data_target_decoded[cond] = target_decoded
    ood_data_conditions[cond] = conds
    break

In [18]:
ood_data_conditions

{'A549_A-366_10.0': ArrayView([[ 0.,  0.,  0., ...,  0.,  0., 10.],
            [ 0.,  0.,  0., ...,  0.,  0., 10.],
            [ 0.,  0.,  0., ...,  0.,  0., 10.],
            ...,
            [ 0.,  0.,  0., ...,  0.,  0., 10.],
            [ 0.,  0.,  0., ...,  0.,  0., 10.],
            [ 0.,  0.,  0., ...,  0.,  0., 10.]])}

In [15]:
conds

ArrayView([[ 0.,  0.,  0., ...,  0.,  0., 10.],
           [ 0.,  0.,  0., ...,  0.,  0., 10.],
           [ 0.,  0.,  0., ...,  0.,  0., 10.],
           ...,
           [ 0.,  0.,  0., ...,  0.,  0., 10.],
           [ 0.,  0.,  0., ...,  0.,  0., 10.],
           [ 0.,  0.,  0., ...,  0.,  0., 10.]])

In [56]:
src_str = adata_ood[adata_ood.obs["condition"]==cond].obs["cell_type"].unique()

In [57]:
src_str

['A549']
Categories (1, object): ['A549']

In [58]:
adata_ood[adata_ood.obs["condition"]==src_str[0]+"_Vehicle_0.0"]

View of AnnData object with n_obs × n_vars = 0 × 2002
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'perturbation', 'drug', 'cell_line', 'condition', 'pubchem_name', 'pubchem_ID', 'smiles', 'control', 'ood', 'is_ood', 'split'
    obsm: 'X_pca', 'X_umap', 'ecfp', 'ecfp_and_dose'
    varm: 'X_train_mean'
    layers: 'centered_X'

In [60]:
adata_test.obs["condition"].value_counts()

condition
A549_Vehicle_0.0                  500
K562_Vehicle_0.0                  500
MCF7_Vehicle_0.0                  500
A549_ABT-737_100.0                100
MCF7_Dasatinib_100.0              100
                                 ... 
K562_Droxinostat_1000.0           100
K562_Divalproex_Sodium_10000.0    100
K562_Divalproex_Sodium_1000.0     100
K562_Divalproex_Sodium_10.0       100
MCF7_Zileuton_10000.0             100
Name: count, Length: 1278, dtype: int64

In [64]:
(adata_test.obs["drug"] == "Vehicle").sum()

1500