In [1]:
import functools
from collections.abc import Iterable
from typing import Literal, Optional

import jax
import jax.numpy as jnp
import numpy as np
import optax
import scanpy as sc
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks.layers import time_encoder
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
from tqdm import tqdm


class ConditionalLoader:
    """Dataset for OT problems with conditions.

    This data loader wraps several data loaders and samples from them.

    Args:
      datasets: Datasets to sample from.
      seed: Random seed.
    """

    def __init__(
        self,
        dataloaders: Iterable[DataLoader],
        seed: int = 0,
    ):
        self.dataloaders = tuple(dataloaders)
        self._rng = jax.random.PRNGKey(seed)

    def __next__(self):
        rng, self._rng = jax.random.split(self._rng, 2)
        idx = int(jax.random.choice(rng, len(self.dataloaders)))
        dl = self.dataloaders[idx]

        return next(iter(dl))

    def __iter__(self) -> "ConditionalLoader":
        return self

    def __len__(self) -> int:
        return 100000

In [2]:
def data_match_fn(
    src_lin: Optional[jnp.ndarray],
    tgt_lin: Optional[jnp.ndarray],
    src_quad: Optional[jnp.ndarray],
    tgt_quad: Optional[jnp.ndarray],
    *,
    typ: Literal["lin", "quad", "fused"],
) -> jnp.ndarray:
    if typ == "lin":
        return solver_utils.match_linear(x=src_lin, y=tgt_lin, scale_cost="mean")
    if typ == "quad":
        return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad)
    if typ == "fused":
        return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin)
    raise NotImplementedError(f"Unknown type: {typ}.")


adata_train = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train.h5ad")
dls = []
source = adata_train[adata_train.obs["condition"] == "control"].obsm["X_pca"]
for cond in adata_train.obs["condition"].cat.categories:
    if cond == "control":
        continue
    target = adata_train[adata_train.obs["condition"] == cond].obsm["X_pca"]
    condition_1 = adata_train[adata_train.obs["condition"] == cond].obsm["ecfp_drug_1"]
    condition_2 = adata_train[adata_train.obs["condition"] == cond].obsm["ecfp_drug_2"]
    assert np.all(np.all(condition_1 == condition_1[0], axis=1))
    assert np.all(np.all(condition_2 == condition_2[0], axis=1))
    expanded_arr = np.expand_dims(
        np.concatenate((condition_1[0, :][None, :], condition_2[0, :][None, :]), axis=0), axis=0
    )
    conds = np.tile(expanded_arr, (len(source), 1, 1))
    dls.append(
        DataLoader(
            datasets.OTDataset(
                datasets.OTData(
                    lin=source,
                    condition=conds,
                ),
                datasets.OTData(lin=target),
            )
        )
    )

train_loader = ConditionalLoader(dls, seed=0)

2024-04-24 13:19:07.184272: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
batch = next(iter(train_loader))

In [4]:
batch.keys()

dict_keys(['src_lin', 'src_condition', 'tgt_lin'])

In [5]:
from ot_pert.nets.nets import VelocityFieldWithAttention

In [6]:
source_dim = source.shape[1]
target_dim = source_dim
condition_dim = condition_1.shape[1]

vf = VelocityFieldWithAttention(
    num_heads=1,
    qkv_feature_dim=32,
    max_seq_length=2,
    hidden_dims=[512, 512, 512],
    output_dims=[512, 512, target_dim],
    condition_dims=[512, 512, 512],
    time_encoder=functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024),
    dropout_rate=0.2,
)

In [7]:
model = otfm.OTFlowMatching(
    vf,
    flow=dynamics.ConstantNoiseFlow(0.0),
    match_fn=jax.jit(functools.partial(data_match_fn, typ="lin", src_quad=None, tgt_quad=None)),
    condition_dim=condition_dim,
    rng=jax.random.PRNGKey(13),
    optimizer=optax.adam(learning_rate=1e-4),
)

In [8]:
training_logs = {"loss": []}
valid_freq = 10000000

In [10]:
import jax.tree_util as jtu

