In [1]:
import functools
from typing import Literal, Optional

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
import scanpy as sc
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks.layers import time_encoder
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
from tqdm import tqdm

from ot_pert.metrics import compute_mean_metrics, compute_metrics_fast
from ot_pert.nets.nets import VelocityFieldWithAttention
from ot_pert.utils import ConditionalLoader


def reconstruct_data(embedding, projection_matrix, mean_to_add):
    """Reconstructs data from projections."""
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

In [6]:
obsm_key_data = "X_pca"
obsm_key_cond_1 = "ecfp_drug_1"
obsm_key_cond_2 = "ecfp_drug_2"
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_30.h5ad"

In [7]:
def load_data(adata, cfg, *, return_dl: bool):
    """Loads data and preprocesses it based on configuration."""
    dls = []
    data_source = {}
    data_target = {}
    data_source_decoded = {}
    data_target_decoded = {}
    data_conditions = {}
    source = adata[adata.obs["condition"] == "control"].obsm[obsm_key_data]
    source_decoded = adata[adata.obs["condition"] == "control"].X

    for cond in adata.obs["condition"].cat.categories:
        if cond != "control":
            target = adata[adata.obs["condition"] == cond].obsm[obsm_key_data]
            target_decoded = adata[adata.obs["condition"] == cond].X.A
            condition_1 = adata[adata.obs["condition"] == cond].obsm[obsm_key_cond_1]
            condition_2 = adata[adata.obs["condition"] == cond].obsm[obsm_key_cond_2]
            assert np.all(np.all(condition_1 == condition_1[0], axis=1))
            assert np.all(np.all(condition_2 == condition_2[0], axis=1))
            expanded_arr = np.expand_dims(
                np.concatenate((condition_1[0, :][None, :], condition_2[0, :][None, :]), axis=0), axis=0
            )
            conds = np.tile(expanded_arr, (len(source), 1, 1))

            if return_dl:
                dls.append(
                    DataLoader(
                        datasets.OTDataset(
                            datasets.OTData(
                                lin=source,
                                condition=conds,
                            ),
                            datasets.OTData(lin=target),
                        ),
                        batch_size=32,
                        shuffle=True,
                    )
                )
            else:
                data_source[cond] = source
                data_target[cond] = target
                data_source_decoded[cond] = source_decoded
                data_target_decoded[cond] = target_decoded
                data_conditions[cond] = conds
    if return_dl:
        return ConditionalLoader(dls, seed=0)

    deg_dict = {k: v for k, v in adata.uns["rank_genes_groups_cov_all"].items() if k in data_conditions.keys()}

    return {
        "source": data_source,
        "target": data_target,
        "source_decoded": data_source_decoded,
        "target_decoded": data_target_decoded,
        "conditions": data_conditions,
        "deg_dict": deg_dict,
    }


def data_match_fn(
    src_lin: Optional[jnp.ndarray],
    tgt_lin: Optional[jnp.ndarray],
    src_quad: Optional[jnp.ndarray],
    tgt_quad: Optional[jnp.ndarray],
    *,
    typ: Literal["lin", "quad", "fused"],
    epsilon: float = 1e-2,
    tau_a: float = 1.0,
    tau_b: float = 1.0,
) -> jnp.ndarray:
    if typ == "lin":
        return solver_utils.match_linear(
            x=src_lin, y=tgt_lin, scale_cost="mean", epsilon=epsilon, tau_a=tau_a, tau_b=tau_b
        )
    if typ == "quad":
        return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad)
    if typ == "fused":
        return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin)
    raise NotImplementedError(f"Unknown type: {typ}.")


def get_mask(x, y, var_names):
    return x[:, [gene in y for gene in var_names]]

In [8]:
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)
train_data = {}
test_data = {}
ood_data = load_data(adata_ood, None, return_dl=False)
dl = load_data(adata_train, None, return_dl=True)

2024-04-23 10:22:39.257463: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [19]:
def get_masks(dataset: List[jnp.ndarray]):
    attention_mask = []
    for data in dataset:
        if data.ndim < 2:
            data = data[None, :]
        if data.ndim < 3:
            data = data[None, :]
        mask = jnp.all(data == 0.0, axis=-1)
        mask = 1 - mask
        mask = jnp.outer(mask, mask)
        attention_mask.append(mask)
    return jnp.expand_dims(jnp.equal(jnp.array(attention_mask), 1.0), 1)

In [24]:
from collections.abc import Sequence
from typing import Callable, Optional

import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state


