In [None]:

import scanpy as sc
import anndata as ad

from sklearn.neighbors import NearestNeighbors
import numpy as np
import pandas as pd

import os
from pathlib import Path
import re

NN = 20

#current_folder = Path(__file__).parent
current_folder = globals()["_dh"][0]
out_folder = current_folder / ".." / ".." / "out"
sc._settings.ScanpyConfig.figdir = out_folder  # where scanpy saves plots
cellbender_out = current_folder / ".." / ".." / "data" / "cellbender_out"
raw_out = current_folder / ".." / ".." / "data" / "raw" / "sc"

samples = [sample for sample in os.listdir(cellbender_out) if not sample.startswith(".")]

# load the annoated object
annoated = sc.read_h5ad(current_folder / ".." / ".." / "data" / "uscsc_dump" / "annotated.h5ad")
annoated.obs.index = [re.sub("-[0-9]+$", "", barcode) for barcode in annoated.obs.index]
hvg = annoated.var_names.to_numpy()

# load raw data
raw_adata_objects = {sample: sc.read_10x_mtx(raw_out / sample / "filtered_feature_bc_matrix") for sample in samples}

# load the cellbender output
adata_objects = {sample: sc.read_h5ad(cellbender_out / sample / "cell_bender_matrix_filtered_qc.h5") for sample in samples}


In [None]:
sample = list(adata_objects.keys())[0]
print(sample)

In [None]:
annotated_ids = annoated.obs.index[annoated.obs.sample_id==sample]
adata_raw_annot = raw_adata_objects[sample]
annotated_ids = np.intersect1d(annotated_ids, adata_raw_annot.obs.index)
adata_raw_annot = adata_raw_annot[annotated_ids, hvg]
obs = annoated.obs[annoated.obs.sample_id==sample]
adata_raw_annot.obs = obs.loc[annotated_ids, :]

In [None]:
# get the 
obs = annoated.obs[annoated.obs.sample_id==sample]
obs = obs.loc[annotated_ids, :]

In [None]:
adata_no_annot = adata_objects[sample]
#sc.pp.normalize_total(adata_no_annot, target_sum=1e4)
#sc.pp.log1p(adata_no_annot)

In [None]:
#adata_annot = annoated[annoated.obs.sample_id==sample, ]
adata_annot = adata_raw_annot

In [None]:
# get the set of shared var_names
shared_var_names = list(set(adata_annot.var_names) & set(adata_no_annot.var_names))
adata_annot = adata_annot[:, shared_var_names]
adata_no_annot_tmp = adata_no_annot[:, shared_var_names].copy()  # copy, because I want to keep all genes for alter on

In [None]:
# how many cells are in adata_no_annot that are not in adata_annot?
print("Number of cells in adata_no_annot: " + str(len(adata_no_annot_tmp.obs_names)))
print("Number of cells in adata_annot: " + str(len(adata_annot.obs_names)))

cells_oi = set(adata_no_annot_tmp.obs_names) - set(adata_annot.obs_names)
print("Number of cells in adata_no_annot that are not in adata_annot: " + str(len(cells_oi)))

# how many cells are in adata_no_annot that are not in adata_annot?
cells_oi = set(adata_annot.obs_names) - set(adata_no_annot_tmp.obs_names)
print("Number of cells in adata_annot that are not in adata_no_annot: " + str(len(cells_oi)))

In [None]:
# add prefixes to cell barcodes and concat
adata_annot.obs_names = ["annot_" + name for name in adata_annot.obs_names]
adata_no_annot_tmp.obs_names = ["not_annot_" + name for name in adata_no_annot_tmp.obs_names]
adata_concat = ad.concat([adata_annot, adata_no_annot_tmp], join="outer", label="annotation", keys=["annotated", "unannotated"])

In [None]:
adata_concat

In [None]:
sc.pp.normalize_total(adata_concat, target_sum=1e4)
sc.pp.log1p(adata_concat)
sc.pp.pca(adata_concat)

In [None]:
# make umap 
sc.pp.neighbors(adata_concat, n_neighbors=15)
sc.tl.umap(adata_concat)

In [None]:
# plot umap with label according to annotation, TODO: save plot for each sample (QC)
sc.pl.umap(adata_concat, color="annotation")

In [None]:
sc.external.pp.bbknn(adata_concat, batch_key="annotation")

In [None]:
sc.tl.umap(adata_concat)
sc.pl.umap(adata_concat, color="annotation")

In [None]:


# integrate?


# joint pca 
sc.pp.pca(adata_concat)

# split back into annotated and unannotated and remove prefixes
adata_annot_tmp = adata_concat[adata_concat.obs.index.str.startswith("annot_"), ]
adata_annot_tmp.obs_names = [re.sub("^annot_", "", name) for name in adata_annot_tmp.obs_names]
adata_no_annot_tmp = adata_concat[adata_concat.obs.index.str.startswith("not_annot_"), ]
adata_no_annot_tmp.obs_names = [re.sub("^not_annot_", "", name) for name in adata_no_annot_tmp.obs_names]

# for each row in adata_not_annot_pca search for the 20 nearest neighbors in adata_annot_pca
nbrs = NearestNeighbors(n_neighbors=NN, algorithm='ball_tree').fit(adata_annot_tmp.obsm["X_pca"])
distances, indices = nbrs.kneighbors(adata_no_annot_tmp.obsm["X_pca"])

annotation = []
for cell_i in range(len(adata_no_annot)):
    barcode = adata_no_annot.obs.index[cell_i]
    # if the barcode is present in the annotated object, then use the annotation
    if barcode in adata_annot.obs.index:
        annotation.append(adata_annot.obs.loc[barcode, "cell_type"])
    # else, use the annotation of the nearest neighbors
    else:
        nearest_neighbors = indices[cell_i, :]
        nn_annot = adata_annot.obs.iloc[nearest_neighbors, :]["cell_type"].value_counts()
        # if max count below 50%, then set to "unannotated"
        if nn_annot.iloc[0] < NN/2:
            annotation.append("unannotated")
        else:
            # get label with max count
            annotation.append(nn_annot.index[0])
adata_no_annot.obs["cell_type"] = annotation

print(adata_no_annot.obs["cell_type"].value_counts())

print(adata_no_annot)

# save the annotated object
adata_no_annot.write_h5ad(cellbender_out / sample / "cell_bender_matrix_filtered_qc_annotated.h5ad")