In [1]:
import scanpy as sc
import numpy as np
import functools
import jax
from ot_pert.metrics import compute_metrics, compute_mean_metrics

In [2]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_30.h5ad"

In [3]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)




In [4]:
adata_train.obs["split"].value_counts()

split
train    51882
Name: count, dtype: int64

In [5]:
adata_test.obs["split"].value_counts()

split
test    3100
Name: count, dtype: int64

In [6]:
adata_ood.obs["split"].value_counts()

split
ood    8896
Name: count, dtype: int64

In [7]:
adata_ood.obs["condition"].value_counts()

condition
Givinostat+SRT1720         2260
Cediranib+PCI-34051        2161
Panobinostat+PCI-34051     1814
Panobinostat+Crizotinib    1641
SRT2104+Alvespimycin        520
control                     500
Name: count, dtype: int64

In [8]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

def project_data(data: np.ndarray, projection_matrix: np.ndarray, mean_to_subtract: np.ndarray) -> np.ndarray:
    return np.matmul(data - mean_to_subtract, projection_matrix)

In [9]:
project_data_fn = functools.partial(project_data, projection_matrix = adata_train.varm["PCs"], mean_to_subtract=adata_train.varm["X_train_mean"].T)

In [10]:
OBSM_KEY_DATA_EMBEDDING = "X_pca"

In [11]:
train_data_target_encoded = {}
train_data_target_decoded = {}
train_data_target_encoded_predicted = {}
train_data_target_decoded_predicted = {}
pred_id = adata_train[adata_train.obs["condition"]=="control"].X.A
for cond in adata_train.obs["condition"].cat.categories:
    if cond == "control":
        continue
    train_data_target_encoded[cond] = adata_train[adata_train.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    train_data_target_decoded[cond] = adata_train[adata_train.obs["condition"]==cond].X.A
    train_data_target_decoded_predicted[cond] = pred_id
    train_data_target_encoded_predicted[cond] = project_data_fn(pred_id)

pred_id = adata_test[adata_test.obs["condition"]=="control"].X.A
test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"]==cond].X.A
    test_data_target_decoded_predicted[cond] = pred_id
    test_data_target_encoded_predicted[cond] = project_data_fn(pred_id)

pred_id = adata_ood[adata_ood.obs["condition"]=="control"].X.A
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"]==cond].X.A
    ood_data_target_decoded_predicted[cond] = pred_id
    ood_data_target_encoded_predicted[cond] = project_data_fn(pred_id)



In [12]:
train_metrics_encoded = jax.tree_util.tree_map(compute_metrics, train_data_target_encoded, train_data_target_encoded_predicted)
mean_train_metrics_encoded = compute_mean_metrics(train_metrics_encoded, prefix="encoded_train_")

train_metrics_decoded = jax.tree_util.tree_map(compute_metrics, train_data_target_decoded, train_data_target_decoded_predicted)
mean_train_metrics_decoded = compute_mean_metrics(train_metrics_decoded, prefix="decoded_train_")

2024-04-19 08:55:17.576377: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [13]:
test_metrics_encoded = jax.tree_util.tree_map(compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(compute_metrics, test_data_target_decoded, test_data_target_decoded_predicted)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [14]:
ood_metrics_encoded = jax.tree_util.tree_map(compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")




In [15]:
train_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in train_data_target_decoded_predicted.keys()}
test_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in test_data_target_decoded_predicted.keys()}
ood_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in ood_data_target_decoded_predicted.keys()}

In [16]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)

train_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, train_data_target_decoded_predicted, train_deg_dict)
train_deg_target_decoded = jax.tree_util.tree_map(get_mask, train_data_target_decoded, test_deg_dict)

In [17]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")

deg_train_metrics = jax.tree_util.tree_map(compute_metrics, train_deg_target_decoded, train_deg_target_decoded_predicted)
deg_mean_train_metrics = compute_mean_metrics(deg_train_metrics, prefix="deg_train_")

