# 4. Label transfer of STARmap

2023-05-05

In [None]:
# Import Packages

%load_ext autoreload
%autoreload 2

import os
import warnings 
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import anndata as ad
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from anndata import AnnData
from natsort import natsorted
from tqdm.notebook import tqdm
from statannotations.Annotator import Annotator

# Customized packages
import starmap.sc_util as su
# test()

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

## Set path

In [None]:
# Set path
base_path = 'path/to/dataset/folder'

input_path = os.path.join(base_path, 'input')

out_path = os.path.join(base_path, 'output')
if not os.path.exists(out_path):
    os.mkdir(out_path)
    
fig_path = os.path.join(base_path, 'figures')
if not os.path.exists(fig_path):
    os.mkdir(fig_path)

sc.settings.figdir = fig_path

In [None]:
# laod combined file
rdata = sc.read_h5ad(os.path.join(out_path, 'Brain-RIBOmap-ct-bk.h5ad'))
rdata

In [None]:
# laod Combined file
cdata = sc.read_h5ad(os.path.join(out_path, 'Brain-combined-harmony.h5ad'))
cdata

In [None]:
# copy annotations of ribomap cells 
cdata.obs.loc[cdata.obs['protocol'] == 'RIBOmap', 'level_2'] = rdata.obs['level_2'].values
cdata.obs.loc[cdata.obs['protocol'] == 'RIBOmap', 'level_3'] = rdata.obs['level_3'].values

## label transfer (cosine distance with harmony pcs)

In [None]:
def label_transfer(adata, embedding='umap', field='level_2_code', metric='cosine', n_neighbors=100):
    
    # reclassify starmap cells 
    ref_cells = adata.obs.loc[adata.obs['protocol'] == 'RIBOmap', :].index
    query_cells = adata.obs.loc[adata.obs['protocol'] == 'STARmap', :].index

    # cdhp
    ref_cell_loc = adata[ref_cells, :].obsm[f'X_{embedding}']
    query_cell_loc = adata[query_cells, :].obsm[f'X_{embedding}']

    # ref annotation
    ref_cell_annot = adata.obs.loc[ref_cells, field].values
    
    from sklearn.neighbors import KNeighborsClassifier
    neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric=metric)
    neigh.fit(ref_cell_loc, ref_cell_annot)
    query_cell_predicted = neigh.predict(query_cell_loc)
    
    return query_cell_predicted

### level_2

In [None]:
# parameters
n_neighbors = 50

In [None]:
# create new label columns 
cdata.obs['level_2_code_cdhp'] = cdata.obs['level_2_code'].values

In [None]:
# conduct label transfer
predicted_label_cdhp = label_transfer(cdata, embedding='pca_harmony', field='level_2_code', metric='cosine', n_neighbors=n_neighbors)

In [None]:
# update to cdata
query_cells = cdata.obs.loc[cdata.obs['protocol'] == 'STARmap', :].index

cdata.obs.loc[query_cells, 'level_2_code_cdhp'] = predicted_label_cdhp

In [None]:
# update category and color map
cdata.obs['level_2_code_cdhp'] = cdata.obs['level_2_code_cdhp'].astype('category')
cdata.obs['level_2_code_cdhp'] = cdata.obs['level_2_code_cdhp'].cat.reorder_categories(cdata.obs['level_2_code'].cat.categories)

current_cpl = sns.color_palette(cdata.uns['level_2_color'])

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(21, 4))
axs = axs.flatten()
for i, sample in enumerate(cdata.obs['protocol-replicate'].cat.categories):
    current_data = cdata[cdata.obs["protocol-replicate"] == sample, :]
    ax = sc.pl.umap(cdata, show=False, size=(120000 / cdata.n_obs), ax=axs[i])
    sc.pl.umap(current_data, color='level_2_code_cdhp', frameon=False, ax=ax, size=(120000 / cdata.n_obs), title=f"{sample}", legend_loc=None,
               palette=current_cpl, save=False, show=False)
    
plt.show()

### level_3

In [None]:
# parameters
n_neighbors = 50

In [None]:
# create new label columns 
cdata.obs['level_3_cdhp'] = cdata.obs['level_3'].values

In [None]:
# conduct label transfer
predicted_label_cdhp = label_transfer(cdata, embedding='pca_harmony', field='level_3', metric='cosine', n_neighbors=n_neighbors)

In [None]:
# update to cdata
cdata.obs.loc[query_cells, 'level_3_cdhp'] = predicted_label_cdhp

In [None]:
# update category and color map
cdata.obs['level_3_cdhp'] = cdata.obs['level_3_cdhp'].astype('category')
cdata.obs['level_3_cdhp'] = cdata.obs['level_3_cdhp'].cat.reorder_categories(cdata.obs['level_3'].cat.categories)

current_cpl = sns.color_palette(cdata.uns['level_3_color'])

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(21, 4))
axs = axs.flatten()
for i, sample in enumerate(cdata.obs['protocol-replicate'].cat.categories):
    current_data = cdata[cdata.obs["protocol-replicate"] == sample, :]
    ax = sc.pl.umap(cdata, show=False, size=(120000 / cdata.n_obs), ax=axs[i])
    sc.pl.umap(current_data, color='level_3_cdhp', frameon=False, ax=ax, size=(120000 / cdata.n_obs), title=f"{sample}", legend_loc=None,
               palette=current_cpl, save=False, show=False)
    
plt.show()

### label cells with inconsistent label between level 2 and level 3 as mix

In [None]:
# create reference dict 
h_dict = {}
for current_type in cdata.obs.level_2_code.cat.categories:
    # print(f"===={current_type}====")
    current_subtypes = cdata.obs.loc[(cdata.obs['protocol'] == 'RIBOmap') & (cdata.obs['level_2_code'] == current_type), 'level_3'].unique().to_list()
    h_dict[current_type] = current_subtypes
    
h_dict

In [None]:
# change annotation to mix if level2 and level3 cannot match 
for current_type in cdata.obs.level_2_code.cat.categories:
    # print(f"===={current_type}====")
    current_obs = cdata.obs.loc[(cdata.obs['protocol'] == 'STARmap') & (cdata.obs['level_2_code'] == current_type), :]
    current_subtypes = h_dict[current_type]
    current_mix = current_obs.loc[~current_obs.level_3.isin(current_subtypes), :]
    print(current_mix.shape)
    
    # modify cell annotations 
    cdata.obs.loc[current_mix.index, 'level_1'] = 'Mix'
    cdata.obs.loc[current_mix.index, 'level_2'] = 'Mix'
    cdata.obs.loc[current_mix.index, 'level_3'] = 'Mix'
    cdata.obs.loc[current_mix.index, 'level_2_code'] = 'Mix'
    # print(current_mix)

In [None]:
# change annotation to mix if level2 and level3 cannot match 
for current_type in cdata.obs.level_2_code.cat.categories:
    # print(f"===={current_type}====")
    current_obs = cdata.obs.loc[(cdata.obs['protocol'] == 'STARmap') & (cdata.obs['level_2_code'] == current_type), :]
    current_subtypes = h_dict[current_type]
    current_mix = current_obs.loc[~current_obs.level_3.isin(current_subtypes), :]
    print(current_mix.shape)