class VelocityFieldWithAttention2(nn.Module):
    num_heads: int
    qkv_feature_dim: int
    max_seq_length: int
    hidden_dims: Sequence[int]
    output_dims: Sequence[int]
    condition_dims: Optional[Sequence[int]] = None
    time_dims: Optional[Sequence[int]] = None
    time_encoder: Callable[[jnp.ndarray], jnp.ndarray] = time_encoder.cyclical_time_encoder
    act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu
    pad_max_dim: int = -1

    def __post_init__(self):
        self.get_masks = jax.jit(get_masks)
        super().__post_init__()

    @nn.compact
    def __call__(
        self,
        t: jnp.ndarray,
        x: jnp.ndarray,
        condition: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """Forward pass through the neural vector field.

        Args:
          t: Time of shape ``[batch, 1]``.
          x: Data of shape ``[batch, ...]``.
          condition: Conditioning vector of shape ``[batch, ...]``.

        Returns
        -------
          Output of the neural vector field of shape ``[batch, output_dim]``.
        """
        squeeze_output = False
        if x.ndim < 2:
            x = x[None, :]
            t = jnp.full(shape=(1, 1), fill_value=t)
            condition = condition[None, :]
            squeeze_output = True

        time_dims = self.hidden_dims if self.time_dims is None else self.time_dims
        t = self.time_encoder(t)
        for time_dim in time_dims:
            t = self.act_fn(nn.Dense(time_dim)(t))

        for hidden_dim in self.hidden_dims:
            x = self.act_fn(nn.Dense(hidden_dim)(x))

        assert condition is not None, "No condition sequence was passed."

        token_shape = (len(condition), 1) if condition.ndim > 2 else (1,)
        print(token_shape)
        class_token = nn.Embed(num_embeddings=1, features=condition.shape[-1])(jnp.int32(jnp.zeros(token_shape)))

        condition = jnp.concatenate((class_token, condition), axis=-2)
        mask = self.get_masks(condition)

        attention = nn.MultiHeadDotProductAttention(num_heads=self.num_heads, qkv_features=self.qkv_feature_dim)
        emb = attention(condition, mask=mask)
        emb = emb[:, 0, :]  # only continue with token 0

        for cond_dim in self.condition_dims:
            condition = self.act_fn(nn.Dense(cond_dim)(emb))

        feats = jnp.concatenate([t, x, condition], axis=1)

        for output_dim in self.output_dims[:-1]:
            feats = self.act_fn(nn.Dense(output_dim)(feats))

        # no activation function for the final layer
        out = nn.Dense(self.output_dims[-1])(feats)
        return jnp.squeeze(out) if squeeze_output else out

    def create_train_state(
        self,
        rng: jax.Array,
        optimizer: optax.OptState,
        input_dim: int,
        condition_dim: Optional[int] = None,
    ) -> train_state.TrainState:
        """Create the training state.

        Args:
          rng: Random number generator.
          optimizer: Optimizer.
          input_dim: Dimensionality of the velocity field.
          condition_dim: Dimensionality of the condition of the velocity field.

        Returns
        -------
          The training state.
        """
        t, x = jnp.ones((1, 1)), jnp.ones((1, input_dim))
        if self.condition_dims is None:
            cond = None
        else:
            assert condition_dim > 0, "Condition dimension must be positive."
            print(condition_dim)
            cond = jnp.ones((1, 1, condition_dim))

        params = self.init(rng, t, x, cond)["params"]
        return train_state.TrainState.create(apply_fn=self.apply, params=params, tx=optimizer)

In [27]:
comp_metrics_fn = compute_metrics_fast

reconstruct_data_fn = functools.partial(
    reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T
)
mask_fn = functools.partial(get_mask, var_names=adata_train.var_names)

batch = next(dl)
output_dim = batch["tgt_lin"].shape[1]
condition_dim = batch["src_condition"].shape[-1]

vf = VelocityFieldWithAttention(
    num_heads=4,
    qkv_feature_dim=32,
    max_seq_length=2,
    hidden_dims=[512, 512],
    time_dims=[512, 512],
    output_dims=[512, 512] + [output_dim],
    condition_dims=[512, 512],
    time_encoder=functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024),
)

model = otfm.OTFlowMatching(
    vf,
    flow=dynamics.ConstantNoiseFlow(0.0),
    match_fn=jax.jit(
        functools.partial(
            data_match_fn,
            typ="lin",
            src_quad=None,
            tgt_quad=None,
            epsilon=1.0,
            tau_a=1.0,
            tau_b=1.0,
        )
    ),
    condition_dim=condition_dim,
    rng=jax.random.PRNGKey(13),
    optimizer=optax.MultiSteps(optax.adam(1e-4), 20),
)

training_logs = {"loss": []}
rng = jax.random.PRNGKey(0)

