In [1]:
import os
import torch
import pandas as pd
import numpy as np
import scanpy as sc
import gseapy as gp
import seaborn as sns

from tqdm import tqdm
from functools import partial
from sklearn.cluster import KMeans
from scvi.model.base._utils import _de_core
from scvi.model._utils import _get_var_names_from_setup_anndata,  scrna_raw_counts_properties
from scripts import constants
from contrastive_vi.model.contrastive_vi import ContrastiveVIModel
from scvi.model import SCVI

from anndata import AnnData
from typing import Dict, Iterable, Optional, Sequence, Union
from scvi._compat import Literal

Number = Union[int, float]

Global seed set to 0
1: package ‘methods’ was built under R version 3.6.1 
2: package ‘datasets’ was built under R version 3.6.1 
3: package ‘utils’ was built under R version 3.6.1 
4: package ‘grDevices’ was built under R version 3.6.1 
5: package ‘graphics’ was built under R version 3.6.1 
6: package ‘stats’ was built under R version 3.6.1 


In [2]:
dataset = "zheng_2017"

split_key = constants.DATASET_SPLIT_LOOKUP[dataset]["split_key"]
background_value = constants.DATASET_SPLIT_LOOKUP[dataset]["background_value"]

In [3]:
adata = sc.read_h5ad(
    os.path.join(
        constants.DEFAULT_DATA_PATH,
        f"{dataset}/preprocessed/adata_top_2000_genes.h5ad",
    )
)
SCVI.setup_anndata(adata, layer="count")

