In [2]:
import hydra
import wandb
from ott.neural import datasets
import sys
from omegaconf import DictConfig
import jax.numpy as jnp
from jax import random
from typing import Optional, Literal
import jax
import pathlib
import optax
import yaml
from datetime import datetime
from flax import linen as nn
import functools
from tqdm import tqdm
from flax.training import train_state

from ott.neural.networks.layers import time_encoder
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks import velocity_field
from ott.solvers import utils as solver_utils
import jax.tree_util as jtu
from ott.neural.networks.layers import time_encoder
from ott.neural.networks.velocity_field import VelocityField
import pandas as pd
import os

import hydra
from omegaconf import DictConfig, OmegaConf

from torch.utils.data import DataLoader
import numpy as np

import scanpy as sc
from ot_pert.metrics import compute_metrics_fast, compute_mean_metrics
from ot_pert.nets.nets import VelocityFieldWithAttention
from ot_pert.utils import ConditionalLoader


In [3]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

In [4]:
obsm_key_cond = "ecfp_logdose_cell_line"
obsm_key_data = "X_pca"

In [5]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_30.h5ad"

In [6]:
adata_train = sc.read(adata_train_path)



In [7]:
sc.pp.subsample(adata_train, fraction=0.5)

In [8]:
def data_match_fn(
    src_lin: Optional[jnp.ndarray], tgt_lin: Optional[jnp.ndarray],
    src_quad: Optional[jnp.ndarray], tgt_quad: Optional[jnp.ndarray], *,
    typ: Literal["lin", "quad", "fused"], epsilon: float = 1e-2, tau_a: float = 1.0,
    tau_b: float = 1.0,
) -> jnp.ndarray:
    if typ == "lin":
        return solver_utils.match_linear(x=src_lin, y=tgt_lin, scale_cost="mean", epsilon=epsilon, tau_a=tau_a, tau_b=tau_b)
    if typ == "quad":
        return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad)
    if typ == "fused":
        return solver_utils.match_quadratic(
            xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin
        )
    raise NotImplementedError(f"Unknown type: {typ}.")

# Load data

dls = []

train_data_source = {}
train_data_target = {}
train_data_source_decoded = {}
train_data_target_decoded = {}
train_data_conditions = {}


