In [None]:
# Use scVI to integrate and batch correct

# Integrate Malignant cells in multiple patients and plot FOXA2+ vs FOXA- cells


## Description

Using scVI for batch correction. 
For each cell, designate FOXA2 expression status.
Plot density embedding of FOXA2 expression status.


## Procedure

- Import libraries
- Load anndata object of malignant cells
- Compute 


## Reference

Lopez, Romain, et al. "Deep generative modeling for single-cell transcriptomics." Nature methods 15.12 (2018): 1053-1058.

# Setup environment


// Remove conda env if it exist
conda deactivate
conda remove -n scvi-env --all

// Create conda env
conda create -n scvi-env python=3.10
conda activate scvi-env

// Install Pytorch
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia 

// Install Jax
conda install jax -c conda-forge

// Install scvi-tools
// conda install scvi-tools -c conda-forge
pip install -U scvi-tools

// Check that it worked
python -c "import torch; torch.cuda.is_available()"
python -c "import scvi"

// Install scanpy 
pip install scanpy, ipython



In [None]:
import os
import scvi
import scanpy as sc
import torch
import anndata
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import harmonypy
import pickle
import numpy as np
import matplotlib as mpl
import matplotlib.font_manager
from matplotlib import font_manager
from matplotlib.font_manager import fontManager, FontProperties
import infercnvpy as cnv


from common_utils import setup_dirs, find_arial_font, add_gene_binary_status

In [None]:
outDir = OUTDIR_FOR_SCVI_INTEGRATION
figuresDir, dataDir, tablesDir = setup_dirs(outDir)

sc.settings.figdir = figuresDir
sc.set_figure_params(dpi_save=300, vector_friendly=True)

find_arial_font()

main_genes = ['FOXA2']

In [None]:
# Load anndata
n_top_genes = 2000
adata_path = ADATA_PATH_RNA_19_2K_HARMONY
adata = sc.read_h5ad(adata_path)
adata.X = adata.layers["counts"].copy()
assert adata.X.max() > 100 and adata.X.min() >= 0, "Has to be counts!"
adata.var["mito"] = adata.var_names.str.startswith("MT-")
adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
adata.var["hb"] = adata.var_names.str.contains(("^HB[^(P)]"))
sc.pp.calculate_qc_metrics(adata, qc_vars=["mito", "ribo", "hb"], inplace=True)

# Run SCVI
torch.set_float32_matmul_precision('medium' | 'high')
scvi.model.SCVI.setup_anndata(adata, categorical_covariate_keys=['sample'], continuous_covariate_keys=["pct_counts_mito", "pct_counts_ribo"],)
model_dir = os.path.join(dataDir, f"scvi_integrated_{n_top_genes}")
model = scvi.model.SCVI(adata, n_latent=32)
model.train()
model.save(model_dir, overwrite=True)

# Load the model
model = scvi.model.SCVI.load(model_dir, adata=adata)

SCVI_LATENT_KEY = "X_scVI"
latent = model.get_latent_representation()
adata.obsm[SCVI_LATENT_KEY] = latent
SCVI_NORMALIZED_KEY = "scvi_normalized"

adata.layers[SCVI_NORMALIZED_KEY] = model.get_normalized_expression(library_size=1e4)
adata.write(os.path.join(dataDir, f"sub_adata_{n_top_genes}_scvi.h5ad"))
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata)
sc.pl.umap(adata, color="sample", save=f'umap_with_batch_{n_top_genes}.pdf')
adata.write(os.path.join(dataDir, f"sub_adata_{n_top_genes}_scvi.h5ad"))


In [None]:
for i, gene in enumerate(main_genes):    
    adata = add_gene_binary_status(adata, main_genes, threshold=0, use_counts=True)
    adata.obs[f'{gene}_is_expressed_str'] = adata.obs[f'{gene}_is_expressed'].astype(str)
    sc.tl.embedding_density(adata, basis='umap', groupby=f'{gene}_is_expressed_str')
    sc.pl.embedding_density(adata, basis='umap', key=f'umap_density_{gene}_is_expressed_str', save=f"{gene}_expr_umap_densitypdf")
