In [1]:
import functools
from typing import Literal, Optional

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
import scanpy as sc
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks.layers import time_encoder
from ott.neural.networks.velocity_field import VelocityField
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
from tqdm import tqdm

from ot_pert.metrics import compute_mean_metrics, compute_metrics
from ot_pert.utils import ConditionalLoader

In [2]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

In [3]:
obsm_key_cond = "ecfp_and_dose"
obsm_key_data = "X_pca"

In [4]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_30.h5ad"

In [6]:
adata_train = sc.read(adata_train_path)



In [7]:
sc.pp.subsample(adata_train, fraction=0.5)

In [12]:
def data_match_fn(
    src_lin: Optional[jnp.ndarray],
    tgt_lin: Optional[jnp.ndarray],
    src_quad: Optional[jnp.ndarray],
    tgt_quad: Optional[jnp.ndarray],
    *,
    typ: Literal["lin", "quad", "fused"],
    epsilon: float = 1e-2,
    tau_a: float = 1.0,
    tau_b: float = 1.0,
) -> jnp.ndarray:
    if typ == "lin":
        return solver_utils.match_linear(
            x=src_lin, y=tgt_lin, scale_cost="mean", epsilon=epsilon, tau_a=tau_a, tau_b=tau_b
        )
    if typ == "quad":
        return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad)
    if typ == "fused":
        return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin)
    raise NotImplementedError(f"Unknown type: {typ}.")


# Load data

dls = []

train_data_source = {}
train_data_target = {}
train_data_source_decoded = {}
train_data_target_decoded = {}
train_data_conditions = {}


