In [1]:
import os
import torch

import scanpy as sc
import numpy as np
import pandas as pd
import gseapy as gp

from tqdm import tqdm
from sklearn.cluster import KMeans
from contrastive_vi.model.contrastive_vi import ContrastiveVIModel
from scripts import constants
from scvi._settings import settings

Global seed set to 0
1: package ‘methods’ was built under R version 3.6.1 
2: package ‘datasets’ was built under R version 3.6.1 
3: package ‘utils’ was built under R version 3.6.1 
4: package ‘grDevices’ was built under R version 3.6.1 
5: package ‘graphics’ was built under R version 3.6.1 
6: package ‘stats’ was built under R version 3.6.1 


In [2]:
settings.seed = 0
device = "cpu"
dataset = "haber_2017"

Global seed set to 0


In [3]:
split_key = constants.DATASET_SPLIT_LOOKUP[dataset]["split_key"]
background_value = constants.DATASET_SPLIT_LOOKUP[dataset]["background_value"]
seeds = constants.DEFAULT_SEEDS
latent_size = 10

In [4]:
adata = sc.read_h5ad(
    os.path.join(
        constants.DEFAULT_DATA_PATH,
        f"{dataset}/preprocessed/adata_top_2000_genes.h5ad",
    )
)
ContrastiveVIModel.setup_anndata(adata, layer="count")

[34mINFO    [0m No batch_key inputted, assuming all cells are same batch                            
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.layers[1m[[0m[32m"count"[0m[1m][0m                                               
[34mINFO    [0m Successfully registered anndata object containing [1;36m7721[0m cells, [1;36m2000[0m vars, [1;36m1[0m batches, 
         [1;36m1[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra
         continuous covariates.                                                              
[34mINFO    [0m Please do not further modify adata until model is trained.                          


In [5]:
target_indices = np.where(adata.obs[split_key] != background_value)[0]
target_adata = adata[target_indices]

In [6]:
model_list = []
latent_rep_list = []
for seed in tqdm(seeds):
    result_dir = os.path.join(
        constants.DEFAULT_RESULTS_PATH,
        f"{dataset}/contrastiveVI/latent_{latent_size}",
        f"{seed}",
    )
    model_list.append(
        torch.load(
            os.path.join(result_dir, "model.ckpt"),
            map_location=device,
        ),
    )
    latent_rep_list.append(
        np.load(os.path.join(result_dir, "latent_representations.npy")),
    )

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:31<00:00,  6.30s/it]


In [7]:
de_result_list = []
for seed_index, seed in enumerate(seeds):
    model = model_list[seed_index]
    latent_rep = latent_rep_list[seed_index]
    latent_clusters = KMeans(n_clusters=2, random_state=123).fit(latent_rep).labels_
    cluster_label = f"cluster_{seed}"
    
    tmp_target_adata = target_adata.copy()
    tmp_target_adata.obs[cluster_label] = latent_clusters.astype(str)
    
    de_result = model.differential_expression(
        adata=tmp_target_adata,
        groupby=cluster_label,
        group1="0",
        group2="1",
        idx1=None,
        idx2=None,
        mode="change",
        delta=0.25,
        batch_size=128,
        all_stats=True,
        batch_correction=False,
        batchid1=None,
        batchid2=None,
        fdr_target=0.05,
        silent=False,
    )
    
    de_result.reset_index()
    de_result["gene_symbol"] = de_result.index
    de_result["seed"] = seed
    de_result_list.append(de_result)

DE...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [04:44<00:00, 284.01s/it]
DE...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [04:34<00:00, 274.57s/it]
DE...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [04:33<00:00, 273.41s/it]
DE...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [04:32<00:00, 272.16s/it]
DE...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [04:32<00:00, 272.71s/it]


In [8]:
de_result = pd.concat(de_result_list)

In [9]:
de_result_mean = (
    de_result.groupby("gene_symbol", as_index=False)
    .mean()
    .sort_values(by="proba_de", ascending=False)
)

In [10]:
top_genes = de_result_mean[de_result_mean["proba_de"] > 0.95]["gene_symbol"].tolist()

enr = gp.enrichr(
    gene_list=top_genes,
    gene_sets="KEGG_2019_Mouse",
    organism="mouse",
    cutoff=0.05,
)

enr_results = enr.results
enr_results = enr_results[enr_results["Adjusted P-value"] < 0.05]

In [11]:
cols = ["Gene_set", "Term", "Adjusted P-value", "Overlap", "Genes"]
enr_results[cols]

Unnamed: 0,Gene_set,Term,Adjusted P-value,Overlap,Genes
0,KEGG_2019_Mouse,Fat digestion and absorption,0.023545,5/40,FABP1;FABP2;PLA2G3;APOA1;APOA4
1,KEGG_2019_Mouse,Vitamin digestion and absorption,0.023545,4/24,CUBN;RBP2;APOA1;APOA4
2,KEGG_2019_Mouse,Cholesterol metabolism,0.029304,5/49,APOH;APOC2;APOA1;APOC3;APOA4


In [12]:
enr_results["Term"].tolist()

['Fat digestion and absorption',
 'Vitamin digestion and absorption',
 'Cholesterol metabolism']

In [13]:
len(top_genes)

280