In [24]:
import scanpy as sc
import numpy as np
import functools
import jax
from ot_pert.metrics import compute_metrics, compute_mean_metrics

In [3]:
adata = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/models/cpa/combosciplex/adata_with_predictions.h5ad")

  utils.warn_names_duplicates("obs")


In [15]:
adata_train_orig = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train.h5ad")



In [7]:
preds = adata.obsm["CPA_pred"]

In [41]:
def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

def project_data(data: np.ndarray, projection_matrix: np.ndarray, mean_to_subtract: np.ndarray) -> np.ndarray:
    return np.matmul(data - mean_to_subtract, projection_matrix)

In [42]:
OBSM_KEY_DATA_EMBEDDING = "X_pca"

In [37]:
adata_train = adata[adata.obs["split"]=="train"]
adata_test = adata[adata.obs["split"]=="test"]
adata_ood = adata[adata.obs["split"]=="ood"]

In [43]:
project_data_fn = functools.partial(project_data, projection_matrix = adata_train_orig.varm["PCs"], mean_to_subtract=adata_train_orig.varm["X_train_mean"].T)

In [44]:
adata_train_orig.varm["PCs"].shape

(2000, 50)

In [48]:
train_data_target_encoded = {}
train_data_target_decoded = {}
train_data_target_encoded_predicted = {}
train_data_target_decoded_predicted = {}

for cond in adata_train.obs["condition"].cat.categories:
    if cond == "control":
        continue
    train_data_target_encoded[cond] = adata_train[adata_train.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    train_data_target_decoded[cond] = adata_train[adata_train.obs["condition"]==cond].X.A
    pred_cpa = np.log1p(adata_train[adata_train.obs["condition"]==cond].obsm["CPA_pred"])
    train_data_target_decoded_predicted[cond] = pred_cpa
    train_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)


test_data_target_encoded = {}
test_data_target_decoded = {}
test_data_target_encoded_predicted = {}
test_data_target_decoded_predicted = {}
for cond in adata_test.obs["condition"].cat.categories:
    if cond == "control":
        continue
    test_data_target_encoded[cond] = adata_test[adata_test.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    test_data_target_decoded[cond] = adata_test[adata_test.obs["condition"]==cond].X.A
    pred_cpa = np.log1p(adata_test[adata_test.obs["condition"]==cond].obsm["CPA_pred"])
    test_data_target_decoded_predicted[cond] = pred_cpa
    test_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)
    

ood_data_target_encoded = {}
ood_data_target_decoded = {}
ood_data_target_encoded_predicted = {}
ood_data_target_decoded_predicted = {}
for cond in adata_ood.obs["condition"].cat.categories:
    if cond == "control":
        continue
    ood_data_target_encoded[cond] = adata_ood[adata_ood.obs["condition"]==cond].obsm[OBSM_KEY_DATA_EMBEDDING]
    ood_data_target_decoded[cond] = adata_ood[adata_ood.obs["condition"]==cond].X.A
    pred_cpa = np.log1p(adata_ood[adata_ood.obs["condition"]==cond].obsm["CPA_pred"])
    ood_data_target_decoded_predicted[cond] = pred_cpa
    ood_data_target_encoded_predicted[cond] = project_data_fn(pred_cpa)
    


In [50]:
train_metrics_encoded = jax.tree_util.tree_map(compute_metrics, train_data_target_encoded, train_data_target_encoded_predicted)
mean_train_metrics_encoded = compute_mean_metrics(train_metrics_encoded, prefix="encoded_train_")

train_metrics_decoded = jax.tree_util.tree_map(compute_metrics, train_data_target_decoded, train_data_target_decoded_predicted)
mean_train_metrics_decoded = compute_mean_metrics(train_metrics_decoded, prefix="decoded_train_")

2024-04-16 18:19:54.641830: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [51]:
test_metrics_encoded = jax.tree_util.tree_map(compute_metrics, test_data_target_encoded, test_data_target_encoded_predicted)
mean_test_metrics_encoded = compute_mean_metrics(test_metrics_encoded, prefix="encoded_test_")

test_metrics_decoded = jax.tree_util.tree_map(compute_metrics, test_data_target_decoded, test_data_target_decoded_predicted)
mean_test_metrics_decoded = compute_mean_metrics(test_metrics_decoded, prefix="decoded_test_")

In [52]:
ood_metrics_encoded = jax.tree_util.tree_map(compute_metrics, ood_data_target_encoded, ood_data_target_encoded_predicted)
mean_ood_metrics_encoded = compute_mean_metrics(ood_metrics_encoded, prefix="encoded_ood_")

ood_metrics_decoded = jax.tree_util.tree_map(compute_metrics, ood_data_target_decoded, ood_data_target_decoded_predicted)
mean_ood_metrics_decoded = compute_mean_metrics(ood_metrics_decoded, prefix="decoded_ood_")

In [53]:
mean_train_metrics_encoded

{'encoded_train_r_squared': 0.09882947343231972,
 'encoded_train_sinkhorn_div_01': 27.670842895507814,
 'encoded_train_e_distance': 2.8356454493050745,
 'encoded_train_mmd': 12.157588424682617}

In [54]:
mean_train_metrics_decoded

{'decoded_train_r_squared': 0.7729072789319338,
 'decoded_train_sinkhorn_div_01': 83.83307525634766,
 'decoded_train_e_distance': 3.982077108849602,
 'decoded_train_mmd': 20.897135314941405}

