In [1]:
import functools
from typing import Literal, Optional, Iterable

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
import scanpy as sc
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm, genot
from ott.neural.networks.layers import time_encoder
from ot_pert.nets.nets import CondVelocityField
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import joypy
import pandas as pd
import pickle
import yaml
from ot_pert.nets.nets import VelocityFieldWithAttention
from ot_pert.metrics import compute_metrics, compute_mean_metrics

In [2]:
adata_train_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_300.h5ad"
adata_test_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_300.h5ad"
adata_ood_path= "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_300.h5ad"


In [3]:
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)
   

In [4]:
OBSM_KEY_COND_1 = "ecfp_drug_1"
OBSM_KEY_COND_2 = "ecfp_drug_2"
OBSM_KEY_DATA = "X_pca"

In [5]:
def load_data(adata):
    """Loads data and preprocesses it based on configuration."""
    dls = []
    data_source = {}
    data_target = {}
    data_source_decoded = {}
    data_target_decoded = {}
    data_conditions = {}
    source = adata[adata.obs["condition"] == "control"].obsm[OBSM_KEY_DATA]
    source_decoded = adata[adata.obs["condition"] == "control"].X

    for cond in adata.obs["condition"].cat.categories:
        if cond != "control":
            target = adata[adata.obs["condition"] == cond].obsm[OBSM_KEY_DATA]
            target_decoded = adata[adata.obs["condition"] == cond].X.A
            condition_1 = adata[adata.obs["condition"] == cond].obsm[OBSM_KEY_COND_1]
            condition_2 = adata[adata.obs["condition"] == cond].obsm[OBSM_KEY_COND_2]
            assert np.all(np.all(condition_1 == condition_1[0], axis=1))
            assert np.all(np.all(condition_2 == condition_2[0], axis=1))
            expanded_arr = np.expand_dims(
                np.concatenate((condition_1[0, :][None, :], condition_2[0, :][None, :]), axis=0), axis=0
            )
            conds = np.tile(expanded_arr, (len(source), 1, 1))

            
            data_source[cond] = source
            data_target[cond] = target
            data_source_decoded[cond] = source_decoded
            data_target_decoded[cond] = target_decoded
            data_conditions[cond] = conds
    deg_dict = {k: v for k, v in adata.uns["rank_genes_groups_cov_all"].items() if k in data_conditions.keys()}

    return {
        "source": data_source,
        "target": data_target,
        "source_decoded": data_source_decoded,
        "target_decoded": data_target_decoded,
        "conditions": data_conditions,
        "deg_dict": deg_dict,
    }
    

In [6]:
test_data = load_data(adata_test) 
ood_data = load_data(adata_ood) 
    

In [7]:
yaml_config = """
num_heads: 4
qkv_feature_dim: 16
max_seq_length: 2
hidden_dims: [1024, 1024, 1024]
output_dims: [1024, 1024, 1024, 1024]
condition_dims: [256, 256, 256]
time_dims: [512, 512, 512]
time_n_freqs: 1024
flow_noise: 1.0
learning_rate: 0.00005
multi_steps: 20
epsilon: 0.01
tau_a: 0.999
tau_b: 0.999
dropout_rate: 0.1
"""


In [8]:
model_config = yaml.safe_load(yaml_config)

In [9]:
model_config


{'num_heads': 4,
 'qkv_feature_dim': 16,
 'max_seq_length': 2,
 'hidden_dims': [1024, 1024, 1024],
 'output_dims': [1024, 1024, 1024, 1024],
 'condition_dims': [256, 256, 256],
 'time_dims': [512, 512, 512],
 'time_n_freqs': 1024,
 'flow_noise': 1.0,
 'learning_rate': 5e-05,
 'multi_steps': 20,
 'epsilon': 0.01,
 'tau_a': 0.999,
 'tau_b': 0.999,
 'dropout_rate': 0.1}

In [10]:
output_dim = 300
condition_dim = 1024

In [11]:
vf = VelocityFieldWithAttention(
    num_heads=model_config["num_heads"],
    qkv_feature_dim=model_config["qkv_feature_dim"],
    max_seq_length=model_config["max_seq_length"],
    hidden_dims=model_config["hidden_dims"],
    time_dims=model_config["time_dims"],
    output_dims=model_config["output_dims"] + [output_dim],
    condition_dims=model_config["condition_dims"],
    dropout_rate=model_config["dropout_rate"],
    time_encoder=functools.partial(time_encoder.cyclical_time_encoder, n_freqs=model_config["time_n_freqs"]),
)


model = otfm.OTFlowMatching(
    vf,
    flow=dynamics.ConstantNoiseFlow(model_config["flow_noise"]),
    match_fn=None,
    condition_dim=condition_dim,
    rng=jax.random.PRNGKey(13),
    optimizer=optax.MultiSteps(optax.adam(model_config["learning_rate"]), model_config["multi_steps"]),
)


