In [1]:
import functools
from typing import Literal, Optional

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
import scanpy as sc
from flax import linen as nn
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks.layers import time_encoder
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
from tqdm import tqdm

from ot_pert.metrics import compute_mean_metrics, compute_metrics_fast
from ot_pert.utils import ConditionalLoader

In [2]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

In [3]:
obsm_key_cond = "ecfp_dose_cell_line"
obsm_key_data = "X_pca"

In [4]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_30.h5ad"

In [5]:
import functools
import sys
import traceback

import hydra
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
import scanpy as sc
import wandb
from omegaconf import DictConfig, OmegaConf
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks.layers import time_encoder
from ott.neural.networks.velocity_field import VelocityField
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
from tqdm import tqdm

from ot_pert.metrics import compute_mean_metrics, compute_metrics, compute_metrics_fast
from ot_pert.utils import ConditionalLoader


def reconstruct_data(embedding, projection_matrix, mean_to_add):
    """Reconstructs data from projections."""
    return np.matmul(embedding, projection_matrix.T) + mean_to_add



def load_data(adata, cfg, *, return_dl: bool):
    """Loads data and preprocesses it based on configuration."""
    dls = []
    data_source = {}
    data_target = {}
    data_source_decoded = {}
    data_target_decoded = {}
    data_conditions = {}
    for cond in adata.obs["condition"].cat.categories:
        if "Vehicle" not in cond:
            src_str_unique = list(adata[adata.obs["condition"] == cond].obs["cell_type"].unique())
            assert len(src_str_unique) == 1
            src_str = src_str_unique[0] + "_Vehicle_0.0"
            source = adata[adata.obs["condition"] == src_str[0] + "_Vehicle_0.0"].obsm[obsm_key_data]
            source_decoded = adata[adata.obs["condition"] == src_str[0] + "_Vehicle_0.0"].X.A
            target = adata[adata.obs["condition"] == cond].obsm[obsm_key_data]
            target_decoded = adata[adata.obs["condition"] == cond].X.A
            conds = adata[adata.obs["condition"] == cond].obsm[obsm_key_cond]
            assert np.all(np.all(conds == conds[0], axis=1))
            conds = np.tile(conds[0], (len(source), 1))
            if return_dl:
                dls.append(
                    DataLoader(
                        datasets.OTDataset(
                            datasets.OTData(
                                lin=source,
                                condition=conds,
                            ),
                            datasets.OTData(lin=target),
                        ),
                        batch_size=512,
                        shuffle=True,
                    )
                )
            else:
                data_source[cond] = source
                data_target[cond] = target
                data_source_decoded[cond] = source_decoded
                data_target_decoded[cond] = target_decoded
                data_conditions[cond] = conds
    if return_dl:
        return ConditionalLoader(dls, seed=0)

    deg_dict = {k: v for k, v in adata.uns["rank_genes_groups_cov_all"].items() if k in data_conditions.keys()}

    return {
        "source": data_source,
        "target": data_target,
        "source_decoded": data_source_decoded,
        "target_decoded": data_target_decoded,
        "conditions": data_conditions,
        "deg_dict": deg_dict,
    }


def data_matching_function(src_lin, tgt_lin, src_quad, tgt_quad, typ, epsilon=1e-2, tau_a=1.0, tau_b=1.0):
    """Defines how data should be matched based on the type."""
    match_fn = {
        "lin": lambda: solver_utils.match_linear(
            x=src_lin, y=tgt_lin, scale_cost="mean", epsilon=epsilon, tau_a=tau_a, tau_b=tau_b
        ),
        "quad": lambda: solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad),
        "fused": lambda: solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin),
    }
    return match_fn.get(typ, lambda: None)()


def get_mask(x, y, var_names):
    return x[:, [gene in y for gene in var_names]]


    

def eval_step(model, data, log_metrics, reconstruct_data_fn, comp_metrics_fn, mask_fn):
    for k, dat in data.items():
        if len(dat) > 1:
            prediction = jtu.tree_map(model.transport, dat["source"], dat["conditions"])
            metrics = jtu.tree_map(comp_metrics_fn, dat["target"], prediction)
            mean_metrics = compute_mean_metrics(metrics, prefix=f"{k}_")
            log_metrics.update(mean_metrics)

            prediction_decoded = jtu.tree_map(reconstruct_data_fn, prediction)
            metrics_decoded = jtu.tree_map(comp_metrics_fn, dat["target_decoded"], prediction_decoded)
            mean_metrics_decoded = compute_mean_metrics(metrics_decoded, prefix=f"decoded_{k}_")
            log_metrics.update(mean_metrics_decoded)

            prediction_decoded_deg = jtu.tree_map(mask_fn, prediction_decoded, dat["deg"])
            metrics_deg = jtu.tree_map(comp_metrics_fn, dat["target_decoded_deg"], prediction_decoded_deg)
            mean_metrics_deg = compute_mean_metrics(metrics_deg, prefix=f"deg_{k}_")
            log_metrics.update(mean_metrics_deg)
    wandb.log(log_metrics)




In [None]:
adata_train = sc.read_h5ad(adata_train_path)
sc.pp.subsample(adata_train, fraction=0.1)
adata_test = sc.read_h5ad(adata_test_path) 
adata_ood = sc.read_h5ad(adata_ood_path)
dl = load_data(adata_train, cfg, return_dl=True)
train_data = {}
test_data = load_data(adata_test, cfg, return_dl=False)
ood_data = load_data(adata_ood, cfg, return_dl=False)
comp_metrics_fn = compute_metrics_fast if True else compute_metrics

reconstruct_data_fn = functools.partial(
    reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T
)
mask_fn = functools.partial(get_mask, var_names=adata_train.var_names)

batch = next(dl)
output_dim = batch["tgt_lin"].shape[1]

vf = VelocityField(
    hidden_dims=[512, 512],
    time_dims=[512, 512],
    output_dims=[512, 512] + output_dim,
    condition_dims=[512, 512],
    time_encoder=functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024),
)

model = otfm.OTFlowMatching(
    vf,
    flow=dynamics.ConstantNoiseFlow(1.0),
    match_fn=jax.jit(
        functools.partial(
            data_matching_function,
            typ="lin",
            epsilon=1.0,
            tau_a=1.0,
            tau_b=1.0,
        )
    ),
    condition_dim=[512, 512],
    rng=jax.random.PRNGKey(13),
    optimizer=optax.MultiSteps(optax.adam(1e-4), 10),
)

training_logs = {"loss": []}
rng = jax.random.PRNGKey(0)

for it in tqdm(range(100)):
    rng, rng_resample, rng_step_fn = jax.random.split(rng, 3)
    batch = next(dl)
    batch = jtu.tree_map(jnp.asarray, batch)

    src, tgt = batch["src_lin"], batch["tgt_lin"]
    src_cond = batch.get("src_condition")

    if model.match_fn is not None:
        tmat = model.match_fn(src, tgt)
        src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
        src, tgt = src[src_ixs], tgt[tgt_ixs]
        src_cond = None if src_cond is None else src_cond[src_ixs]

    model.vf_state, loss = model.step_fn(
        rng_step_fn,
        model.vf_state,
        src,
        tgt,
        src_cond,
    )

    training_logs["loss"].append(float(loss))
    if (it % 10 == 0) and (it > 0):
        train_loss = np.mean(training_logs["loss"][-10 :])
        log_metrics = {"train_loss": train_loss}
        eval_step(
            model,
            {"train": train_data, "test": test_data, "ood": ood_data},
            log_metrics,
            reconstruct_data_fn,
            comp_metrics_fn,
            mask_fn,
        )

