In [1]:
import os
import torch

import scanpy as sc
import numpy as np
import pandas as pd
import gseapy as gp

from tqdm import tqdm
from sklearn.cluster import KMeans
from contrastive_vi.model.contrastive_vi import ContrastiveVIModel
from scripts import constants
from scvi._settings import settings

Global seed set to 0


In [2]:
settings.seed = 0
device = "cuda:1"
dataset = "zheng_2017"

Global seed set to 0


In [3]:
pathway_enr_fdr = 0.05

In [4]:
split_key = constants.DATASET_SPLIT_LOOKUP[dataset]["split_key"]
background_value = constants.DATASET_SPLIT_LOOKUP[dataset]["background_value"]
label_key = constants.DATASET_SPLIT_LOOKUP[dataset]["label_key"]
seeds = constants.DEFAULT_SEEDS
latent_size = 10

In [5]:
adata = sc.read_h5ad(
    os.path.join(
        constants.DEFAULT_DATA_PATH,
        f"{dataset}/preprocessed/adata_top_2000_genes_tc.h5ad",
    )
)
ContrastiveVIModel.setup_anndata(adata, layer="count")

[34mINFO    [0m No batch_key inputted, assuming all cells are same batch                            
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.layers[1m[[0m[32m"count"[0m[1m][0m                                               
[34mINFO    [0m Successfully registered anndata object containing [1;36m16856[0m cells, [1;36m2000[0m vars, [1;36m1[0m batches,
         [1;36m1[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra
         continuous covariates.                                                              
[34mINFO    [0m Please do not further modify adata until model is trained.                          


In [6]:
target_indices = np.where(adata.obs[split_key] != background_value)[0]
target_adata = adata[target_indices]
background_indices = np.where(adata.obs[split_key] == background_value)[0]
background_adata = adata[background_indices]

In [7]:
contrastiveVI_model_list = []
contrastiveVI_latent_rep_list = []
scVI_model_list = []
for seed in tqdm(seeds):
    contrastiveVI_result_dir = os.path.join(
        constants.DEFAULT_RESULTS_PATH,
        f"{dataset}/contrastiveVI/latent_{latent_size}",
        f"{seed}",
    )
    contrastiveVI_model_list.append(
        torch.load(
            os.path.join(contrastiveVI_result_dir, "model.ckpt"),
            map_location=device,
        ),
    )
    
    contrastiveVI_latent_rep_list.append(
        np.load(os.path.join(contrastiveVI_result_dir, "latent_representations.npy")),
    )

    scVI_result_dir = os.path.join(
        constants.DEFAULT_RESULTS_PATH,
        f"{dataset}/scVI/latent_{latent_size}",
        f"{seed}",
    )

    scVI_model_list.append(
        torch.load(
            os.path.join(scVI_result_dir, "model.ckpt"),
            map_location=device,
        ),
    )

100%|███████████████████████████████████████████████████████████████████████████████| 5/5 [00:30<00:00,  6.17s/it]


In [8]:
from scipy.stats import spearmanr

correlations = []
labels = ["post_transplant", "pre_transplant"]
condition_label = ["condition"]

for (scVI_model, contrastiveVI_model, latent_rep) in zip(
    scVI_model_list, contrastiveVI_model_list, contrastiveVI_latent_rep_list
):
    latent_clusters = KMeans(n_clusters=2, random_state=123).fit(latent_rep).labels_
    cluster_label = f"cluster"
    
    tmp_target_adata = target_adata.copy()
    tmp_target_adata.obs[cluster_label] = latent_clusters.astype(str)


    scVI_de_result = scVI_model.differential_expression(
        adata=tmp_target_adata,
        groupby=cluster_label,
        group1="0",
        group2="1",
        idx1=None,
        idx2=None,
        mode="change",
        delta=0.25,
        batch_size=128,
        all_stats=True,
        batch_correction=False,
        batchid1=None,
        batchid2=None,
        fdr_target=0.05,
        silent=False,
        n_samples=100,
    )

    contrastiveVI_de_result = contrastiveVI_model.differential_expression(
        adata=tmp_target_adata,
        groupby=cluster_label,
        group1="0",
        group2="1",
        idx1=None,
        idx2=None,
        mode="change",
        delta=0.25,
        batch_size=128,
        all_stats=True,
        batch_correction=False,
        batchid1=None,
        batchid2=None,
        fdr_target=0.05,
        silent=False,
    )

    correlations.append(spearmanr(
        scVI_de_result['bayes_factor'].values,
        contrastiveVI_de_result['bayes_factor'].values
    )[0])

DE...: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.17it/s]
DE...: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.29s/it]
DE...: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.87it/s]
DE...: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.39s/it]
DE...: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.45it/s]
DE...: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.27s/it]
DE...: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.84it/s]
DE...: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.29s/it]
DE...: 100%|████████████████████████████████████████████████████████████████████

In [9]:
correlations

[0.9984884217645608,
 0.9984980532500395,
 0.9985154552884273,
 0.9987561837871823,
 0.9982017288261604]

In [10]:
np.mean(correlations)

0.9984919685832739