for cond in adata_train.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_train[adata_train.obs["condition"] == cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_train[adata_train.obs["condition"] == src_str[0] + "_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_train[adata_train.obs["condition"] == src_str[0] + "_Vehicle_0.0"].X.A
    target = adata_train[adata_train.obs["condition"] == cond].obsm[obsm_key_data]
    target_decoded = adata_train[adata_train.obs["condition"] == cond].X.A
    conds = adata_train[adata_train.obs["condition"] == cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    dls.append(
        DataLoader(
            datasets.OTDataset(
                datasets.OTData(
                    lin=source,
                    condition=conds,
                ),
                datasets.OTData(lin=target),
            ),
            batch_size=10,
            shuffle=True,
        )
    )
    train_data_source[cond] = source
    train_data_target[cond] = target
    train_data_conditions[cond] = conds
    train_data_source_decoded[cond] = source_decoded
    train_data_target_decoded[cond] = target_decoded

train_loader = ConditionalLoader(dls, seed=0)

test_data_source = {}
test_data_target = {}
test_data_source_decoded = {}
test_data_target_decoded = {}
test_data_conditions = {}
# adata_test = sc.read(adata_test_path)
for cond in adata_test.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_test[adata_test.obs["condition"] == cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_test[adata_test.obs["condition"] == src_str[0] + "_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_test[adata_test.obs["condition"] == src_str[0] + "_Vehicle_0.0"].X.A

    target = adata_test[adata_test.obs["condition"] == cond].obsm[obsm_key_data]
    target_decoded = adata_test[adata_test.obs["condition"] == cond].X.A
    conds = adata_test[adata_test.obs["condition"] == cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    test_data_source[cond] = source
    test_data_target[cond] = target
    test_data_source_decoded[cond] = source_decoded
    test_data_target_decoded[cond] = target_decoded
    test_data_conditions[cond] = conds

ood_data_source = {}
ood_data_target = {}
ood_data_source_decoded = {}
ood_data_target_decoded = {}
ood_data_conditions = {}
# adata_ood = sc.read(adata_ood_path)
for cond in adata_ood.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_ood[adata_ood.obs["condition"] == cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_ood[adata_ood.obs["condition"] == src_str[0] + "_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_ood[adata_ood.obs["condition"] == src_str[0] + "_Vehicle_0.0"].X.A
    conds = adata_ood[adata_ood.obs["condition"] == cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    ood_data_source[cond] = source
    ood_data_target[cond] = target
    ood_data_source_decoded[cond] = source_decoded
    ood_data_target_decoded[cond] = target_decoded
    ood_data_conditions[cond] = cond

reconstruct_data_fn = functools.partial(
    reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T
)

In [13]:
reconstruct_data_fn = functools.partial(
    reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T
)

train_deg_dict = {
    k: v for k, v in adata_train.uns["rank_genes_groups_cov_all"].items() if k in train_data_conditions.keys()
}
test_deg_dict = {
    k: v for k, v in adata_train.uns["rank_genes_groups_cov_all"].items() if k in test_data_conditions.keys()
}
ood_deg_dict = {
    k: v for k, v in adata_train.uns["rank_genes_groups_cov_all"].items() if k in ood_data_conditions.keys()
}


def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [14]:
source_dim = source.shape[1]
target_dim = source_dim
condition_dim = conds.shape[1]

source_dim = source.shape[1]
target_dim = source_dim
condition_dim = conds.shape[1]

In [None]:
vf = VelocityField(
    hidden_dims=[512, 512],
    time_dims=[512, 512],
    output_dims=[30] + [target_dim],
    condition_dims=[512, 512],
    time_encoder=functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024),
)

model = otfm.OTFlowMatching(
    vf,
    flow=dynamics.ConstantNoiseFlow(0),
    match_fn=jax.jit(
        functools.partial(data_match_fn, typ="lin", src_quad=None, tgt_quad=None, epsilon=0.1, tau_a=1.0, tau_b=1.0)
    ),
    condition_dim=condition_dim,
    rng=jax.random.PRNGKey(13),
    optimizer=optax.MultiSteps(optax.adam(learning_rate=1e-3), 5),
)

training_logs = {"loss": []}

rng = jax.random.PRNGKey(0)
for it in tqdm(range(1000)):
    rng, rng_resample, rng_step_fn = jax.random.split(rng, 3)
    batch = next(train_loader)
    batch = jtu.tree_map(jnp.asarray, batch)

    src, tgt = batch["src_lin"], batch["tgt_lin"]
    src_cond = batch.get("src_condition")

    if model.match_fn is not None:
        tmat = model.match_fn(src, tgt)
        src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
        src, tgt = src[src_ixs], tgt[tgt_ixs]
        src_cond = None if src_cond is None else src_cond[src_ixs]

    model.vf_state, loss = model.step_fn(
        rng_step_fn,
        model.vf_state,
        src,
        tgt,
        src_cond,
    )

    training_logs["loss"].append(float(loss))
    if (it % 100 == 0) and (it > 0):
        valid_losses = []
        for cond in test_data_source.keys():
            src = test_data_source[cond]
            tgt = test_data_target[cond]
            src_cond = test_data_conditions[cond]
            if model.match_fn is not None:
                tmat = model.match_fn(src, tgt)
                src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
                src, tgt = src[src_ixs], tgt[tgt_ixs]
                src_cond = None if src_cond is None else src_cond[src_ixs]
            _, valid_loss = model.step_fn(
                rng,
                model.vf_state,
                src,
                tgt,
                src_cond,
            )
            valid_losses.append(valid_loss)

        # predicted_target_train = jax.tree_util.tree_map(model.transport, train_data_source, train_data_conditions)
        # train_metrics = jax.tree_util.tree_map(compute_metrics, train_data_target, predicted_target_train)
        # mean_train_metrics = compute_mean_metrics(train_metrics, prefix="train_")

        # predicted_target_train_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_train)
        # train_metrics_decoded = jax.tree_util.tree_map(compute_metrics, train_data_target_decoded, predicted_target_train_decoded)
        # mean_train_metrics_decoded = compute_mean_metrics(train_metrics_decoded, prefix="decoded_train_")

        # train_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, predicted_target_train_decoded, train_deg_dict)
        # train_deg_target_decoded = jax.tree_util.tree_map(get_mask, train_data_target_decoded, test_deg_dict)

        predicted_target_test = jax.tree_util.tree_map(model.transport, test_data_source, test_data_conditions)
        test_metrics = jax.tree_util.tree_map(compute_metrics, test_data_target, predicted_target_test)
        mean_test_metrics = compute_mean_metrics(test_metrics, prefix="test_")

        predicted_target_test_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_test)
        test_metrics_decoded = jax.tree_util.tree_map(
            compute_metrics, test_data_target_decoded, predicted_target_test_decoded
        )
        mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

        test_deg_target_decoded_predicted = jax.tree_util.tree_map(
            get_mask, predicted_target_test_decoded, test_deg_dict
        )
        test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)
        deg_test_metrics_encoded = jax.tree_util.tree_map(
            compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted
        )
        deg_mean_test_metrics_encoded = compute_mean_metrics(deg_test_metrics_encoded, prefix="deg_test_")

        predicted_target_ood = jax.tree_util.tree_map(model.transport, ood_data_source, ood_data_conditions)
        ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_data_target, predicted_target_ood)
        mean_ood_metrics = compute_mean_metrics(ood_metrics, prefix="ood_")

        predicted_target_ood_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_ood)
        ood_metrics_decoded = jax.tree_util.tree_map(
            compute_metrics, ood_data_target_decoded, predicted_target_ood_decoded
        )
        mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

        ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, predicted_target_ood_decoded, ood_deg_dict)
        ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)
        deg_ood_metrics_encoded = jax.tree_util.tree_map(
            compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted
        )
        deg_mean_ood_metrics_encoded = compute_mean_metrics(deg_ood_metrics_encoded, prefix="deg_ood_")

 10%|â–‰         | 98/1000 [00:14<00:09, 99.74it/s]

In [None]:
next(iter(dls[2]))

In [25]:
for cond in adata_train.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_train[adata_train.obs["condition"] == cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_train[adata_train.obs["condition"] == src_str[0] + "_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_train[adata_train.obs["condition"] == src_str[0] + "_Vehicle_0.0"].X.A
    target = adata_train[adata_train.obs["condition"] == cond].obsm[obsm_key_data]
    target_decoded = adata_train[adata_train.obs["condition"] == cond].X.A
    conds = adata_train[adata_train.obs["condition"] == cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))

    ds = datasets.OTDataset(
        datasets.OTData(
            lin=source,
            condition=conds,
        ),
        datasets.OTData(lin=target),
    )
    break

In [47]:
conds = np.tile(conds[0], (len(source), 1))

In [48]:
conds.shape

(2787, 1025)

In [30]:
batch = next(iter(ds))

In [31]:
batch["src_lin"].shape

(30,)

In [32]:
batch["tgt_lin"].shape

(30,)

In [33]:
batch["src_condition"].shape

(1025,)

In [39]:
ds.src_data.lin.shape

(2787, 30)

In [40]:
ds.tgt_data.lin.shape

(114, 30)

In [44]:
ds.src_data.condition.shape

(114, 1025)

In [36]:
dir(ds)

['SRC_PREFIX',
 'TGT_PREFIX',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_rng',
 '_sample_from_target',
 '_tgt_cond_to_ix',
 '_verify_integrity',
 'is_aligned',
 'src_conditions',
 'src_data',
 'tgt_conditions',
 'tgt_data']