rng = jax.random.PRNGKey(0)
for it in tqdm(range(50000)):
    rng, rng_resample, rng_step_fn = jax.random.split(rng, 3)
    batch = next(iter(train_loader))
    batch = jtu.tree_map(jnp.asarray, batch)

    src, tgt = batch["src_lin"], batch["tgt_lin"]
    src_cond = batch.get("src_condition")

    if model.match_fn is not None:
        tmat = model.match_fn(src, tgt)
        src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
        src, tgt = src[src_ixs], tgt[tgt_ixs]
        src_cond = None if src_cond is None else src_cond[src_ixs]

    model.vf_state, loss = model.step_fn(
        rng_step_fn,
        model.vf_state,
        src,
        tgt,
        src_cond,
    )

    training_logs["loss"].append(float(loss))

100%|██████████| 50000/50000 [06:08<00:00, 135.79it/s]


In [11]:
test_data_source = {}
test_data_target = {}
test_data_conditions = {}
adata_test = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test.h5ad")
source = adata_test[adata_test.obs["condition"] == "control"].obsm["X_pca"]
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    target = adata_test[adata_test.obs["condition"] == cond].obsm["X_pca"]
    condition_1 = adata_test[adata_test.obs["condition"] == cond].obsm["ecfp_drug_1"]
    condition_2 = adata_test[adata_test.obs["condition"] == cond].obsm["ecfp_drug_2"]
    assert np.all(np.all(condition_1 == condition_1[0], axis=1))
    assert np.all(np.all(condition_2 == condition_2[0], axis=1))
    expanded_arr = np.expand_dims(
        np.concatenate((condition_1[0, :][None, :], condition_2[0, :][None, :]), axis=0), axis=0
    )
    conds = np.tile(expanded_arr, (len(source), 1, 1))
    test_data_source[cond] = source
    test_data_target[cond] = target
    test_data_conditions[cond] = conds



In [13]:
from ot_pert.metrics import compute_mean_metrics, compute_metrics_fast

In [18]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add


reconstruct_data_fn = functools.partial(
    reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T
)

In [None]:
predicted_target_ood = jax.tree_util.tree_map(model.transport, test_data_source, test_data_conditions)
test_metrics = jax.tree_util.tree_map(compute_metrics_fast, test_data_target, predicted_target_ood)
mean_test_metrics = compute_mean_metrics(test_metrics, prefix="test_")

In [19]:
predicted_target_test_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_ood)
test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics_fast, test_data_target_decoded, predicted_target_test_decoded
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

NameError: name 'test_data_target_decoded' is not defined

In [None]:
mean_test_metrics_decoded

In [None]:
mean_test_metrics

In [22]:
embs = jax.tree_util.tree_map(lambda x: vf.get_embedding(model.vf_state, x[0]), test_data_conditions)

In [23]:
embs

{'Alvespimycin+Pirarubicin': Array([[-0.01111262, -0.01984059,  0.00402328, ..., -0.06305397,
         -0.0452285 , -0.02433442],
        [-0.01111262, -0.01984059,  0.00402328, ..., -0.06305397,
         -0.0452285 , -0.02433442],
        [-0.01111262, -0.01984059,  0.00402328, ..., -0.06305397,
         -0.0452285 , -0.02433442]], dtype=float32),
 'Dacinostat+Danusertib': Array([[ 0.04122573,  0.00685555,  0.04499897, ..., -0.06701037,
         -0.08553509,  0.08969629],
        [ 0.04122573,  0.00685555,  0.04499897, ..., -0.06701037,
         -0.08553509,  0.08969629],
        [ 0.04122573,  0.00685555,  0.04499897, ..., -0.06701037,
         -0.08553509,  0.08969629]], dtype=float32),
 'Dacinostat+Dasatinib': Array([[ 0.02645211, -0.01343745,  0.04022329, ..., -0.10119122,
         -0.06401235,  0.13010481],
        [ 0.02645211, -0.01343745,  0.04022329, ..., -0.10119122,
         -0.06401235,  0.13010481],
        [ 0.02645211, -0.01343745,  0.04022329, ..., -0.10119122,
       

In [1]:
1

1