### Process outline

1. Follow the scVI tutorial: https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/harmonization.html

In [30]:
import os

import scanpy as sc
import scvi
import seaborn as sns
import torch
from rich import print
from scib_metrics.benchmark import Benchmarker

import warnings
warnings.filterwarnings("ignore")

In [31]:
scvi.settings.seed = 0 # for reproducibility
print("Last run with scvi-tools version:", scvi.__version__)

[rank: 0] Seed set to 0


In [32]:
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

In [33]:
device_count = torch.cuda.device_count()
device_count

1

In [34]:
print(torch.cuda.is_available())

In [35]:
# Load all AnnData objects into a list

from pathlib import Path
from itertools import chain

GSE132509_directory = Path('/QRISdata/Q6104/Xiaohan/2_AnnData_objs/GSE132509')
GSE236351_directory = Path('/QRISdata/Q6104/Xiaohan/2_AnnData_objs/GSE236351')
GSE148218_directory = Path('/QRISdata/Q6104/Xiaohan/2_AnnData_objs/GSE148218')

combined_dirs = chain(GSE132509_directory.iterdir(), GSE236351_directory.iterdir(), GSE148218_directory.iterdir())
adatas = []
for adata_path in combined_dirs:
    if "_uni.h5ad" in adata_path.name:
        adata = sc.read_h5ad(adata_path)
        adatas.append(adata)

print(len(adatas))
print(adatas[0])

In [36]:
# Find out common genes among all AnnData objects
common_genes = set(adatas[0].var_names)
for adata in adatas[1:]:
    common_genes.intersection_update(adata.var_names)

print(len(common_genes))

In [37]:
# Filter all AnnData objects with common genes
adatas_common_genes = []
for adata in adatas:
    adata_common_genes = adata[:, list(common_genes)]
    # print(adata_common_genes.shape)
    adatas_common_genes.append(adata_common_genes)

### <span style="color:yellow">**Preprocessing:**</span> normalization & log transformation

Follow the scanpy preprocessing tutorial: https://scanpy-tutorials.readthedocs.io/en/latest/pbmc3k.html

Use the preprocessing package from dandelion to filter out cell and gene outliers

In [17]:
from dandelion.preprocessing.external._preprocessing import recipe_scanpy_qc

adatas_filtered = [] 

for adata in adatas_common_genes:
    adata.raw = adata

    # Do QC and filtering
    recipe_scanpy_qc(adata)
    adata = adata[adata.obs.filter_rna == 'False', :]

    # Do normalization
    sc.pp.normalize_total(adata)

    # Do the log transformation
    sc.pp.log1p(adata)

    adatas_filtered.append(adata)

In [18]:
# Create a merged AnnData for all filtered Anndata objects
adatas_filtered_all = sc.AnnData.concatenate(*adatas_filtered)

In [19]:
adatas_filtered_all

AnnData object with n_obs × n_vars = 85282 × 14071
    obs: 'cancer_type', 'dataset', 'tissue', 'sample_barcode', 'uni_barcode', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'gmm_pct_count_clusters_keep', 'scrublet_score', 'is_doublet', 'filter_rna', 'batch'
    var: 'gene_ids-0', 'gene_ids-1', 'gene_ids-10', 'gene_ids-11', 'feature_types-11', 'gene_ids-12', 'feature_types-12', 'gene_ids-13', 'feature_types-13', 'gene_ids-14', 'feature_types-14', 'gene_ids-15', 'feature_types-15', 'gene_ids-16', 'feature_types-16', 'gene_ids-17', 'feature_types-17', 'gene_ids-18', 'feature_types-18', 'gene_ids-19', 'feature_types-19', 'gene_ids-2', 'gene_ids-20', 'feature_types-20', 'gene_ids-21', 'feature_types-21', 'gene_ids-22', 'feature_types-22', 'gene_ids-23', 'feature_types-23', 'gene_ids-24', 'feature_types-24', 'gene_ids-25', 'feature_types-25', 'gene_ids-3', 'gene_ids-4', 'gene_ids-5', 'gene_ids-6', 'gene_ids-7', 'gene_ids-8', 'gene_ids-9'

In [21]:
# Select highly variable genes
sc.pp.highly_variable_genes(
    adatas_filtered_all,
    flavor="seurat_v3",
    n_top_genes=2000,
    batch_key="sample_barcode",
    subset=True,
)

In [22]:
adatas_filtered_all

