In [2]:
import functools

import jax
import numpy as np
import scanpy as sc

from ot_pert.metrics import compute_mean_metrics, compute_metrics

In [3]:
adata = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/models/cpa/combosciplex/adata_with_predictions.h5ad")

  utils.warn_names_duplicates("obs")


In [4]:
adata_train_orig = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train.h5ad")



In [5]:
adata_train_orig.uns.keys()

dict_keys(['Drug1_colors', 'Drug2_colors', 'Well_colors', 'condition_colors', 'dendrogram_leiden', 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pathway1_colors', 'pathway2_colors', 'pathway_colors', 'pca', 'rank_genes_groups', 'rank_genes_groups_cov_all', 'split_colors', 'umap'])

In [6]:
preds = adata.obsm["CPA_pred"]

In [7]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add


def project_data(data: np.ndarray, projection_matrix: np.ndarray, mean_to_subtract: np.ndarray) -> np.ndarray:
    return np.matmul(data - mean_to_subtract, projection_matrix)

In [8]:
OBSM_KEY_DATA_EMBEDDING = "X_pca"

In [9]:
adata_train = adata[adata.obs["split"] == "train"]
adata_test = adata[adata.obs["split"] == "test"]
adata_ood = adata[adata.obs["split"] == "ood"]

In [10]:
project_data_fn = functools.partial(
    project_data,
    projection_matrix=adata_train_orig.varm["PCs"],
    mean_to_subtract=adata_train_orig.varm["X_train_mean"].T,
)

In [11]:
train_data_target_encoded = {}
train_data_target_decoded = {}
train_data_target_encoded_predicted = {}
train_data_target_decoded_predicted = {}

for cond in adata_train.obs["condition"].cat.categories:
    if cond == "control":
        continue
    train_data_target_encoded[cond] = adata_train[adata_train.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    train_data_target_decoded[cond] = adata_train[adata_train.obs["condition"] == cond].X.A
    pred_cpa = np.log1p(adata_train[adata_train.obs["condition"] == cond].obsm["CPA_pred"])
    train_data_target_decoded_predicted[cond] = pred_cpa
    train_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)


test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"] == cond].X.A
    pred_cpa = np.log1p(adata_test[adata_test.obs["condition"] == cond].obsm["CPA_pred"])
    test_data_target_decoded_predicted[cond] = pred_cpa
    test_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)


ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"] == cond].X.A
    pred_cpa = np.log1p(adata_ood[adata_ood.obs["condition"] == cond].obsm["CPA_pred"])
    ood_data_target_decoded_predicted[cond] = pred_cpa
    ood_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)

In [13]:
train_deg_dict = {
    k: v
    for k, v in adata_train_orig.uns["rank_genes_groups_cov_all"].items()
    if k in train_data_target_decoded_predicted.keys()
}
test_deg_dict = {
    k: v
    for k, v in adata_train_orig.uns["rank_genes_groups_cov_all"].items()
    if k in test_data_target_decoded_predicted.keys()
}
ood_deg_dict = {
    k: v
    for k, v in adata_train_orig.uns["rank_genes_groups_cov_all"].items()
    if k in ood_data_target_decoded_predicted.keys()
}

In [14]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]


ood_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, ood_data_target_decoded_predicted, ood_deg_dict)
ood_deg_target_decoded = jax.tree_util.tree_map(get_mask, ood_data_target_decoded, ood_deg_dict)

test_deg_target_decoded_predicted = jax.tree_util.tree_map(get_mask, test_data_target_decoded_predicted, test_deg_dict)
test_deg_target_decoded = jax.tree_util.tree_map(get_mask, test_data_target_decoded, test_deg_dict)

train_deg_target_decoded_predicted = jax.tree_util.tree_map(
    get_mask, train_data_target_decoded_predicted, train_deg_dict
)
train_deg_target_decoded = jax.tree_util.tree_map(get_mask, train_data_target_decoded, test_deg_dict)

