In [None]:
import pandas as pd
import scanpy as sc
import numpy as np
import scvi
import dill

import os
from os import walk
from os.path import basename
import re

In [3]:
os.chdir("/home/fengtang/jupyter_notebooks/working_script/label_transfer/scANVI")

In [None]:
def run_scANVI(count, ref_data, test_data, filename):

    exp_mat = pd.read_csv(count, index_col = 0)
    ref_cell_index = pd.read_csv(ref_data, index_col = 0)
    test_cell_index = pd.read_csv(test_data, index_col = 0)


    adata = sc.AnnData(exp_mat)
    adata.layers["counts"] = adata.X.copy()
    adata.obs["ref_or_not"] = ["ref" if i in ref_cell_index.index.tolist() else "test" for i in exp_mat.index]
    adata.obs["cell_type"] = [ ref_cell_index["cell_type"].tolist()[ref_cell_index.index.tolist().index(i)] if i in ref_cell_index.index.tolist() else "Unknown" for i in exp_mat.index]
    

    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)



    adata.raw = adata
    sc.pp.highly_variable_genes(
                   adata,
                   flavor="seurat_v3",
                   n_top_genes=2000,
                   layer="counts",
                   batch_key="ref_or_not",
                   subset=True,
     )
    

    scvi.model.SCVI.setup_anndata(adata, layer="counts")
    scvi_model = scvi.model.SCVI(adata)
    scvi_model.train()

    SCVI_LATENT_KEY = "X_scVI"
    adata.obsm[SCVI_LATENT_KEY] = scvi_model.get_latent_representation()

    SCANVI_CELLTYPE_KEY = "celltype_scanvi"

    adata.obs[SCANVI_CELLTYPE_KEY] = adata.obs["cell_type"].tolist()

    scanvi_model = scvi.model.SCANVI.from_scvi_model(
                    scvi_model,
                    adata=adata,
                    unlabeled_category="Unknown",
                    labels_key=SCANVI_CELLTYPE_KEY,
                    )
    scanvi_model.train()

    SCANVI_PREDICTION_KEY = "C_scANVI"
    adata.obs[SCANVI_PREDICTION_KEY] = scanvi_model.predict(adata)

    with open(f"{filename}_scANVI_prediction.pkl", "wb") as file:
        dill.dump(adata, file)

    print(f"finish prediction on {filename}")

In [None]:
# run scANVI
file = pd.read_csv("/home/fengtang/jupyter_notebooks/working_script/label_transfer/SingleCellNet/label_transfer_evulate_data_loc.csv", index_col=0)
file["test_data"] = [re.sub("代码/", "代码/SuperSCC/finest_cell_label_res/",i) for i in file["test_data"].tolist()]
file["ref_data"] = [re.sub("代码/", "代码/SuperSCC/finest_cell_label_res/", i) for i in file["ref_data"].tolist()]

for idx, i in file.iterrows():
    run_scANVI(i[0], i[1], i[2], i[3])

In [None]:
# tidy up the prediction res
files = list()
for root, dir, file, in walk(os.getcwd()):
    for i in file:
        if re.search(".+pkl$", i):
            files.append(i)
            
for i in files:
    data = pd.read_pickle(i)
    csv = data.obs
    filename = re.sub("_scANVI_prediction.pkl", "",i)
    csv.to_csv(f"{filename}.csv")