In [1]:
import pandas as pd
import scanpy as sc
import numpy as np
import scvi
import dill
import anndata as ad

import os
from os import walk
from os.path import basename
import re

In [2]:
os.chdir("/home/fengtang/jupyter_notebooks/working_script/label_transfer/scANVI")

In [None]:
def run_scANVI(ref_data, test_data, filename):

    ref_exp_mat = pd.read_csv(ref_data, index_col = 0)
    ref_max_column = ref_exp_mat.shape[1]
    ref_adata = sc.AnnData(ref_exp_mat.iloc[:, ref_exp_mat.columns != "cell_type"])
 
    test_exp_mat = pd.read_csv(test_data, index_col = 0)
    test_exp_mat.index = [f"cell_{i}" for i in range(test_exp_mat.shape[0])]
    test_adata = sc.AnnData(test_exp_mat)

 
    
    adata = ad.concat([ref_adata, test_adata], join = "inner")
    
    adata.layers["counts"] = adata.X.copy()
    adata.obs["ref_or_not"] = ["ref" if not re.search("^cell_\\d+", i) else "test" for i in adata.obs.index]
    cell_type = ref_exp_mat.cell_type.tolist()
    cell_type.extend(["Unknown" for i in range(test_adata.shape[0])])
    adata.obs["cell_type"] = cell_type
    

    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)


    adata.raw = adata
    sc.pp.highly_variable_genes(
                   adata,
                   flavor="seurat_v3",
                   n_top_genes=2000,
                   layer="counts",
                   batch_key="ref_or_not",
                   subset=True,
     )

    
    scvi.model.SCVI.setup_anndata(adata, layer="counts")
    scvi_model = scvi.model.SCVI(adata)
    scvi_model.train()

    SCVI_LATENT_KEY = "X_scVI"
    adata.obsm[SCVI_LATENT_KEY] = scvi_model.get_latent_representation()

    SCANVI_CELLTYPE_KEY = "celltype_scanvi"

    adata.obs[SCANVI_CELLTYPE_KEY] = adata.obs["cell_type"].tolist()

    scanvi_model = scvi.model.SCANVI.from_scvi_model(
                    scvi_model,
                    adata=adata,
                    unlabeled_category="Unknown",
                    labels_key=SCANVI_CELLTYPE_KEY,
                    )
    scanvi_model.train()

    SCANVI_PREDICTION_KEY = "C_scANVI"
    adata.obs[SCANVI_PREDICTION_KEY] = scanvi_model.predict(adata)

    with open(f"{filename}_scANVI_prediction_on_multi_datasets.pkl", "wb") as file:
        dill.dump(adata, file)

    print(f"finish prediction on {filename}")

In [None]:
# run scANVI
file = pd.read_csv("/mnt/disk5/zhongmin/superscc/结果位置/结果位置_3.csv", encoding = "GBK")
file = file.loc[file.数据集.isin(["Banovich_Kropski_2020", "Barbry_Leroy_2020", "Krasnow_2020", "Lafyatis_Rojas_2019", "Nawijn_2021", "Teichmann_Meyer_2019"]), :]
file = file.iloc[:, [0, 10]]
file["ref_data"] = ["/home/fengtang/jupyter_notebooks/working_script/label_transfer/SuperSCC/train_datasets_from_different_study.csv" for i in range(6)]
file.columns = ["filename", "test_data", "ref_data"]
file = file[["ref_data", "test_data", "filename"]]

for idx, i in file.iterrows():
    run_scANVI(i[0], i[1], i[2])