In [1]:
import functools
from typing import Literal, Optional, Iterable

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
import scanpy as sc
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm, genot
from ott.neural.networks.layers import time_encoder
from ot_pert.nets.nets import CondVelocityField
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import joypy
import pandas as pd
import pickle
import yaml
from ot_pert.nets.nets import VelocityFieldWithAttention
from ot_pert.metrics import compute_metrics, compute_mean_metrics

In [2]:
adata_train_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_30.h5ad"
adata_test_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_30.h5ad"
adata_ood_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_30.h5ad"


In [3]:
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)


In [4]:
obsm_key_cond = "ecfp_cell_line_logdose_more_dose"
obsm_key_data = "X_pca"

In [5]:
source_dim=30
target_dim=30
condition_dim=1424

In [6]:
yaml_config = """
hidden_dims: [1024, 1024, 1024]
output_dims: [1024, 1024, 1024]
condition_dims: [4096, 4096, 4096, 4096]
time_dims: [512, 512, 512]
time_n_freqs: 1024
flow_noise: 1.0
learning_rate: 0.00005
multi_steps: 20
epsilon: 0.01
tau_a: 0.999
tau_b: 0.999
dropout_rate: 0.2
"""


In [7]:
model_config = yaml.safe_load(yaml_config)

In [8]:
vf = CondVelocityField(
    hidden_dims= model_config["hidden_dims"],
    output_dims= model_config["output_dims"] + [target_dim],
    condition_dims= model_config["condition_dims"],
    time_dims= model_config["time_dims"],
    time_encoder=functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024),
)

model = otfm.OTFlowMatching(
    vf,
    flow=dynamics.ConstantNoiseFlow(0),
    match_fn=None,
    condition_dim=condition_dim,
    rng=jax.random.PRNGKey(13),
    optimizer=optax.MultiSteps(optax.adam(learning_rate=1e-4), 20),
)


2024-06-18 14:51:20.352365: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [9]:
load_path = f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/sciplex_biolord_split/cerulean-deluge-3112_model.pkl"

# Open the file containing the saved parameters
with open(load_path, 'rb') as f:
    loaded_params = pickle.load(f)

check_load = jax.tree_util.tree_map(lambda x,y: x.shape==y.shape, model.vf_state.params, loaded_params)
def all_values_true(d):
    if isinstance(d, dict):
        return all(all_values_true(v) for v in d.values())
    return d is True
all_values_true(check_load)

True

In [10]:
def load_data(adata, obsm_key_data, obsm_key_cond):
    """Loads data and preprocesses it based on configuration."""
    dls = []
    data_source = {}
    data_target = {}
    data_source_decoded = {}
    data_target_decoded = {}
    data_conditions = {}
    for cond in adata.obs["condition"].cat.categories:
        if "Vehicle" not in cond:
            src_str_unique = list(adata[adata.obs["condition"] == cond].obs["cell_type"].unique())
            assert len(src_str_unique) == 1
            src_str = src_str_unique[0] + "_Vehicle_0.0"
            source = adata[adata.obs["condition"] == src_str].obsm[obsm_key_data]
            source_decoded = adata[adata.obs["condition"] == src_str].X.A
            target = adata[adata.obs["condition"] == cond].obsm[obsm_key_data]
            target_decoded = adata[adata.obs["condition"] == cond].X.A
            conds = adata[adata.obs["condition"] == cond].obsm[obsm_key_cond]
            assert np.all(np.all(conds == conds[0], axis=1))
            conds = np.tile(conds[0], (len(source), 1))
            data_source[cond] = source
            data_target[cond] = target
            data_source_decoded[cond] = source_decoded
            data_target_decoded[cond] = target_decoded
            data_conditions[cond] = conds
    return {
        "source": data_source,
        "target": data_target,
        "source_decoded": data_source_decoded,
        "target_decoded": data_target_decoded,
        "conditions": data_conditions,
    }


In [11]:
test_data = load_data(adata_test, obsm_key_data, obsm_key_cond) 
ood_data = load_data(adata_ood, obsm_key_data, obsm_key_cond) 
    

In [12]:
predictions_test = jtu.tree_map(model.transport, test_data["source"], test_data["conditions"])

In [13]:
predictions_ood = jtu.tree_map(model.transport, ood_data["source"], ood_data["conditions"])

In [25]:
conds = []
preds = []
for cond, pred in predictions_test.items():
    conds.extend(len(pred) * [cond])
    preds.append(np.asarray(pred))




In [15]:
from anndata import AnnData

In [30]:
adata_with_preds = AnnData(X=np.vstack(np.array(preds)), obs=pd.DataFrame(conds, columns=["condition"]))



In [32]:
adata_with_preds.write("/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/otfm/sciplex_biolord_30_predictions_test.h5ad")

In [None]:
conds = []
preds = []
for cond, pred in predictions_ood.items():
    conds.extend(len(pred) * [cond])
    preds.append(np.asarray(pred))

adata_with_preds_ood = AnnData(X=np.vstack(np.array(preds)), obs=pd.DataFrame(conds, columns=["condition"]))

In [None]:
adata_with_preds_ood.write("/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/otfm/sciplex_biolord_30_predictions_ood.h5ad")

In [33]:
def reconstruct_data(embedding, projection_matrix, mean_to_add):
    """Reconstructs data from projections."""
    return np.matmul(embedding, projection_matrix.T) + mean_to_add


In [34]:
reconstruct_data_fn = functools.partial(
    reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T
)

In [35]:
#predictions_test_decoded = jtu.tree_map(reconstruct_data_fn, predictions_test)
predictions_ood_decoded = jtu.tree_map(reconstruct_data_fn, predictions_ood)

In [36]:
#test_metrics_encoded = jax.tree_util.tree_map(compute_metrics, test_data["target"], predictions_test)
#mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

#test_metrics_decoded = jax.tree_util.tree_map(compute_metrics, test_data["target_decoded"], predictions_test_decoded)
#mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [None]:
ood_metrics_encoded = jax.tree_util.tree_map(compute_metrics, ood_data["target"], predictions_ood)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(compute_metrics, ood_data["target_decoded"], predictions_ood_decoded)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [None]:
mean_ood_metrics_decoded