In [None]:
# Use scVI to integrate and batch correct

# Setup environment


// Remove conda env if it exist
conda deactivate
conda remove -n scvi-env --all

// Create conda env
conda create -n scvi-env python=3.10
conda activate scvi-env

// Install Pytorch
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia 

// Install Jax
conda install jax -c conda-forge

// Install scvi-tools
// conda install scvi-tools -c conda-forge
pip install -U scvi-tools

// Check that it worked
python -c "import torch; torch.cuda.is_available()"
python -c "import scvi"

// Install scanpy 
pip install scanpy, ipython



In [None]:
import os
import scvi
import scanpy as sc
import torch
import anndata
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import harmonypy
import pickle
import numpy as np
import matplotlib as mpl
import matplotlib.font_manager
from matplotlib import font_manager
from matplotlib.font_manager import fontManager, FontProperties
import infercnvpy as cnv


def setup_dirs(outDir):
    figuresDir = os.path.join(outDir, 'figures')
    dataDir = os.path.join(outDir, 'data')
    tablesDir = os.path.join(outDir, 'tables')
    os.makedirs(figuresDir, exist_ok=True)
    os.makedirs(dataDir, exist_ok=True)
    os.makedirs(tablesDir, exist_ok=True)
    return figuresDir, dataDir, tablesDir

def force_arial():
    arial_font_path = '/home/salehis/projects/cdm/fonts/arial.ttf'
    font_manager.fontManager.addfont(arial_font_path)
    prop = font_manager.FontProperties(fname=arial_font_path)
    print("Arial font forced")

# set the font
def find_arial_font():
    arial_font = None
    for font in font_manager.findSystemFonts():
        #if font.lower().endswith("arial.ttf"):
        if "arial" in font.lower():
            arial_font = font
            break
        if arial_font:
            print("Found Arial font at: ", arial_font)
            prop = font_manager.FontProperties(fname=arial_font)
            sns.set(font=prop.get_name())
    if arial_font is None:
        print("Arial font not found")
        force_arial()

def add_gene_binary_status(adata, gene, threshold=np.log1p(1), use_counts=False):
    """
    Find a cut-off for expressed vs not expressed
    very simple, expressed, if there is more than 1 count (i.e., log1p > log(1+1) = np.log1p(1))
    add a column to adata.obs, {gene}_is_expressed
    """    
    assert gene in adata.var_names, f"Gene {gene} not in adata.var_names..."
    # drop the column if it exists
    if f'{gene}_is_expressed' in adata.obs.columns:
        adata.obs.drop(f'{gene}_is_expressed', axis=1, inplace=True)
    if f'{gene}_EXPR' in adata.obs.columns:
        adata.obs.drop(f'{gene}_EXPR', axis=1, inplace=True)
    if use_counts:
        adata.obs[f'{gene}_EXPR'] = adata.layers['counts'][:, (adata.var_names == gene)].toarray()
    else:
        adata.obs[f'{gene}_EXPR'] = adata[:, gene].X.A
    adata.obs[f'{gene}_is_expressed'] = adata.obs[f'{gene}_EXPR'] > threshold
    adata.obs[f'{gene}_is_expressed'] = adata.obs[f'{gene}_is_expressed'].astype('category')    
    return adata 


In [None]:
outDir = '/data1/shahs3/users/salehis/sclc/results/rebuttal/nat_comms/integration_scvi'
figuresDir, dataDir, tablesDir = setup_dirs(outDir)

sc.settings.figdir = figuresDir
sc.set_figure_params(dpi_save=300, vector_friendly=True)
find_arial_font()

rsync -azvp --relative \
    iris:/data1/shahs3/users/salehis/sclc/./results//rebuttal/nat_comms/integration_scvi/figures/*.p* \
    /Users/salehis/Projects/sclc/rebuttal_code/SCLC_MET/

In [None]:
# Load anndata
n_top_genes = 2000
adata_path = '/data1/shahs3/users/salehis/sclc/results/patient_met/foxa2_umaps_19/data/rna_19_2K_harmony.h5ad'
adata = sc.read_h5ad(adata_path)
adata.X = adata.layers["counts"].copy()
assert adata.X.max() > 100 and adata.X.min() >= 0, "Has to be counts!"
adata.var["mito"] = adata.var_names.str.startswith("MT-")
adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
adata.var["hb"] = adata.var_names.str.contains(("^HB[^(P)]"))
sc.pp.calculate_qc_metrics(adata, qc_vars=["mito", "ribo", "hb"], inplace=True)

# Run SCVI
torch.set_float32_matmul_precision('medium' | 'high')
scvi.model.SCVI.setup_anndata(adata, categorical_covariate_keys=['sample'], continuous_covariate_keys=["pct_counts_mito", "pct_counts_ribo"],)
model_dir = os.path.join(dataDir, f"scvi_integrated_{n_top_genes}")
model = scvi.model.SCVI(adata, n_latent=32)
model.train()
model.save(model_dir, overwrite=True)

# Load the model
# model = torch.load(model_dir)
model = scvi.model.SCVI.load(model_dir, adata=adata)

SCVI_LATENT_KEY = "X_scVI"
latent = model.get_latent_representation()
adata.obsm[SCVI_LATENT_KEY] = latent
SCVI_NORMALIZED_KEY = "scvi_normalized"

adata.layers[SCVI_NORMALIZED_KEY] = model.get_normalized_expression(library_size=1e4)
adata.write(os.path.join(dataDir, f"sub_adata_{n_top_genes}_scvi.h5ad"))
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata)
sc.pl.umap(adata, color="sample", save=f'umap_with_batch_{n_top_genes}.pdf')
adata.write(os.path.join(dataDir, f"sub_adata_{n_top_genes}_scvi.h5ad"))


In [None]:
main_genes = ['FOXA2']
for i, gene in enumerate(main_genes):    
    adata = add_gene_binary_status(adata, 'FOXA2', threshold=0, use_counts=True)
    adata.obs[f'{gene}_is_expressed_str'] = adata.obs[f'{gene}_is_expressed'].astype(str)
    sc.tl.embedding_density(adata, basis='umap', groupby=f'{gene}_is_expressed_str')
    sc.pl.embedding_density(adata, basis='umap', key=f'umap_density_{gene}_is_expressed_str', save=f"{gene}_expr_umap_densitypdf")
