In [1]:
import scanpy as sc
import numpy as np
import functools
import jax
from ot_pert.metrics import compute_metrics, compute_mean_metrics

In [2]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_30.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_30.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_30.h5ad"

In [3]:
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)




In [4]:
adata_train.obs["split"].value_counts()

split
train    278826
Name: count, dtype: int64

In [5]:
adata_test.obs["split"].value_counts()

split
test    129000
Name: count, dtype: int64

In [6]:
adata_ood.obs["split"].value_counts()

split
ood    58731
Name: count, dtype: int64

In [7]:
adata_ood.obs["condition"].value_counts()

condition
MCF7_Mesna__100.0                                          707
MCF7_GSK-LSD1_2HCl_10.0                                    557
MCF7_Valproic_acid_sodium_salt_(Sodium_valproate)_100.0    543
MCF7_Tubastatin_A_HCl_10.0                                 543
MCF7_Tazemetostat_(EPZ-6438)_10000.0                       539
                                                          ... 
A549_Nilotinib_(AMN-107)_10000.0                           203
A549_WHI-P154_10000.0                                      202
A549_Nilotinib_(AMN-107)_100.0                             201
K562_A-366_100.0                                           201
A549_TGX-221_1000.0                                        200
Name: count, Length: 199, dtype: int64

In [8]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

def project_data(data: np.ndarray, projection_matrix: np.ndarray, mean_to_subtract: np.ndarray) -> np.ndarray:
    return np.matmul(data - mean_to_subtract, projection_matrix)

In [9]:
project_data_fn = functools.partial(project_data, projection_matrix = adata_train.varm["PCs"], mean_to_subtract=adata_train.varm["X_train_mean"].T)

In [10]:
OBSM_KEY_DATA_EMBEDDING = "X_pca"

In [11]:
train_data_target_encoded = {}
train_data_target_decoded = {}
train_data_target_encoded_predicted = {}
train_data_target_decoded_predicted = {}
for cond in adata_train.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_train[adata_train.obs["condition"] == cond].obs["cell_type"].unique())
    assert len(src_str) == 1  
    pred_id = adata_train[adata_train.obs["condition"] == src_str[0] + "_Vehicle_0.0"].X.A
    train_data_target_encoded[cond] = adata_train[adata_train.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    train_data_target_decoded[cond] = adata_train[adata_train.obs["condition"]==cond].X.A
    train_data_target_decoded_predicted[cond] = pred_id
    train_data_target_encoded_predicted[cond] = project_data_fn(pred_id)

test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_train[adata_train.obs["condition"] == cond].obs["cell_type"].unique())
    assert len(src_str) == 1  
    pred_id = adata_train[adata_train.obs["condition"] == src_str[0] + "_Vehicle_0.0"].X.A

    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"]==cond].X.A
    test_data_target_decoded_predicted[cond] = pred_id
    test_data_target_encoded_predicted[cond] = project_data_fn(pred_id)
    
ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_train.obs["condition"].cat.categories:
    if "Vehicle" in cond:
        continue
    src_str = list(adata_train[adata_train.obs["condition"] == cond].obs["cell_type"].unique())
    assert len(src_str) == 1  
    pred_id = adata_train[adata_train.obs["condition"] == src_str[0] + "_Vehicle_0.0"].X.A
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"]==cond].X.A
    ood_data_target_decoded_predicted[cond] = pred_id
    ood_data_target_encoded_predicted[cond] = project_data_fn(pred_id)
    


In [18]:
is_nan = jax.tree_util.tree_map(lambda x: np.isnan(x).sum(), train_data_target_encoded_predicted)

In [21]:
train_data_target_encoded_predicted

{'A549_ABT-737_100.0': array([], shape=(0, 30), dtype=float64),
 'A549_AC480_(BMS-599626)_1000.0': array([], shape=(0, 30), dtype=float64),
 'A549_AC480_(BMS-599626)_10000.0': array([], shape=(0, 30), dtype=float64),
 'A549_AG-14361_10.0': array([], shape=(0, 30), dtype=float64),
 'A549_AG-14361_100.0': array([], shape=(0, 30), dtype=float64),
 'A549_AG-490_(Tyrphostin_B42)_100.0': array([], shape=(0, 30), dtype=float64),
 'A549_AG-490_(Tyrphostin_B42)_10000.0': array([], shape=(0, 30), dtype=float64),
 'A549_AICAR_(Acadesine)_10.0': array([], shape=(0, 30), dtype=float64),
 'A549_AICAR_(Acadesine)_1000.0': array([], shape=(0, 30), dtype=float64),
 'A549_AICAR_(Acadesine)_10000.0': array([], shape=(0, 30), dtype=float64),
 'A549_AR-42_10.0': array([], shape=(0, 30), dtype=float64),
 'A549_AR-42_100.0': array([], shape=(0, 30), dtype=float64),
 'A549_AZ_960_100.0': array([], shape=(0, 30), dtype=float64),
 'A549_Alendronate_sodium_trihydrate_10.0': array([], shape=(0, 30), dtype=float64

In [19]:
np.unique(list(is_nan.values()))

array([0])

In [20]:
train_metrics_encoded = jax.tree_util.tree_map(compute_metrics, train_data_target_encoded, train_data_target_encoded_predicted)
mean_train_metrics_encoded = compute_mean_metrics(train_metrics_encoded, prefix="encoded_train_")

#train_metrics_decoded = jax.tree_util.tree_map(compute_metrics, train_data_target_decoded, train_data_target_decoded_predicted)
#mean_train_metrics_decoded = compute_mean_metrics(train_metrics_decoded, prefix="decoded_train_")

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


ValueError: Input contains NaN.

In [None]:
test_metrics_encoded = jax.tree_util.tree_map(compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(compute_metrics, test_data_target_decoded, test_data_target_decoded_predicted)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [None]:
ood_metrics_encoded = jax.tree_util.tree_map(compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")




In [None]:
train_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in train_data_target_decoded_predicted.keys()}
test_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in test_data_target_decoded_predicted.keys()}
ood_deg_dict = {k: v for k,v in adata_train.uns['rank_genes_groups_cov_all'].items() if k in ood_data_target_decoded_predicted.keys()}

In [None]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]
    
ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)

train_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, train_data_target_decoded_predicted, train_deg_dict)
train_deg_target_decoded = jax.tree_util.tree_map(get_mask, train_data_target_decoded, test_deg_dict)

In [None]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")

deg_train_metrics = jax.tree_util.tree_map(compute_metrics, train_deg_target_decoded, train_deg_target_decoded_predicted)
deg_mean_train_metrics = compute_mean_metrics(deg_train_metrics, prefix="deg_train_")

In [None]:
deg_mean_ood_metrics

In [None]:
deg_mean_test_metrics

In [None]:
deg_mean_train_metrics

In [None]:
mean_train_metrics_decoded

In [None]:
mean_test_metrics_encoded

In [None]:
mean_test_metrics_decoded

In [None]:
mean_ood_metrics_encoded

In [None]:
mean_ood_metrics_decoded

In [None]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/results/identity"

In [None]:
import os
import pandas as pd
pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "ood_metrics_encoded.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "ood_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded.csv"))
pd.DataFrame.from_dict(train_metrics_encoded).to_csv(os.path.join(output_dir, "train_metrics_encoded.csv"))
pd.DataFrame.from_dict(train_metrics_decoded).to_csv(os.path.join(output_dir, "train_metrics_decoded.csv"))