In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn_ann.kneighbors.annoy import AnnoyTransformer

In [None]:
adata = sc.read_h5ad('/mnt/storage/Daniele/atlases/mouse/13_mouse_all_annotated.h5ad')

In [None]:
from sklearn.neighbors import KNeighborsClassifier

def apply_knn_labeling(adata, embedding_key, label_key, n_neighbors):

    confident_mask = adata.obs[label_key] != "Missclassified"
    X_train = adata.obsm[embedding_key][confident_mask]
    y_train = adata.obs[label_key][confident_mask].values
    X_test_mask = ~confident_mask
    X_test = adata.obsm[embedding_key][X_test_mask]

    # Fit k-NN
    knn = KNeighborsClassifier(n_neighbors=n_neighbors)
    knn.fit(X_train, y_train)

    # Predict
    predicted_labels = knn.predict(X_test)

    # Update AnnData object
    adata.obs[f'{label_key}_knn'] = adata.obs[label_key].copy()
    adata.obs.loc[X_test_mask, f'{label_key}_knn'] = predicted_labels

    return adata


## annotation

In [None]:
threshold = 0

In [None]:
sc.pl.umap(adata, color = 'Level_3')

In [None]:
adata.obs['Level_3'].unique().tolist()

In [None]:
dictionary_maps = []

## Macrophages

In [None]:
adata.obs['Level_3'].replace('Macrophage - M1 TAM', 'Macrophage - M1-like TAM', inplace=True)

In [None]:
M = adata[adata.obs['Level_3'].isin(["Macrophage - M2 TAM"])].copy()

In [None]:
sc.pp.neighbors(M, transformer=AnnoyTransformer(15), use_rep='scANVI_emb')
sc.tl.leiden(M, flavor = 'igraph', resolution = .1)
sc.tl.umap(M, min_dist=0.5, spread=1.0)

In [None]:
sc.pl.umap(M, color = 'leiden')

In [None]:
sc.tl.rank_genes_groups(M, 'leiden', )

In [None]:
sc.pl.rank_genes_groups_dotplot(M, n_genes=10, values_to_plot='logfoldchanges', cmap='coolwarm', )

In [None]:
adata.obs['Level_3'].replace('Macrophage - M2 TAM', 'Macrophage - M2-like TAM', inplace=True)

In [None]:
del M

# T cells

### cd4 T cells

In [None]:
cd4 = adata[adata.obs['Level_3'].isin(["CD4+ T Cell"])].copy()

In [None]:
cd4_subsets = {
    "CD4+ Th1 Cell": ["Stat4","Cxcr3", 'Ifng'],
    "CD4+ Th2 Cell": ["Gata3", "Ccr4",'Ptgr2'],
    "CD4+ Th17 Cell": ["Il17a", "Il17f", "Rora", 'Klrb1', 'Ccr6'],
    "CD4+ Th22 Cell": ["Il22", "Ccr10", "Foxo4"],
    "CD4+ Naive T Cell": ["Ccr7", "Sell", "Lef1", "Tcf7"],
    "CD4+ Memory T Cell": ["Il7r", "Gpr183", "Cd69"],
    "γδ T Cell (Vδ1)": ["Trdc"],
    "T-reg": ["Foxp3","Il2ra","Ctla4","Tnfrsf4"],
    "Double Positive CD4+CD8+ T Cell": ["Cd4", "Cd8a", "Cd8b1"],
}


In [None]:
for cell_type, markers in cd4_subsets.items():
    sc.tl.score_genes(cd4, gene_list=markers, score_name=cell_type)
celltypes = []
scores = cd4.obs[list(cd4_subsets.keys())].values

max_indices = np.argmax(scores, axis=1)
celltypes = np.array(list(cd4_subsets.keys()))[max_indices]

threshold = 0.0
max_scores = scores[np.arange(scores.shape[0]), max_indices]
celltypes[max_scores < threshold] = "Missclassified"

cd4.obs['celltype'] = celltypes


In [None]:
cd4.obs['celltype'].replace('Missc', 'Missclassified', inplace=True)
sc.pl.dotplot(cd4, groupby = 'celltype', var_names = cd4_subsets, vmax=1, vmin=0)
sc.pl.dotplot(cd4, groupby = 'celltype', var_names = ['Cd3e', 'Cd4', 'Cd8a', 'Cd68'], vmax=1, vmin=0)
sc.pl.dotplot(cd4, groupby = 'celltype', var_names = list(cd4_subsets.keys()), vmax=1, vmin=0)

In [None]:
sc.pl.dotplot(cd4, groupby = 'celltype', var_names = cd4_subsets, vmax=1, vmin=0)
sc.pl.dotplot(cd4, groupby = 'celltype', var_names = ['Cd3e', 'Cd4', 'Cd8a', 'Cd68'], vmax=1, vmin=0)
sc.pl.dotplot(cd4, groupby = 'celltype', var_names = list(cd4_subsets.keys()), vmax=1, vmin=0)

In [None]:
cd4_cells_map = {k:v for k,v in zip(list(cd4.obs_names), list(cd4.obs.celltype))}
dictionary_maps.append(cd4_cells_map)

In [None]:
len(dictionary_maps)

### cd8 T cells

In [None]:
cd8 = adata[adata.obs['Level_3'].isin(["CD8+ T Cells"])].copy()