In [29]:
def eval_step(cfg, model, data, log_metrics, reconstruct_data_fn, comp_metrics_fn, mask_fn):
    for k, dat in data.items():
        if k == "test":
            n_samples = 0
        if k == "train":
            n_samples = 0
        if k == "ood":
            n_samples = -1

        if n_samples != 0:
            if n_samples > 0:
                idcs = np.random.choice(list(list(dat.values())[0]), n_samples)
                dat_source = {k: v for k, v in dat["source"].items() if k in idcs}
                dat_target = {k: v for k, v in dat["target"].items() if k in idcs}
                dat_conditions = {k: v for k, v in dat["conditions"].items() if k in idcs}
                dat_deg_dict = {k: v for k, v in dat["deg_dict"].items() if k in idcs}
                dat_target_decoded = {k: v for k, v in dat["target_decoded"].items() if k in idcs}
            else:
                dat_source = dat["source"]
                dat_target = dat["target"]
                dat_conditions = dat["conditions"]
                dat_deg_dict = dat["deg_dict"]
                dat_target_decoded = dat["target_decoded"]

            prediction = jtu.tree_map(model.transport, dat_source, dat_conditions)
            metrics = jtu.tree_map(comp_metrics_fn, dat_target, prediction)
            mean_metrics = compute_mean_metrics(metrics, prefix=f"{k}_")
            log_metrics.update(mean_metrics)

            prediction_decoded = jtu.tree_map(reconstruct_data_fn, prediction)
            metrics_decoded = jtu.tree_map(comp_metrics_fn, dat_target_decoded, prediction_decoded)
            mean_metrics_decoded = compute_mean_metrics(metrics_decoded, prefix=f"decoded_{k}_")
            log_metrics.update(mean_metrics_decoded)

            prediction_decoded_deg = jtu.tree_map(mask_fn, prediction_decoded, dat_deg_dict)
            target_decoded_deg = jax.tree_util.tree_map(mask_fn, dat_target_decoded, dat_deg_dict)
            metrics_deg = jtu.tree_map(comp_metrics_fn, target_decoded_deg, prediction_decoded_deg)
            mean_metrics_deg = compute_mean_metrics(metrics_deg, prefix=f"deg_{k}_")
            log_metrics.update(mean_metrics_deg)

    print(log_metrics)

In [31]:
for it in tqdm(range(100)):
    rng, rng_resample, rng_step_fn = jax.random.split(rng, 3)
    batch = next(dl)
    batch = jtu.tree_map(jnp.asarray, batch)

    src, tgt = batch["src_lin"], batch["tgt_lin"]
    src_cond = batch.get("src_condition")

    if model.match_fn is not None:
        tmat = model.match_fn(src, tgt)
        src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
        src, tgt = src[src_ixs], tgt[tgt_ixs]
        src_cond = None if src_cond is None else src_cond[src_ixs]

    model.vf_state, loss = model.step_fn(
        rng_step_fn,
        model.vf_state,
        src,
        tgt,
        src_cond,
    )

    training_logs["loss"].append(float(loss))
    if (it % 10 == 0) and (it > 0):
        train_loss = np.mean(training_logs["loss"][-10:])
        log_metrics = {"train_loss": train_loss}
        eval_step(
            None,
            model,
            {"train": train_data, "test": test_data, "ood": ood_data},
            log_metrics,
            reconstruct_data_fn,
            comp_metrics_fn,
            mask_fn,
        )

 20%|██        | 20/100 [00:25<01:24,  1.06s/it]

{'train_loss': 2.068877935409546, 'ood_r_squared': -0.267297544155283, 'ood_e_distance': 64.00366434401329, 'ood_mmd_distance': 0.10975084641701024, 'decoded_ood_r_squared': 0.6886720894982247, 'decoded_ood_e_distance': 65.1548799950073, 'decoded_ood_mmd_distance': 0.10469719692641617, 'deg_ood_r_squared': 0.39709923807133723, 'deg_ood_e_distance': 38.94894497198663, 'deg_ood_mmd_distance': 0.13494821025869946}


 29%|██▉       | 29/100 [00:48<01:45,  1.49s/it]

{'train_loss': 2.196030855178833, 'ood_r_squared': -0.22045694068340493, 'ood_e_distance': 61.8951147119141, 'ood_mmd_distance': 0.10723329663260936, 'decoded_ood_r_squared': 0.6988271385301771, 'decoded_ood_e_distance': 63.046331175908264, 'decoded_ood_mmd_distance': 0.10334919571961845, 'deg_ood_r_squared': 0.41655784480552055, 'deg_ood_e_distance': 37.71764960054539, 'deg_ood_mmd_distance': 0.13275068108877447}


 39%|███▉      | 39/100 [01:12<01:39,  1.63s/it]