In [55]:
mean_test_metrics_encoded

{'encoded_test_r_squared': 0.22096380806799604,
 'encoded_test_sinkhorn_div_01': 21.72205924987793,
 'encoded_test_e_distance': 2.495100569970721,
 'encoded_test_mmd': 10.17115592956543}

In [56]:
mean_test_metrics_decoded

{'decoded_test_r_squared': 0.7646439309457859,
 'decoded_test_sinkhorn_div_01': 87.17083251953125,
 'decoded_test_e_distance': 4.067989317550506,
 'decoded_test_mmd': 21.85130615234375}

In [57]:
mean_ood_metrics_encoded

{'encoded_ood_r_squared': 0.15826509960980395,
 'encoded_ood_sinkhorn_div_01': 21.725287119547527,
 'encoded_ood_e_distance': 2.5296382695145954,
 'encoded_ood_mmd': 10.320155461629232}

In [58]:
mean_ood_metrics_decoded

{'decoded_ood_r_squared': 0.7614227489499669,
 'decoded_ood_sinkhorn_div_01': 85.6328353881836,
 'decoded_ood_e_distance': 4.12626323148969,
 'decoded_ood_mmd': 21.937901814778645}

In [62]:
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/results"

In [65]:
import os
pd.DataFrame.from_dict(ood_metrics_encoded).to_csv(os.path.join(output_dir, "ood_metrics_encoded.csv"))
pd.DataFrame.from_dict(ood_metrics_decoded).to_csv(os.path.join(output_dir, "ood_metrics_decoded.csv"))
pd.DataFrame.from_dict(test_metrics_encoded).to_csv(os.path.join(output_dir, "test_metrics_encoded.csv"))
pd.DataFrame.from_dict(test_metrics_decoded).to_csv(os.path.join(output_dir, "test_metrics_decoded.csv"))
pd.DataFrame.from_dict(train_metrics_encoded).to_csv(os.path.join(output_dir, "train_metrics_encoded.csv"))
pd.DataFrame.from_dict(train_metrics_decoded).to_csv(os.path.join(output_dir, "train_metrics_decoded.csv"))

In [71]:
def rank_genes_groups_by_cov(
    adata,
    groupby,
    control_group,
    covariate,
    pool_doses=False,
    n_genes=50,
    rankby_abs=True,
    key_added=‘rank_genes_groups_cov’,
    return_dict=False,
):
    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        #name of the control group in the groupby obs column
        control_group_cov = ‘_’.join([cov_cat, control_group])
        #subset adata to cells belonging to a covariate category
        adata_cov = adata[adata.obs[covariate]==cov_cat]
        #compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
            use_raw=False
        )
        #add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns[‘rank_genes_groups’][‘names’])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()
    adata.uns[key_added] = gene_dict
    if return_dict:
        return gene_dict


def get_DE_genes(adata, skip_calc_de):
    adata.obs.loc[:, ‘dose_val’] = adata.obs.condition.apply(lambda x: ‘1+1’ if len(x.split(‘+’)) == 2 else ‘1’)
    adata.obs.loc[:, ‘control’] = adata.obs.condition.apply(lambda x: 0 if len(x.split(‘+’)) == 2 else 1)
    adata.obs.loc[:, ‘condition_name’] =  adata.obs.apply(lambda x: ‘_’.join([x.cell_type, x.condition, x.dose_val]), axis = 1)
    adata.obs = adata.obs.astype(‘category’)
    if not skip_calc_de:
        rank_genes_groups_by_cov(adata,
                         groupby=‘condition_name’,
                         covariate=‘cell_type’,
                         control_group=‘ctrl_1’,
                         n_genes=len(adata.var),
                         key_added = ‘rank_genes_groups_cov_all’)
    return adata

Unnamed: 0_level_0,id,n_cells,mt,n_cells_by_counts,mean_counts,pct_dropout_by_counts,total_counts,highly_variable,means,dispersions,dispersions_norm
gene_short_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
C1orf112,ENSG00000000460.16,9158,False,9158,0.248684,85.562037,15774.0,True,0.188306,1.054700,7.938564
CFTR,ENSG00000001626.14,1756,False,1756,0.033423,97.231594,2120.0,True,0.026559,0.367341,2.067544
KLHL13,ENSG00000003096.14,6093,False,6093,0.216207,90.394135,13714.0,True,0.162379,1.469726,11.483470
TFPI,ENSG00000003436.15,18262,False,18262,0.472458,71.209207,29968.0,True,0.343148,0.555299,1.769344
SLC7A2,ENSG00000003989.17,5449,False,5449,0.097446,91.409428,6181.0,True,0.078516,0.240577,0.984797
...,...,...,...,...,...,...,...,...,...,...,...
AC006460.2,ENSG00000284052.1,3637,False,3637,0.065017,94.266120,4124.0,True,0.054669,0.300221,1.494241
AL589669.1,ENSG00000284377.1,313,False,313,0.005676,99.506543,360.0,True,0.004905,0.411998,2.448978
AC020912.1,ENSG00000284430.1,24,False,24,0.000378,99.962163,24.0,True,0.000387,0.250882,1.072816
AL805961.1,ENSG00000284668.1,50,False,50,0.000851,99.921173,54.0,True,0.000696,0.723907,5.113125
