In [1]:
import os
import torch
import pandas as pd
import numpy as np
import scanpy as sc
import gseapy as gp
import seaborn as sns

from tqdm import tqdm
from functools import partial
from sklearn.cluster import KMeans
from scvi.model.base._utils import _de_core
from scvi.model._utils import _get_var_names_from_setup_anndata,  scrna_raw_counts_properties
from scripts import constants
from contrastive_vi.model.contrastive_vi import ContrastiveVIModel

from anndata import AnnData
from typing import Dict, Iterable, Optional, Sequence, Union
from scvi._compat import Literal

Number = Union[int, float]

Global seed set to 0
1: package ‘methods’ was built under R version 3.6.1 
2: package ‘datasets’ was built under R version 3.6.1 
3: package ‘utils’ was built under R version 3.6.1 
4: package ‘grDevices’ was built under R version 3.6.1 
5: package ‘graphics’ was built under R version 3.6.1 
6: package ‘stats’ was built under R version 3.6.1 


In [2]:
dataset = "zheng_2017"

In [3]:
split_key = constants.DATASET_SPLIT_LOOKUP[dataset]["split_key"]
background_value = constants.DATASET_SPLIT_LOOKUP[dataset]["background_value"]

In [4]:
adata = sc.read_h5ad(
    os.path.join(
        constants.DEFAULT_DATA_PATH,
        f"{dataset}/preprocessed/adata_top_2000_genes.h5ad",
    )
)
ContrastiveVIModel.setup_anndata(adata, layer="count")

