# Data integration

In [None]:
# Python packages
import scanpy as sc
import scvi
import bbknn
import scib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import os
import anndata as ad

os.chdir("/data/home/wx/scislets") 

# R interface
from rpy2.robjects import pandas2ri
from rpy2.robjects import r
import rpy2.rinterface_lib.callbacks
import anndata2ri

pandas2ri.activate()
anndata2ri.activate()

%load_ext rpy2.ipython

In [None]:
adata_CT = sc.read(filename="processed/CT_quality_control.h5ad")
adata_ZP = sc.read(filename="processed/ZP_quality_control.h5ad")
adata_ANS = sc.read(filename="processed/ANS_quality_control.h5ad")

In [None]:
adata = adata_CT.concatenate(
    adata_ZP,adata_ANS,
    batch_categories=['CT', 'ZP','ANS']
)

In [None]:
adata.write("processed/adata_raw.h5ad")

In [None]:
#label_key = "manual_celltype_annotation"
batch_key = "batch"

In [None]:
adata

In [None]:
adata.X = adata.layers["counts"].copy()
adata.raw = adata#保存原始矩阵

In [None]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
adata.layers["logcounts"] = adata.X.copy()

In [None]:
sc.pp.highly_variable_genes(adata)
sc.tl.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color=[batch_key], wspace=1)

## Batch-aware feature selection

In [None]:
adata.X = adata.layers ['counts'].copy()
adata.raw = adata  # keep full dimension safe
adata.X = adata.layers ['logcounts'].copy()

In [None]:
sc.pp.highly_variable_genes(
    adata, 
    n_top_genes=2000, 
    flavor="cell_ranger",
    batch_key=batch_key,
    subset = True,
    layer="counts",
)

In [None]:
print(adata.raw.X[0:10,0:10])

## Variational autoencoder (VAE) based integration

### Data preparation

In [None]:
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key=batch_key)

### Building the model

In [None]:
model_scvi = scvi.model.SCVI(adata,n_layers=2, n_latent=30, gene_likelihood="nb")
model_scvi

In [None]:
model_scvi.view_anndata_setup()

### Training the model

In [None]:
max_epochs_scvi = np.min([round((20000 / adata.n_obs) * 400), 400])
max_epochs_scvi

In [None]:
model_scvi.train()

### Extracting the embedding 

In [None]:
SCVI_LATENT_KEY = "X_scVI"
adata.obsm[SCVI_LATENT_KEY] = model_scvi.get_latent_representation()

In [None]:
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.leiden(adata)

In [None]:
SCVI_MDE_KEY = "X_scVI_MDE"
adata.obsm[SCVI_MDE_KEY] = scvi.model.utils.mde(adata.obsm[SCVI_LATENT_KEY])

### Calculate a batch-corrected UMAP

In [None]:
sc.pl.embedding(
    adata,
    basis=SCVI_MDE_KEY,
    color=["batch", "leiden"],
    frameon=False,
    ncols=1,
)

In [None]:
print(adata.raw.X)

In [None]:
adata.write("processed/adata_scvi_interagted.h5ad")