[34mINFO    [0m No batch_key inputted, assuming all cells are same batch                            
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.layers[1m[[0m[32m"count"[0m[1m][0m                                               
[34mINFO    [0m Successfully registered anndata object containing [1;36m16856[0m cells, [1;36m2000[0m vars, [1;36m1[0m batches,
         [1;36m1[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra
         continuous covariates.                                                              
[34mINFO    [0m Please do not further modify adata until model is trained.                          


In [4]:
target_indices = np.where(adata.obs[split_key] != background_value)[0]
target_adata = adata[target_indices]

In [5]:
genes = pd.read_table(
    os.path.join(
        constants.DEFAULT_DATA_PATH,
        dataset,
        "aml027_post_transplant_filtered_gene_bc_matrices",
        "filtered_matrices_mex/hg19",
        "genes.tsv",
    ),
    header=None,
)
genes = genes.rename(columns={0: "ensembl_id", 1: "gene_symbol"})
genes = genes[genes["ensembl_id"].isin(adata.var.index)]

In [6]:
seeds = [123, 42, 789, 46, 999]
latent_size = 10
model_list = [
    torch.load(
        os.path.join(
            constants.DEFAULT_RESULTS_PATH,
            f"{dataset}/scVI/latent_{latent_size}/{seed}/model.ckpt",
        ),
        map_location="cpu",
    )
    for seed in tqdm(seeds)
]
latent_rep_list = [
    np.load(
        os.path.join(
            constants.DEFAULT_RESULTS_PATH,
            f"{dataset}/scVI/latent_{latent_size}/{seed}/latent_representations.npy",
        ),
    )
    for seed in tqdm(seeds)
]

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:51<00:00, 10.35s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 17.45it/s]


In [7]:
result_list = []
for seed_index, seed in enumerate(seeds):
    model = model_list[seed_index]
    latent_rep = latent_rep_list[seed_index]
    
    target_adata = model._validate_anndata(target_adata.copy())
    col_names = _get_var_names_from_setup_anndata(target_adata)
    
    latent_clusters = KMeans(n_clusters=2, random_state=123).fit(latent_rep).labels_
    target_adata.obs[f"cluster_{seeds[seed_index]}"] = latent_clusters.astype(str)

    def get_normalized_expression(
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        transform_batch: Optional[Sequence[Union[Number, str]]] = None,
        gene_list: Optional[Sequence[str]] = None,
        library_size: Union[float, Literal["latent"]] = 1,
        n_samples: int = 1,
        n_samples_overall: int = None,
        batch_size: Optional[int] = None,
        return_mean: bool = True,
        return_numpy: Optional[bool] = None,
    ):
        return model.get_normalized_expression(
            adata=adata,
            indices=indices,
            transform_batch=transform_batch,
            gene_list=gene_list,
            library_size=library_size,
            n_samples=n_samples,
            n_samples_overall=n_samples_overall,
            batch_size=batch_size,
            return_mean=return_mean,
            return_numpy=return_numpy,
        )
    
    model_fn = partial(
        get_normalized_expression,
        return_numpy=True,
        n_samples=100,
        batch_size=128,
    )
    result = _de_core(
        target_adata,
        model_fn,
        groupby=f"cluster_{seeds[seed_index]}",
        group1="0",
        group2="1",
        idx1=None,
        idx2=None,
        all_stats=True,
        all_stats_fn=scrna_raw_counts_properties,
        col_names=col_names,
        mode="change",
        batchid1=None,
        batchid2=None,
        delta=0.25,
        batch_correction=False,
        fdr=0.05,
        silent=False,
    )
    result.reset_index()
    result["ensembl_id"] = result.index
    result = result.merge(genes, on="ensembl_id")
    result_list.append(result)

DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:24<00:00, 24.95s/it]
DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:24<00:00, 24.58s/it]
DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:30<00:00, 30.49s/it]
DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:31<00:00, 31.33s/it]
DE...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:26<00:00, 26.15s/it]


In [8]:
for seed_index, seed in enumerate(seeds):
    result_list[seed_index]["seed"] = seed

In [9]:
all_result = pd.concat(result_list)

In [10]:
all_result_mean = (
    all_result.groupby("gene_symbol", as_index=False)
    .mean()
    .sort_values(by="proba_de", ascending=False)
)

In [11]:
top_genes = all_result_mean[all_result_mean["proba_de"] > 0.95]["gene_symbol"].tolist()

In [12]:
enr = gp.enrichr(
    gene_list=top_genes,
    gene_sets="KEGG_2016",
    organism="human",
    cutoff=0.05,
)

In [13]:
enr_results = enr.results
enr_results = enr_results[enr_results["Adjusted P-value"] < 0.1]

In [14]:
cols = ["Gene_set", "Term", "Adjusted P-value", "Overlap", "Genes"]
enr_results[cols]

Unnamed: 0,Gene_set,Term,Adjusted P-value,Overlap,Genes
0,KEGG_2016,Asthma Homo sapiens hsa05310,1.624066e-11,16/31,IL10;HLA-DRB5;FCER1G;PRG2;RNASE3;TNF;HLA-DMA;H...
1,KEGG_2016,Hematopoietic cell lineage Homo sapiens hsa04640,7.56816e-09,22/88,GYPA;HLA-DRB5;CSF1;DNTT;GP1BA;TNF;CD3D;CD2;FCE...
2,KEGG_2016,Type I diabetes mellitus Homo sapiens hsa04940,3.395361e-08,15/43,HLA-DRB5;GAD1;ICA1;GZMB;TNF;HLA-DMA;HLA-DMB;IL...
3,KEGG_2016,Graft-versus-host disease Homo sapiens hsa05332,1.358825e-07,14/41,HLA-DRB5;GZMB;TNF;HLA-DMA;HLA-DMB;IL1B;HLA-DPB...
4,KEGG_2016,Leishmaniasis Homo sapiens hsa05140,2.124204e-07,18/73,IL10;JUN;MARCKSL1;HLA-DRB5;NCF2;PTGS2;TNF;NFKB...
5,KEGG_2016,Allograft rejection Homo sapiens hsa05330,3.332964e-07,13/38,IL10;HLA-DRB5;GZMB;TNF;HLA-DMA;HLA-DMB;HLA-DPB...
6,KEGG_2016,Antigen processing and presentation Homo sapie...,3.78559e-07,18/77,CD74;HLA-DRB5;HSPA6;IFI30;TNF;CTSS;HLA-DMA;HLA...
7,KEGG_2016,Rheumatoid arthritis Homo sapiens hsa05323,4.26712e-06,18/90,JUN;HLA-DRB5;CSF1;CCL3L1;TNF;HLA-DMA;HLA-DMB;I...
8,KEGG_2016,Influenza A Homo sapiens hsa05164,4.26712e-06,26/175,TNF;ACTG1;PIK3R5;CASP9;SOCS3;HLA-DMA;HLA-DMB;C...
9,KEGG_2016,Staphylococcus aureus infection Homo sapiens h...,4.672185e-06,14/56,IL10;CFD;C1QA;HLA-DRB5;FPR1;HLA-DMA;HLA-DMB;HL...


In [15]:
len(top_genes)

928