[34mINFO    [0m No batch_key inputted, assuming all cells are same batch                            
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.layers[1m[[0m[32m"count"[0m[1m][0m                                               
[34mINFO    [0m Successfully registered anndata object containing [1;36m16856[0m cells, [1;36m2000[0m vars, [1;36m1[0m batches,
         [1;36m1[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra
         continuous covariates.                                                              
[34mINFO    [0m Please do not further modify adata until model is trained.                          


In [5]:
target_indices = np.where(adata.obs[split_key] != background_value)[0]
target_adata = adata[target_indices]

In [6]:
genes = pd.read_table(
    os.path.join(
        constants.DEFAULT_DATA_PATH,
        dataset,
        "aml027_post_transplant_filtered_gene_bc_matrices",
        "filtered_matrices_mex/hg19",
        "genes.tsv",
    ),
    header=None,
)
genes = genes.rename(columns={0: "ensembl_id", 1: "gene_symbol"})
genes = genes[genes["ensembl_id"].isin(adata.var.index)]

In [7]:
seeds = [123, 42, 789, 46, 999]
latent_size = 10
model_list = [
    torch.load(
        os.path.join(
            constants.DEFAULT_RESULTS_PATH,
            f"{dataset}/contrastiveVI/latent_{latent_size}/{seed}/model.ckpt",
        ),
        map_location="cpu",
    )
    for seed in tqdm(seeds)
]
latent_rep_list = [
    np.load(
        os.path.join(
            constants.DEFAULT_RESULTS_PATH,
            f"{dataset}/contrastiveVI/latent_{latent_size}/{seed}/latent_representations.npy",
        ),
    )
    for seed in tqdm(seeds)
]

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.55s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 707.42it/s]


In [8]:
result_list = []
for seed_index, seed in enumerate(seeds):
    model = model_list[seed_index]
    latent_rep = latent_rep_list[seed_index]
    
    target_adata = model._validate_anndata(target_adata.copy())
    col_names = _get_var_names_from_setup_anndata(target_adata)
    
    latent_clusters = KMeans(n_clusters=2, random_state=123).fit(latent_rep).labels_
    target_adata.obs[f"cluster_{seeds[seed_index]}"] = latent_clusters.astype(str)

    def get_normalized_expression(
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        transform_batch: Optional[Sequence[Union[Number, str]]] = None,
        gene_list: Optional[Sequence[str]] = None,
        library_size: Union[float, Literal["latent"]] = 1,
        n_samples: int = 1,
        n_samples_overall: int = None,
        batch_size: Optional[int] = None,
        return_mean: bool = True,
        return_numpy: Optional[bool] = None,
    ):
        exprs = model.get_normalized_expression(
            adata=adata,
            indices=indices,
            transform_batch=transform_batch,
            gene_list=gene_list,
            library_size=library_size,
            n_samples=n_samples,
            n_samples_overall=n_samples_overall,
            batch_size=batch_size,
            return_mean=return_mean,
            return_numpy=return_numpy,
        )
        return exprs["salient"]
    
    model_fn = partial(
        get_normalized_expression,
        return_numpy=True,
        n_samples=100,
        batch_size=128,
    )
    result = _de_core(
        target_adata,
        model_fn,
        groupby=f"cluster_{seeds[seed_index]}",
        group1="0",
        group2="1",
        idx1=None,
        idx2=None,
        all_stats=True,
        all_stats_fn=scrna_raw_counts_properties,
        col_names=col_names,
        mode="change",
        batchid1=None,
        batchid2=None,
        delta=0.25,
        batch_correction=False,
        fdr=0.05,
        silent=False,
    )
    result.reset_index()
    result["ensembl_id"] = result.index
    result = result.merge(genes, on="ensembl_id")
    result_list.append(result)

DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:10<00:00, 70.64s/it]
DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:49<00:00, 49.16s/it]
DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:52<00:00, 52.27s/it]
DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:45<00:00, 45.77s/it]
DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:51<00:00, 51.55s/it]


In [9]:
for seed_index, seed in enumerate(seeds):
    result_list[seed_index]["seed"] = seed
all_result = pd.concat(result_list)

In [10]:
all_result_mean = (
    all_result.groupby("gene_symbol", as_index=False)
    .mean()
    .sort_values(by="proba_de", ascending=False)
)

In [11]:
top_genes = all_result_mean[all_result_mean["proba_de"] > 0.95]["gene_symbol"].tolist()

enr = gp.enrichr(
    gene_list=top_genes,
    gene_sets="KEGG_2016",
    organism="human",
    cutoff=0.05,
)

enr_results = enr.results
enr_results = enr_results[enr_results["Adjusted P-value"] < 0.1]

In [12]:
cols = ["Gene_set", "Term", "Adjusted P-value", "Overlap", "Genes"]
enr_results[cols]

Unnamed: 0,Gene_set,Term,Adjusted P-value,Overlap,Genes
0,KEGG_2016,Hematopoietic cell lineage Homo sapiens hsa04640,9.638086e-10,23/88,HLA-DRB5;CSF1;FLT3;ITGA2B;DNTT;GP1BA;TNF;CD3D;...
1,KEGG_2016,Asthma Homo sapiens hsa05310,3.238552e-08,13/31,IL10;HLA-DRB5;FCER1G;PRG2;RNASE3;TNF;HLA-DMB;H...
2,KEGG_2016,Systemic lupus erythematosus Homo sapiens hsa0...,1.182463e-05,22/135,IL10;C1QA;HIST1H2BM;HLA-DRB5;HIST1H3J;HIST1H4L...
3,KEGG_2016,Antigen processing and presentation Homo sapie...,8.46894e-05,15/77,CD74;HLA-DRB5;HSPA5;HSPA6;IFI30;TNF;CTSS;HLA-D...
4,KEGG_2016,Type I diabetes mellitus Homo sapiens hsa04940,0.000104934,11/43,HLA-DRB5;HLA-DMB;GAD1;IL1B;HLA-DPB1;GZMB;HLA-D...
5,KEGG_2016,Allograft rejection Homo sapiens hsa05330,0.0001931001,10/38,IL10;HLA-DRB5;HLA-DMB;HLA-DPB1;GZMB;HLA-DRA;TN...
6,KEGG_2016,Graft-versus-host disease Homo sapiens hsa05332,0.0003196112,10/41,HLA-DRB5;HLA-DMB;IL1B;HLA-DPB1;GZMB;HLA-DRA;TN...
7,KEGG_2016,Rheumatoid arthritis Homo sapiens hsa05323,0.0003196112,15/90,HLA-DRB5;CSF1;CCL3L1;TNF;HLA-DMB;IL1B;CCL5;CCL...
8,KEGG_2016,Leishmaniasis Homo sapiens hsa05140,0.000484236,13/73,IL10;HLA-DRB5;NCF2;PTGS2;TNF;NFKBIA;HLA-DMB;IL...
9,KEGG_2016,Cell adhesion molecules (CAMs) Homo sapiens hs...,0.000484236,19/142,CD274;HLA-DRB5;CD2;CLDN10;HLA-DMB;CD8B;SELL;HL...


In [13]:
len(top_genes)

895