In [None]:
cd8_subsets = {
    "CD8+ Effector T Cell": ["Gzmb", "Gzmk", "Prf1","Ifng"],
    "CD8+ Exhausted T Cell": ["Pdcd1","Havcr2","Lag3","Tox",],
    "CD8+ Memory T Cell": ["Itga1","Dkk3","Ccr4", "Klrb1"],
    "CD8+ Naive T Cell": ["Sell","Ccr7","Tcf7",],
    "CD8+ Terminal Effector T Cell": ["Zeb2","Gzmb","Ifng","Tbx21"],
    "CD8+ Tissue-Resident Memory T Cell": ["Cd69","Itgae","Runx3","Cxcr6"],
}


In [None]:
for cell_type, markers in cd8_subsets.items():
    sc.tl.score_genes(cd8, gene_list=markers, score_name=cell_type)
celltypes = []
scores = cd8.obs[list(cd8_subsets.keys())].values

max_indices = np.argmax(scores, axis=1)
celltypes = np.array(list(cd8_subsets.keys()))[max_indices]

threshold = 0.0
max_scores = scores[np.arange(scores.shape[0]), max_indices]
celltypes[max_scores < threshold] = "Missclassified"

cd8.obs['celltype'] = celltypes


In [None]:
cd8.obs['celltype'].value_counts()


In [None]:
cd8.obs['celltype'].replace('Missc', 'Missclassified', inplace=True)
sc.pl.dotplot(cd8, groupby = 'celltype', var_names = cd8_subsets, vmax=1, vmin=0)
sc.pl.dotplot(cd8, groupby = 'celltype', var_names = ['Cd4', 'Cd8a'], vmax=1, vmin=0)
sc.pl.dotplot(cd8, groupby = 'celltype', var_names = list(cd8_subsets.keys()), vmax=1, vmin=0)

In [None]:
cd8_cells_map = {k:v for k,v in zip(list(cd8.obs_names), list(cd8.obs.celltype))}
dictionary_maps.append(cd8_cells_map)

In [None]:
len(dictionary_maps)

# Malignant epithelial

In [None]:
ME = adata[adata.obs['Level_3'].isin(['Malignant Cell - Epithelial', 'Malignant Cell - Mesenchymal'])].copy()

In [None]:
me_markers = {
    "Malignant Cell - Epithelial": ["Epcam", "Cldn4", "Cldn7"],
    "Malignant Cell - Pit Like": ["Gkn1", "Gkn2", "Gkn3", "Cldn18"],
    "Malignant Cell - Hypoxia": ["Hif1a", "Vegfa"],
    "Malignant Cell - Highly Proliferative": ["Mki67", "Cenpf", "Top2a"],
    "Malignant Cell - EMT": ["Zeb1", "Twist1", "Cdh2"],
    "Malignant Cell - Acinar-like": ["Reg3b", "Reg3g", "Cpa1"],
    "Malignant Cell - Hihgly Invasive": ["Mmp9", "Mmp2", "Mmp14"],
    "Malignant Cell - Senescence": ["Cdkn1a", "Cdkn2a", "Trp53"],
    "Malignant Cell - Apoptotic": ["Bax", "Bcl2", "Fas"],
    "Malignant Cell - Mesenchymal": ["Cdh2", "Col3a1"],
}


In [None]:
for cell_type, markers in me_markers.items():
    sc.tl.score_genes(ME, gene_list=markers, score_name=cell_type)
celltypes = []
scores = ME.obs[list(me_markers.keys())].values

max_indices = np.argmax(scores, axis=1)
celltypes = np.array(list(me_markers.keys()))[max_indices]

#np.mean(scores) * np.std(scores) * 50
max_scores = scores[np.arange(scores.shape[0]), max_indices]
celltypes[max_scores < threshold] = "Missclassified"

ME.obs['celltype'] = celltypes


In [None]:
ME.obs['celltype'].value_counts()

In [None]:
ME.obs['celltype'].replace({'Missclass': 'Missclassified'}, inplace=True)
sc.pl.dotplot(ME,  groupby = 'celltype', var_names = list(me_markers.keys()), vmax=1, vmin=0)
sc.pl.dotplot(ME,  groupby = 'celltype', var_names = me_markers, vmax=1, vmin=0)

In [None]:
me_cells_map = {k:v for k,v in zip(list(ME.obs_names), list(ME.obs.celltype))}
dictionary_maps.append(me_cells_map)

In [None]:
len (dictionary_maps)

## add to anndata

In [None]:
dictionary_maps_ = {k: v for d in dictionary_maps for k, v in d.items()}


In [None]:
anno_map = pd.Series(dictionary_maps_)

In [None]:
anno_map.value_counts()

In [None]:
adata.obs['Level_4'] = anno_map
adata.obs['Level_4'] = adata.obs['Level_4'].fillna(adata.obs['Level_3'])

In [None]:
adata.obs['Level_4'].value_counts()

In [None]:
adata

In [None]:
adata = apply_knn_labeling(adata, embedding_key="scANVI_emb", label_key="Level_4", n_neighbors=15)

In [None]:
adata.obs['Level_4_knn'].value_counts()

In [None]:
adata.obs['Level_4'].replace('CD4+ Naive Cell', 'CD4+ Naive T Cell', inplace=True)
adata.obs['Level_4_knn'].replace('CD4+ Naive Cell', 'CD4+ Naive T Cell', inplace=True)

In [None]:
adata.write_h5ad('/mnt/storage/Daniele/atlases/mouse/14_mouse_final_annotation.h5ad')