AnnData object with n_obs × n_vars = 85282 × 2000
    obs: 'cancer_type', 'dataset', 'tissue', 'sample_barcode', 'uni_barcode', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'gmm_pct_count_clusters_keep', 'scrublet_score', 'is_doublet', 'filter_rna', 'batch'
    var: 'gene_ids-0', 'gene_ids-1', 'gene_ids-10', 'gene_ids-11', 'feature_types-11', 'gene_ids-12', 'feature_types-12', 'gene_ids-13', 'feature_types-13', 'gene_ids-14', 'feature_types-14', 'gene_ids-15', 'feature_types-15', 'gene_ids-16', 'feature_types-16', 'gene_ids-17', 'feature_types-17', 'gene_ids-18', 'feature_types-18', 'gene_ids-19', 'feature_types-19', 'gene_ids-2', 'gene_ids-20', 'feature_types-20', 'gene_ids-21', 'feature_types-21', 'gene_ids-22', 'feature_types-22', 'gene_ids-23', 'feature_types-23', 'gene_ids-24', 'feature_types-24', 'gene_ids-25', 'feature_types-25', 'gene_ids-3', 'gene_ids-4', 'gene_ids-5', 'gene_ids-6', 'gene_ids-7', 'gene_ids-8', 'gene_ids-9', 'highly_variab

### <span style="color:yellow">**Integration with scVI**</span> 

In [23]:
scvi.model.SCVI.setup_anndata(adatas_filtered_all, batch_key="sample_barcode")

In [None]:
# Visualize the data before integration
sc.tl.pca(adatas_filtered_hvg_all) # Calculate the PCA embeddings
sc.pp.neighbors(adatas_filtered_hvg_all) # Determine the kNN graph
sc.tl.umap(adatas_filtered_hvg_all) # Calculate the UMAP

In [None]:
sc.pl.umap(adatas_filtered_hvg_all, color=['dataset'])
sc.pl.umap(adatas_filtered_hvg_all, color=['sample_barcode'])

In [24]:
model = scvi.model.SCVI(adatas_filtered_all, n_layers=2, n_latent=30, gene_likelihood="nb")

In [25]:
# Train the scVI model
model.train()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 94/94: 100%|███████████████████████| 94/94 [04:58<00:00,  3.15s/it, v_num=1, train_loss_step=320, train_loss_epoch=330]

`Trainer.fit` stopped: `max_epochs=94` reached.


Epoch 94/94: 100%|███████████████████████| 94/94 [04:58<00:00,  3.18s/it, v_num=1, train_loss_step=320, train_loss_epoch=330]


In [26]:
SCVI_LATENT_KEY = "X_scVI"
adatas_filtered_all.obsm[SCVI_LATENT_KEY] = model.get_latent_representation()

In [27]:
sc.pp.neighbors(adatas_filtered_all, use_rep=SCVI_LATENT_KEY)
sc.tl.leiden(adatas_filtered_all)

In [29]:
SCVI_MDE_KEY = "X_scVI_MDE"
adatas_filtered_all.obsm[SCVI_MDE_KEY] = scvi.model.utils.mde(adatas_filtered_all.obsm[SCVI_LATENT_KEY])

[34mINFO    [0m Using cu[1;92mda:0[0m for `pymde.preserve_neighbors`.                                                              


SolverError: Function evaluation returned inf.

In [None]:
# Do the UMAP to visualize the integration results
sc.pp.neighbors(adata_integrated, use_rep='X_scanorama')
sc.tl.umap(adata_integrated)

In [None]:
sc.pl.umap(adata_integrated, color=['dataset'])
sc.pl.umap(adata_integrated, color=['sample_barcode'])

### <span style="color:yellow">**Attempt 3:**</span> concatenate all AnnData objects and find highly variable genes together and scale individually

In [None]:
# The preprocessing is the same as Attempt 2,
# but before integration, we scale the gene expression of the meta AnnData

sc.pp.scale(adatas_filtered_hvg_all, max_value=10)

# Visualize the data before integration
sc.tl.pca(adatas_filtered_hvg_all) # Calculate the PCA embeddings
sc.pp.neighbors(adatas_filtered_hvg_all) # Determine the kNN graph
sc.tl.umap(adatas_filtered_hvg_all) # Calculate the UMAP

In [None]:
sc.pl.umap(adatas_filtered_hvg_all, color=['dataset'])
sc.pl.umap(adatas_filtered_hvg_all, color=['sample_barcode'])

In [None]:
# Split the meta AnnData
adatas_filtered_hvg_scaled = []

for batch in adatas_filtered_hvg_all.obs['batch'].unique():
    adatas_filtered_hvg_scaled.append(adatas_filtered_hvg_all[adatas_filtered_hvg_all.obs['batch']==batch].copy())

print(len(adatas_filtered_hvg_scaled))
print(adatas_filtered_hvg_scaled[0])

In [None]:
# Now we run Scanorama on the split data.
import scanorama

corrected = scanorama.correct_scanpy(adatas_filtered_hvg_scaled, return_dimred=True)

# Concatenate the integrated AnnData objects
adata_integrated_scaled = sc.AnnData.concatenate(*corrected)
print(adata_integrated_scaled)

In [None]:
# Do the UMAP to visualize the integration results
sc.pp.neighbors(adata_integrated_scaled, use_rep='X_scanorama')
sc.tl.umap(adata_integrated_scaled)

In [None]:
sc.pl.umap(adata_integrated_scaled, color=['dataset'])
sc.pl.umap(adata_integrated_scaled, color=['sample_barcode'])