In [18]:
deg_mean_ood_metrics

{'deg_ood_r_squared': 0.380471824796215,
 'deg_ood_sinkhorn_div_1': 32.49099349975586,
 'deg_ood_sinkhorn_div_10': 20.674528503417967,
 'deg_ood_sinkhorn_div_100': 20.04779815673828,
 'deg_ood_e_distance': 39.97579202417754,
 'deg_ood_mmd': 0.08912052417003327}

In [19]:
deg_mean_test_metrics

{'deg_test_r_squared': 0.42197210858369394,
 'deg_test_sinkhorn_div_1': 32.2918458535121,
 'deg_test_sinkhorn_div_10': 19.429043182959923,
 'deg_test_sinkhorn_div_100': 18.320143112769493,
 'deg_test_e_distance': 36.493777542369685,
 'deg_test_mmd': 0.09369762841545004}

In [20]:
deg_mean_train_metrics

{'deg_train_r_squared': 0.42610602479540144,
 'deg_train_sinkhorn_div_1': 30.306968743984516,
 'deg_train_sinkhorn_div_10': 18.768293600815994,
 'deg_train_sinkhorn_div_100': 18.177710973299465,
 'deg_train_e_distance': 36.24290576049266,
 'deg_train_mmd': 0.08588872974007322}

In [21]:
mean_train_metrics_decoded

{'decoded_train_r_squared': 0.6876376148189827,
 'decoded_train_sinkhorn_div_1': 175.35013404259314,
 'decoded_train_sinkhorn_div_10': 127.87665616548978,
 'decoded_train_sinkhorn_div_100': 30.37559274526743,
 'decoded_train_e_distance': 59.652706027587904,
 'decoded_train_mmd': 0.034737392420857736}

In [22]:
mean_test_metrics_encoded

{'encoded_test_r_squared': -0.8760768311432954,
 'encoded_test_sinkhorn_div_1': 45.95963635811439,
 'encoded_test_sinkhorn_div_10': 31.896249477679913,
 'encoded_test_sinkhorn_div_100': 29.60800735767071,
 'encoded_test_e_distance': 58.87440579010694,
 'encoded_test_mmd': 0.11540636236483498}

In [23]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.6781521682861769,
 'decoded_test_sinkhorn_div_1': 179.88092275766226,
 'decoded_test_sinkhorn_div_10': 144.52333098191482,
 'decoded_test_sinkhorn_div_100': 33.009776188777046,
 'decoded_test_e_distance': 61.68167901269651,
 'decoded_test_mmd': 0.04447886694964587}

In [24]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': -0.2991388146995564,
 'encoded_ood_sinkhorn_div_1': 46.984486961364745,
 'encoded_ood_sinkhorn_div_10': 34.5821044921875,
 'encoded_ood_sinkhorn_div_100': 32.87112197875977,
 'encoded_ood_e_distance': 65.39495653164795,
 'encoded_ood_mmd': 0.1113489583832942}

In [25]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.6786619478903313,
 'decoded_ood_sinkhorn_div_1': 178.64292907714844,
 'decoded_ood_sinkhorn_div_10': 133.90704956054688,
 'decoded_ood_sinkhorn_div_100': 34.36363220214844,
 'decoded_ood_e_distance': 67.16140559518308,
 'decoded_ood_mmd': 0.03772059304184831}

In [26]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/identity"

In [27]:
import os
import pandas as
pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "ood_metrics_encoded.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "ood_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded.csv"))
pd.DataFrame.from_dict(train_metrics_encoded).to_csv(os.path.join(output_dir, "train_metrics_encoded.csv"))
pd.DataFrame.from_dict(train_metrics_decoded).to_csv(os.path.join(output_dir, "train_metrics_decoded.csv"))

SyntaxError: invalid syntax (2995550218.py, line 2)