In [15]:
deg_ood_metrics = jax.tree_util.tree_map(compute_metrics, ood_deg_target_decoded, ood_deg_target_decoded_predicted)
deg_mean_ood_metrics = compute_mean_metrics(deg_ood_metrics, prefix="deg_ood_")

deg_test_metrics = jax.tree_util.tree_map(compute_metrics, test_deg_target_decoded, test_deg_target_decoded_predicted)
deg_mean_test_metrics = compute_mean_metrics(deg_test_metrics, prefix="deg_test_")

deg_train_metrics = jax.tree_util.tree_map(
    compute_metrics, train_deg_target_decoded, train_deg_target_decoded_predicted
)
deg_mean_train_metrics = compute_mean_metrics(deg_train_metrics, prefix="deg_train_")

2024-04-17 09:40:21.651568: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [16]:
train_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, train_data_target_encoded, train_data_target_encoded_predicted
)
mean_train_metrics_encoded = compute_mean_metrics(train_metrics_encoded, prefix="encoded_train_")

train_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, train_data_target_decoded, train_data_target_decoded_predicted
)
mean_train_metrics_decoded = compute_mean_metrics(train_metrics_decoded, prefix="decoded_train_")

In [17]:
test_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted
)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, test_data_target_decoded, test_data_target_decoded_predicted
)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [18]:
ood_metrics_encoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted
)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(
    compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted
)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [35]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': 0.36148934677513483,
 'encoded_ood_sinkhorn_div_01': 22.507043075561523,
 'encoded_ood_e_distance': 2.6911058391697225,
 'encoded_ood_mmd': 10.928740692138671}

In [70]:
mean_test_metrics_encoded

{'encoded_test_r_squared': 0.14233631660649743,
 'encoded_test_sinkhorn_div_01': 22.428858977097732,
 'encoded_test_e_distance': 2.607691631396518,
 'encoded_test_mmd': 10.481525054344765}

In [36]:
ood_metrics_encoded

{'Cediranib+PCI-34051': {'r_squared': 0.38961658319299564,
  'sinkhorn_div_01': 19.501216888427734,
  'e_distance': 2.1707178259849313,
  'mmd': 8.466133117675781},
 'Givinostat+SRT1720': {'r_squared': -0.05524568663913931,
  'sinkhorn_div_01': 19.98489761352539,
  'e_distance': 2.2141817231577066,
  'mmd': 8.647775650024414},
 'Panobinostat+Crizotinib': {'r_squared': 0.49397616829387336,
  'sinkhorn_div_01': 28.331708908081055,
  'e_distance': 3.797012271117232,
  'mmd': 17.024738311767578},
 'Panobinostat+PCI-34051': {'r_squared': 0.48293589364122314,
  'sinkhorn_div_01': 25.119630813598633,
  'e_distance': 3.2346571766388164,
  'mmd': 13.55031967163086},
 'SRT2104+Alvespimycin': {'r_squared': 0.4961637753867213,
  'sinkhorn_div_01': 19.597761154174805,
  'e_distance': 2.0389601989499258,
  'mmd': 6.954736709594727}}

In [58]:
pred = ood_data_target_encoded_predicted["SRT2104+Alvespimycin"]

In [59]:
true = ood_data_target_encoded["SRT2104+Alvespimycin"]

In [60]:
from ott.geometry import costs, pointcloud
from ott.tools.sinkhorn_divergence import sinkhorn_divergence
from sklearn.metrics import r2_score


def compute_r_squared(x: np.ndarray, y: np.ndarray) -> float:
    return r2_score(np.mean(x, axis=0), np.mean(y, axis=0))


def compute_sinkhorn_div(x: np.ndarray, y: np.ndarray, epsilon: float) -> float:
    return float(
        sinkhorn_divergence(
            pointcloud.PointCloud,
            x=x,
            y=y,
            cost_fn=costs.SqEuclidean(),
            epsilon=epsilon,
            scale_cost=1.0,
        ).divergence
    )