{'train_loss': 1.6872590780258179, 'ood_r_squared': -0.22045694068340493, 'ood_e_distance': 61.8951147119141, 'ood_mmd_distance': 0.10723329663260936, 'decoded_ood_r_squared': 0.6988271385301771, 'decoded_ood_e_distance': 63.046331175908264, 'decoded_ood_mmd_distance': 0.10334919571961845, 'deg_ood_r_squared': 0.41655784480552055, 'deg_ood_e_distance': 37.71764960054539, 'deg_ood_mmd_distance': 0.13275068108877447}


 50%|█████     | 50/100 [01:35<01:20,  1.61s/it]

{'train_loss': 2.304449427127838, 'ood_r_squared': -0.17833956693922975, 'ood_e_distance': 59.94131206019108, 'ood_mmd_distance': 0.10491756734789348, 'decoded_ood_r_squared': 0.7081792279178719, 'decoded_ood_e_distance': 61.092529293228985, 'decoded_ood_mmd_distance': 0.10209074221028178, 'deg_ood_r_squared': 0.43524721889936496, 'deg_ood_e_distance': 36.532852764870995, 'deg_ood_mmd_distance': 0.13062581721605296}


 59%|█████▉    | 59/100 [01:59<01:11,  1.75s/it]

{'train_loss': 1.8693579137325287, 'ood_r_squared': -0.17833956693922975, 'ood_e_distance': 59.94131206019108, 'ood_mmd_distance': 0.10491756734789348, 'decoded_ood_r_squared': 0.7081792279178719, 'decoded_ood_e_distance': 61.092529293228985, 'decoded_ood_mmd_distance': 0.10209074221028178, 'deg_ood_r_squared': 0.43524721889936496, 'deg_ood_e_distance': 36.532852764870995, 'deg_ood_mmd_distance': 0.13062581721605296}


 69%|██████▉   | 69/100 [02:23<00:54,  1.76s/it]

{'train_loss': 2.015455973148346, 'ood_r_squared': -0.13918528124393015, 'ood_e_distance': 58.02488843024548, 'ood_mmd_distance': 0.10270528933883787, 'decoded_ood_r_squared': 0.7172547184235537, 'decoded_ood_e_distance': 59.17610639682092, 'decoded_ood_mmd_distance': 0.10086628547253966, 'deg_ood_r_squared': 0.4540298142172194, 'deg_ood_e_distance': 35.3378112140591, 'deg_ood_mmd_distance': 0.12850564704147158}


 79%|███████▉  | 79/100 [02:46<00:36,  1.75s/it]

{'train_loss': 2.144143211841583, 'ood_r_squared': -0.13918528124393015, 'ood_e_distance': 58.02488843024548, 'ood_mmd_distance': 0.10270528933883787, 'decoded_ood_r_squared': 0.7172547184235537, 'decoded_ood_e_distance': 59.17610639682092, 'decoded_ood_mmd_distance': 0.10086628547253966, 'deg_ood_r_squared': 0.4540298142172194, 'deg_ood_e_distance': 35.3378112140591, 'deg_ood_mmd_distance': 0.12850564704147158}


 90%|█████████ | 90/100 [03:11<00:17,  1.73s/it]

{'train_loss': 1.7133100152015686, 'ood_r_squared': -0.10348161196377445, 'ood_e_distance': 56.17502905572415, 'ood_mmd_distance': 0.10062560824916915, 'decoded_ood_r_squared': 0.7259233296118598, 'decoded_ood_e_distance': 57.3262477166294, 'decoded_ood_mmd_distance': 0.09969110626106982, 'deg_ood_r_squared': 0.4728019100609421, 'deg_ood_e_distance': 34.13936983487522, 'deg_ood_mmd_distance': 0.12638921307176632}


100%|██████████| 100/100 [03:35<00:00,  2.15s/it]

{'train_loss': 2.1012008786201477, 'ood_r_squared': -0.10348161196377445, 'ood_e_distance': 56.17502905572415, 'ood_mmd_distance': 0.10062560824916915, 'decoded_ood_r_squared': 0.7259233296118598, 'decoded_ood_e_distance': 57.3262477166294, 'decoded_ood_mmd_distance': 0.09969110626106982, 'deg_ood_r_squared': 0.4728019100609421, 'deg_ood_e_distance': 34.13936983487522, 'deg_ood_mmd_distance': 0.12638921307176632}





In [12]:
batch.keys()

dict_keys(['src_condition', 'src_lin', 'tgt_lin'])

In [13]:
batch["src_condition"].shape

(32, 2, 1024)

In [14]:
batch["src_lin"].shape

(32, 30)

In [15]:
batch["tgt_lin"].shape

(32, 30)