for cond in adata_train.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_train[adata_train.obs["condition"]==cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_train[adata_train.obs["condition"]==src_str[0]+"_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_train[adata_train.obs["condition"]==src_str[0]+"_Vehicle_0.0"].X.A
    target = adata_train[adata_train.obs["condition"]==cond].obsm[obsm_key_data]
    target_decoded = adata_train[adata_train.obs["condition"]==cond].X.A
    conds = adata_train[adata_train.obs["condition"]==cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    dls.append(DataLoader(datasets.OTDataset(datasets.OTData(
        lin=source,
        condition=conds,
    ), datasets.OTData(lin=target)), batch_size=10, shuffle=True))
    train_data_source[cond] = source
    train_data_target[cond] = target
    train_data_conditions[cond] = conds
    train_data_source_decoded[cond] = source_decoded
    train_data_target_decoded[cond] = target_decoded

train_loader = ConditionalLoader(dls, seed=0)

test_data_source = {}
test_data_target = {}
test_data_source_decoded = {}
test_data_target_decoded = {}
test_data_conditions = {}
adata_test = sc.read(adata_test_path)
for cond in adata_test.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_test[adata_test.obs["condition"]==cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_test[adata_test.obs["condition"]==src_str[0]+"_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_test[adata_test.obs["condition"]==src_str[0]+"_Vehicle_0.0"].X.A
    
    target = adata_test[adata_test.obs["condition"]==cond].obsm[obsm_key_data]
    target_decoded = adata_test[adata_test.obs["condition"]==cond].X.A
    conds = adata_test[adata_test.obs["condition"]==cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    test_data_source[cond] = source
    test_data_target[cond] = target
    test_data_source_decoded[cond] = source_decoded
    test_data_target_decoded[cond] = target_decoded
    test_data_conditions[cond] = conds

ood_data_source = {}
ood_data_target = {}
ood_data_source_decoded = {}
ood_data_target_decoded = {}
ood_data_conditions = {}
adata_ood = sc.read(adata_ood_path)
for cond in adata_ood.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_ood[adata_ood.obs["condition"]==cond].obs["cell_type"].unique())
    assert len(src_str) == 1
    source = adata_ood[adata_ood.obs["condition"]==src_str[0]+"_Vehicle_0.0"].obsm[obsm_key_data]
    source_decoded = adata_ood[adata_ood.obs["condition"]==src_str[0]+"_Vehicle_0.0"].X.A
    conds = adata_ood[adata_ood.obs["condition"]==cond].obsm[obsm_key_cond]
    assert np.all(np.all(conds == conds[0], axis=1))
    conds = np.tile(conds[0], (len(source), 1))
    ood_data_source[cond] = source
    ood_data_target[cond] = target
    ood_data_source_decoded[cond] = source_decoded
    ood_data_target_decoded[cond] = target_decoded
    ood_data_conditions[cond] = conds

reconstruct_data_fn = functools.partial(reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T)


2024-04-19 10:57:32.392060: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [9]:
reconstruct_data_fn = functools.partial(reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T)

train_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in train_data_conditions.keys()}
test_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in test_data_conditions.keys()}
ood_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in ood_data_conditions.keys()}

def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [10]:
source_dim = source.shape[1]
target_dim = source_dim
condition_dim = conds.shape[1]

source_dim = source.shape[1]
target_dim = source_dim
condition_dim = conds.shape[1]

In [18]:
vf = VelocityField(
    hidden_dims=[1024, 1024, 1024],
    time_dims=[512,512],
    output_dims=[1024, 1024, 1024]+[target_dim],
    condition_dims=[2048, 2048],
    time_encoder = functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024),
    )

model = otfm.OTFlowMatching(vf,
    flow=dynamics.ConstantNoiseFlow(0),
    match_fn=jax.jit(functools.partial(data_match_fn, typ="lin", src_quad=None, tgt_quad=None, epsilon=0.1, tau_a=1.0, tau_b=1.0)),
    condition_dim=condition_dim,
    rng=jax.random.PRNGKey(13),
    optimizer=optax.MultiSteps(optax.adam(learning_rate=1e-4), 20)
)

training_logs = {"loss": []}

In [None]:


rng = jax.random.PRNGKey(0)
for it in tqdm(range(100000)):
    rng, rng_resample, rng_step_fn = jax.random.split(rng, 3)
    batch = next(train_loader)
    batch = jtu.tree_map(jnp.asarray, batch)

    src, tgt = batch["src_lin"], batch["tgt_lin"]
    src_cond = batch.get("src_condition")

    if model.match_fn is not None:
        tmat = model.match_fn(src, tgt)
        src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
        src, tgt = src[src_ixs], tgt[tgt_ixs]
        src_cond = None if src_cond is None else src_cond[src_ixs]

    model.vf_state, loss = model.step_fn(
        rng_step_fn,
        model.vf_state,
        src,
        tgt,
        src_cond,
    )

    training_logs["loss"].append(float(loss))
    if (it % 10000 == 0) and (it > 0):
        idcs = np.random.choice(list(test_data_source.keys()), 20)
        test_data_source_tmp = {k:v for k,v in test_data_source.items() if k in idcs}
        test_data_target_tmp = {k:v for k,v in test_data_target.items() if k in idcs}
        test_data_conditions_tmp = {k:v for k,v in test_data_conditions.items() if k in idcs}
        test_data_target_decoded_tmp = {k:v for k,v in test_data_target_decoded.items() if k in idcs}
        test_deg_dict_tmp = {k:v for k,v in test_deg_dict.items() if k in idcs}
        valid_losses = []
        #for cond in test_data_source_tmp.keys():
        #    src = test_data_source_tmp[cond]
        #    tgt = test_data_target_tmp[cond]
        #    src_cond = test_data_conditions_tmp[cond]
        #    if model.match_fn is not None:
        #        tmat = model.match_fn(src, tgt)
        #        src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
        #        src, tgt = src[src_ixs], tgt[tgt_ixs]
        #        src_cond = None if src_cond is None else src_cond[src_ixs]
        #    _, valid_loss = model.step_fn(
        #        rng,
        #        model.vf_state,
        #        src,
        #        tgt,
        #        src_cond,
        #    )
        #    valid_losses.append(valid_loss)

        # predicted_target_train = jax.tree_util.tree_map(model.transport, train_data_source, train_data_conditions)
        # train_metrics = jax.tree_util.tree_map(compute_metrics_fast, train_data_target, predicted_target_train)
        # mean_train_metrics = compute_mean_metrics(train_metrics, prefix="train_")

        # predicted_target_train_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_train)
        # train_metrics_decoded = jax.tree_util.tree_map(compute_metrics_fast, train_data_target_decoded, predicted_target_train_decoded)
        # mean_train_metrics_decoded = compute_mean_metrics(train_metrics_decoded, prefix="decoded_train_")

        # train_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, predicted_target_train_decoded, train_deg_dict)
        # train_deg_target_decoded = jax.tree_util.tree_map(get_mask, train_data_target_decoded, test_deg_dict)

        predicted_target_test = jax.tree_util.tree_map(model.transport, test_data_source_tmp, test_data_conditions_tmp)
        #test_metrics = jax.tree_util.tree_map(compute_metrics_fast, test_data_target_tmp, predicted_target_test)
        #mean_test_metrics = compute_mean_metrics(test_metrics, prefix="test_")

        predicted_target_test_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_test)
        test_metrics_decoded = jax.tree_util.tree_map(
            compute_metrics_fast, test_data_target_decoded_tmp, predicted_target_test_decoded
        )
        mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

        
        #test_deg_target_decoded_predicted = jax.tree_util.tree_map(
        #    get_mask, predicted_target_test_decoded, test_deg_dict_tmp
        #)
        #test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded_tmp, test_deg_dict_tmp)
        #deg_test_metrics_encoded = jax.tree_util.tree_map(
        #    compute_metrics_fast, test_deg_target_decoded, test_deg_target_decoded_predicted
        #)
        #deg_mean_test_metrics_encoded = compute_mean_metrics(deg_test_metrics_encoded, prefix="deg_test_")

        predicted_target_ood = jax.tree_util.tree_map(model.transport, ood_data_source, ood_data_conditions)
        ood_metrics = jax.tree_util.tree_map(compute_metrics_fast, ood_data_target, predicted_target_ood)
        mean_ood_metrics = compute_mean_metrics(ood_metrics, prefix="ood_")

        predicted_target_ood_decoded = jax.tree_util.tree_map(reconstruct_data_fn, predicted_target_ood)
        ood_metrics_decoded = jax.tree_util.tree_map(
            compute_metrics_fast, ood_data_target_decoded, predicted_target_ood_decoded
        )
        mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

        #ood_deg_target_decoded_predicted = jax.tree_util.tree_map(
        #    get_mask, predicted_target_ood_decoded, ood_deg_dict
        #)
        #ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)
        #deg_ood_metrics_encoded = jax.tree_util.tree_map(
        #    compute_metrics_fast, ood_deg_target_decoded, ood_deg_target_decoded_predicted
        #)
        #deg_mean_ood_metrics_encoded = compute_mean_metrics(deg_ood_metrics_encoded, prefix="deg_ood_")
        print(mean_test_metrics_decoded, mean_ood_metrics_decoded)


 10%|█         | 10002/100000 [07:30<220:50:56,  8.83s/it]