2024-06-07 09:28:59.729865: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [12]:
load_path = f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/combosciplex/confused-feather-2848_model.pkl"

# Open the file containing the saved parameters
with open(load_path, 'rb') as f:
    loaded_params = pickle.load(f)

In [13]:
from flax.core.frozen_dict import FrozenDict
new_params = FrozenDict(loaded_params)

if hasattr(model.vf_state, 'replace'):
    model.vf_state = model.vf_state.replace(params=new_params)
else:
    raise ValueError

In [14]:
predictions_test = jtu.tree_map(model.transport, test_data["source"], test_data["conditions"])

In [15]:
predictions_ood = jtu.tree_map(model.transport, ood_data["source"], ood_data["conditions"])

In [16]:
def reconstruct_data(embedding, projection_matrix, mean_to_add):
    """Reconstructs data from projections."""
    return np.matmul(embedding, projection_matrix.T) + mean_to_add


In [17]:
reconstruct_data_fn = functools.partial(
    reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T
)

In [18]:
predictions_test_decoded = jtu.tree_map(reconstruct_data_fn, predictions_test)
predictions_ood_decoded = jtu.tree_map(reconstruct_data_fn, predictions_ood)

In [19]:
test_metrics_encoded = jax.tree_util.tree_map(compute_metrics, test_data["target"], predictions_test)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(compute_metrics, test_data["target_decoded"], predictions_test_decoded)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [20]:
ood_metrics_encoded = jax.tree_util.tree_map(compute_metrics, ood_data["target"], predictions_ood)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(compute_metrics, ood_data["target_decoded"], predictions_ood_decoded)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [21]:
test_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in test_data["conditions"].keys()}
ood_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in ood_data["conditions"].keys()}

In [22]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.9854142226833035,
 'decoded_test_sinkhorn_div_1': 115.47785362830528,
 'decoded_test_sinkhorn_div_10': 79.89690457857571,
 'decoded_test_sinkhorn_div_100': 3.126964862530048,
 'decoded_test_e_distance': 2.6889577248873056,
 'decoded_test_mmd': 0.016066037292163278}

In [23]:
def get_mask(x, y, var_names):
    return x[:, [gene in y for gene in var_names]]

mask_fn = functools.partial(get_mask, var_names=adata_train.var_names)

prediction_decoded_test_deg = jtu.tree_map(mask_fn, predictions_test_decoded, test_deg_dict)
target_decoded_test_deg = jax.tree_util.tree_map(mask_fn, test_data["target_decoded"], test_deg_dict)

prediction_decoded_ood_deg = jtu.tree_map(mask_fn, predictions_ood_decoded, ood_deg_dict)
target_decoded_ood_deg = jax.tree_util.tree_map(mask_fn, ood_data["target_decoded"], ood_deg_dict)
            

In [24]:

predictions_test_deg = jax.tree_util.tree_map(mask_fn, predictions_test_decoded, test_deg_dict)
target_test_deg = jax.tree_util.tree_map(mask_fn, test_data["target_decoded"], test_deg_dict)

predictions_ood_deg = jax.tree_util.tree_map(mask_fn, predictions_ood_decoded, ood_deg_dict)
target_ood_deg = jax.tree_util.tree_map(mask_fn, ood_data["target_decoded"], ood_deg_dict)

In [25]:
test_metrics_deg = jtu.tree_map(compute_metrics, target_test_deg, predictions_test_deg)
mean_metrics_deg = compute_mean_metrics(test_metrics_deg, prefix=f"deg_test_")

ood_metrics_deg = jtu.tree_map(compute_metrics,target_ood_deg, predictions_ood_deg)
mean_metrics_deg = compute_mean_metrics(ood_metrics_deg, prefix=f"deg_ood_")
            

In [26]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.9586435646612876,
 'decoded_ood_sinkhorn_div_1': 114.75249633789062,
 'decoded_ood_sinkhorn_div_10': 69.23498077392578,
 'decoded_ood_sinkhorn_div_100': 4.369729614257812,
 'decoded_ood_e_distance': 7.926962827305201,
 'decoded_ood_mmd': 0.013214902879714496}

In [27]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/otfm"

In [28]:
import os
import pandas as pd
pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "ood_metrics_encoded.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "ood_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded.csv"))
pd.DataFrame.from_dict(ood_metrics_deg).to_csv(os.path.join(output_dir, "ood_metrics_deg.csv"))
pd.DataFrame.from_dict(test_metrics_deg).to_csv(os.path.join(output_dir, "test_metrics_deg.csv"))