In [61]:
compute_r_squared(pred, true)

0.6092905921317875

In [62]:
compute_r_squared(true, pred)

0.4961637753867213

In [69]:
compute_sinkhorn_div(true, pred, 0.1)

19.597761154174805

In [67]:
pointcloud.PointCloud(x=true, y=pred, cost_fn=costs.SqEuclidean()).cost_matrix.mean()

Array(31.116922, dtype=float32)

In [72]:
compute_sinkhorn_div(pred, true, 1)

19.27981185913086

In [26]:
deg_mean_ood_metrics_encoded

{'deg_ood_r_squared': 0.5230906992746265,
 'deg_ood_sinkhorn_div_01': 23.402545547485353,
 'deg_ood_e_distance': 3.987580411351522,
 'deg_ood_mmd': 13.944179344177247}

In [21]:
mean_train_metrics_encoded

{'encoded_train_r_squared': 0.1105130831981621,
 'encoded_train_sinkhorn_div_01': 25.790736491863544,
 'encoded_train_e_distance': 2.6972641844368406,
 'encoded_train_mmd': 10.962975355295034}

In [54]:
mean_test_metrics_encoded

{'encoded_test_r_squared': 0.14233631660649743,
 'encoded_test_sinkhorn_div_01': 22.428858977097732,
 'encoded_test_e_distance': 2.607691631396518,
 'encoded_test_mmd': 10.481525054344765}

In [57]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': 0.15826509960980395,
 'encoded_ood_sinkhorn_div_01': 21.725287119547527,
 'encoded_ood_e_distance': 2.5296382695145954,
 'encoded_ood_mmd': 10.320155461629232}

In [34]:
deg_mean_ood_metrics_encoded

{'deg_ood_r_squared': 0.5230906992746265,
 'deg_ood_sinkhorn_div_01': 23.402545547485353,
 'deg_ood_e_distance': 3.987580411351522,
 'deg_ood_mmd': 13.944179344177247}

In [58]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.7614227489499669,
 'decoded_ood_sinkhorn_div_01': 85.6328353881836,
 'decoded_ood_e_distance': 4.12626323148969,
 'decoded_ood_mmd': 21.937901814778645}

In [62]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results/cpa"

In [65]:
import os

pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "ood_metrics_encoded.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "ood_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded.csv"))
pd.DataFrame.from_dict(train_metrics_encoded).to_csv(os.path.join(output_dir, "train_metrics_encoded.csv"))
pd.DataFrame.from_dict(train_metrics_decoded).to_csv(os.path.join(output_dir, "train_metrics_decoded.csv"))

In [22]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': 0.16754053021868395,
 'encoded_ood_sinkhorn_div_01': 21.63243230183919,
 'encoded_ood_e_distance': 2.5160986112775205,
 'encoded_ood_mmd': 10.119170506795248}

In [23]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.761389117431608,
 'decoded_ood_sinkhorn_div_01': 85.49267069498698,
 'decoded_ood_e_distance': 4.157940662742361,
 'decoded_ood_mmd': 21.927113850911457}

In [28]:
adata[adata.obs["split"] == "ood"].obs["condition"].value_counts()

condition
Givinostat+Crizotinib     2662
Givinostat+SRT1720        2260
Cediranib+PCI-34051       2161
Panobinostat+SRT2104      1971
Panobinostat+PCI-34051    1814
Dacinostat+Dasatinib      1231
control                    500
Name: count, dtype: int64

In [29]:
adata_train.obs["split"].value_counts()

split
train    48279
Name: count, dtype: int64

In [30]:
adata_test.obs["split"].value_counts()

split
test    3000
Name: count, dtype: int64

In [32]:
adata_ood.obs["split"].value_counts()

split
ood    12599
Name: count, dtype: int64