{'decoded_test_r_squared': 0.955759099095201, 'decoded_test_e_distance': 1.952600192954521, 'decoded_test_mmd_distance': 0.04702994282136224} {'decoded_ood_r_squared': -0.7570388686841591, 'decoded_ood_e_distance': 53.50669506605293, 'decoded_ood_mmd_distance': 0.13306989534852176}


 20%|█▉        | 19998/100000 [08:57<09:00, 147.88it/s]   

In [12]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': -0.6659773320388185,
 'decoded_ood_e_distance': 50.73361927033098,
 'decoded_ood_mmd_distance': 0.12942688787855533}

In [13]:
ood_metrics_decoded

{'A549_A-366_10.0': {'r_squared': -1.892633773604254,
  'e_distance': 88.08870271998903,
  'mmd_distance': 0.22087055997840274},
 'A549_A-366_100.0': {'r_squared': -1.892685998299453,
  'e_distance': 88.0902931052567,
  'mmd_distance': 0.22086954061269737},
 'A549_A-366_10000.0': {'r_squared': -1.8928958562553757,
  'e_distance': 88.09668385847338,
  'mmd_distance': 0.2208850891576631},
 'A549_Carmofur_10.0': {'r_squared': -1.8927973642599647,
  'e_distance': 88.09368450540867,
  'mmd_distance': 0.22088357537043088},
 'A549_Carmofur_100.0': {'r_squared': -1.8930694173989373,
  'e_distance': 88.10196927567513,
  'mmd_distance': 0.22089576104956812},
 'A549_Disulfiram__10.0': {'r_squared': -1.8725561412197287,
  'e_distance': 87.47728324135454,
  'mmd_distance': 0.2199586645957762},
 'A549_Disulfiram__1000.0': {'r_squared': -1.8729995933796886,
  'e_distance': 87.49078758452193,
  'mmd_distance': 0.21997403666989754},
 'A549_Disulfiram__10000.0': {'r_squared': -1.8731945827348837,
  'e_d

In [15]:
len([el for el in test_data_source.keys() if el.startswith("K562")])

353

In [16]:
len([el for el in test_data_source.keys() if el.startswith("MCF7")])

619

In [17]:
len([el for el in test_data_source.keys() if el.startswith("A549")])

303