# Subtype classification

In [None]:
# Import Packages

%load_ext autoreload
%autoreload 2

import os
import warnings 
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from skimage.filters import threshold_otsu, gaussian
from skimage.morphology import remove_small_objects
from matplotlib.colors import ListedColormap
from anndata import AnnData, concat

# Customized packages
from starmap.utilities import *
from starmap.sequencing import *
from starmap.obj import STARMapDataset, load_data
# import starmap.analyze as anz
# import starmap.viz as viz
import starmap.sc_util as su

sc.logging.print_header()
# test()

In [None]:
from statannotations.Annotator import Annotator

## Input

In [None]:
# Set path
base_path = 'Z:/Data/Analyzed/2022-01-03-Hu-AD/'
out_path = os.path.join(base_path, 'output')
fig_path = os.path.join(base_path, 'figures')

out_path = os.path.join(base_path, 'output')
if not os.path.exists(out_path): 
    os.mkdir(out_path)
    
fig_path = os.path.join(base_path, 'figures')
if not os.path.exists(fig_path):
    os.mkdir(fig_path)

In [None]:
# Load new data
adata = sc.read_h5ad(os.path.join(out_path, '2022-04-06-Hu-AD-stardist-scaled.h5ad'))
adata

In [None]:
# odata = adata[adata.obs['top_level'] == 'Astro', ]
# odata.write_h5ad(os.path.join(out_path, 'astro-test.h5ad'))

## Gene filtering

In [None]:
cell_types = ['CTX-Ex',
 'Inh',
 'CA1',
 'CA2',
 'CA3',
 'DG',
 'Astro',
 'Endo',
 'Micro',
 'Oligo_OPC',
 'SMC',
 'LHb']

# merge Oligo-OPC label
adata.obs['top_level_filtering'] = adata.obs['top_level'].values
adata.obs['top_level_filtering'] = adata.obs['top_level_filtering'].astype(object)
adata.obs.loc[adata.obs['top_level'] == 'Oligo', 'top_level_filtering'] = 'Oligo_OPC'
adata.obs.loc[adata.obs['top_level'] == 'OPC', 'top_level_filtering'] = 'Oligo_OPC'
adata.obs['top_level_filtering'] = adata.obs['top_level_filtering'].astype('category')
adata.obs['top_level_filtering'] = adata.obs['top_level_filtering'].cat.reorder_categories(cell_types)

# compute pct matrix
pct_df = pd.DataFrame(columns=cell_types, index=adata.var.index)
for current_type in cell_types:
    hdata = adata[adata.obs['top_level_filtering'] == current_type, ]
    hdata.X = hdata.layers['raw'].copy()
    sc.pp.calculate_qc_metrics(hdata, inplace=True)
    current_pct = 100 - hdata.var['pct_dropout_by_counts']
    pct_df[current_type] = current_pct

In [None]:
# compute pct matrix
cell_types = adata.obs['top_level'].cat.categories.to_list()
pct_df = pd.DataFrame(columns=cell_types, index=adata.var.index)
for current_type in cell_types:
    hdata = adata[adata.obs['top_level'] == current_type, ]
    hdata.X = hdata.layers['raw'].copy()
    sc.pp.calculate_qc_metrics(hdata, inplace=True)
    current_pct = 100 - hdata.var['pct_dropout_by_counts']
    pct_df[current_type] = current_pct

In [None]:
# get genes 
ingroup_threshold = 5
outgroup_threshold = 80
filtered_markers_dict = {}

for current_type in cell_types:
    ingroup_vec = pct_df[current_type] > ingroup_threshold
    current_pct_df = pct_df.loc[:, pct_df.columns != current_type]
    current_col_sum = (current_pct_df > outgroup_threshold).sum(axis=1)
    outgroup_vec = current_col_sum > 0
    current_final_vec = (ingroup_vec & ~outgroup_vec)
    current_count = current_final_vec.value_counts()
    filtered_markers_dict[current_type] = adata.var.loc[current_final_vec, :].index.to_list()
    print(current_type, ' - ', current_count[True])

In [None]:
confusion_df = pd.DataFrame(columns=cell_types, index=cell_types)
for current_type in cell_types:
    current_genes = filtered_markers_dict[current_type]
    current_vec = []
    for compare_type in cell_types:
        compare_genes = filtered_markers_dict[compare_type]
        counts = len(set(current_genes) & set(compare_genes))
        current_vec.append(counts)
    confusion_df[current_type] = current_vec
    
fig, ax = plt.subplots(figsize=(12,10))
sns.heatmap(confusion_df, annot=True, fmt='d')

In [None]:
# save marker dict 

current_set_id = f'in_{ingroup_threshold}_out_{outgroup_threshold}'
current_out_path = os.path.join(out_path, current_set_id)
if not os.path.exists(current_out_path): 
    os.mkdir(current_out_path)
    
for current_type in cell_types:
    current_genes = filtered_markers_dict[current_type]
    current_df = pd.DataFrame(current_genes, columns=['Gene'])
    current_df.to_csv(os.path.join(current_out_path, f'{current_type}.csv'), index=False)

In [None]:
pct_df

## Subtype clustering

### Astro

In [None]:
# Subset
sub_id = 'Astro'
current_genes = filtered_markers_dict[sub_id]
current_cells = adata.obs['top_level'] == sub_id
sdata = adata[current_cells, current_genes]

# sdata = adata[current_cells, :]

print(sdata.X.max())
sdata

In [None]:
# remove aqp4
sdata = sdata[:, sdata.var.index != 'Aqp4']
sdata

In [None]:
# setup the output path 
sub_level_fig_path = os.path.join(fig_path, 'subclustering-test', sub_id)
if not os.path.exists(sub_level_fig_path):
    os.mkdir(sub_level_fig_path)

In [None]:
# Redo preprocessing
sdata.X = sdata.layers['raw'].copy()
del sdata.layers

sdata.layers['raw'] = sdata.X.copy()

# # Normalization scaling
# sc.pp.normalize_total(sdata)
# sc.pp.log1p(sdata)

# sdata.layers['norm'] = sdata.X.copy()
# sdata.raw = sdata

# # sc.pp.highly_variable_genes(sdata, min_mean=0.01, max_mean=3, min_disp=0.5)
# # sc.pl.highly_variable_genes(sdata)

# # Scale data to unit variance and zero mean
# sc.pp.scale(sdata)
# sdata.layers['scaled'] = sdata.X.copy()

# Batch correction
sc.pp.combat(sdata, key='batch')
sc.pp.regress_out(sdata, 'total_counts')
sdata.layers['corrected'] = sdata.X.copy()

In [None]:
# Run PCA
sdata.X = sdata.layers['corrected'].copy()
sc.tl.pca(sdata, svd_solver='full', use_highly_variable=True, zero_center=True)

# Plot explained variance 
sc.pl.pca_variance_ratio(sdata, log=False)

# Plot PCA
sc.pl.pca(sdata, color='sample')
sc.pl.pca(sdata, color='time-group')
sc.pl.pca(sdata, color='Gfap')
sc.pl.pca(sdata, color='Vim')
sc.pl.pca_loadings(sdata, components = '1,2,3,4,5')

#### test

In [None]:
for i in range(2, 30, 1):
    print(i)
    sc.pp.neighbors(sdata, n_neighbors=15, n_pcs=i)
    sc.tl.umap(sdata, min_dist=.1)
    if i > 10:
        sc.tl.leiden(sdata, resolution = .4, random_state=0)
    elif i > 15:
        sc.tl.leiden(sdata, resolution = .5, random_state=0)
    elif i < 6:
        sc.tl.leiden(sdata, resolution = .2, random_state=0)
    else:
        sc.tl.leiden(sdata, resolution = .3, random_state=0)
        
    fig, axs = plt.subplots(1, 3, figsize=(15, 4))
    axs = axs.flatten()
    sc.pl.umap(sdata, color='group', title=f'#pcs: {i}', ax=axs[0], show=False)
    sc.pl.umap(sdata, color='leiden', title=f'leiden', ax=axs[1], show=False)
    
    sub_count_sample = pd.DataFrame(index=sdata.obs['leiden'].cat.categories.to_list(), columns=sdata.obs['sample'].cat.categories.to_list())

    for sample in sdata.obs['sample'].cat.categories:
        # print(sample)
        current_obs = sdata.obs.loc[sdata.obs['sample'] == sample, :]
        current_count = current_obs['leiden'].value_counts()
        sub_count_sample.loc[:, sample] = current_count
    
    sub_count_sample['leiden'] = sub_count_sample.index.values
    sub_count_sample_melt = pd.melt(sub_count_sample, id_vars=['leiden'], value_vars=adata.obs['sample'].cat.categories.to_list())
    sub_count_sample_melt.columns = ['leiden', 'sample', 'value']
    sub_count_sample_melt['time-group'] = sub_count_sample_melt['sample'].values
    time_group_dict = {
        'ADmouse_9723': '8months-disease', 'ADmouse_9735': '8months-control', 'ADmouse_9494': '13months-disease', 'ADmouse_9498': '13months-control', 
        'ADmouse_9723_2': '8months-disease', 'ADmouse_9707': '8months-control', 'ADmouse_11346': '13months-disease', 'ADmouse_11351': '13months-control', 
                      }
    sub_count_sample_melt['time-group'] = sub_count_sample_melt['time-group'].map(time_group_dict)
    sub_count_sample_melt[['time', 'group']] = sub_count_sample_melt['time-group'].str.split('-', 1, expand=True)
    sub_count_sample_melt['group'] = sub_count_sample_melt['group'].astype('category')
    sub_count_sample_melt['type-group'] = sub_count_sample_melt['leiden'].astype(str) + '-' + sub_count_sample_melt['group'].astype(str)
    sub_count_sample_melt['type-group'] = sub_count_sample_melt['type-group'].astype('category')
    sub_count_sample_melt['leiden'] = sub_count_sample_melt['leiden'].astype('category')
    sub_count_sample_melt['code'] = sub_count_sample_melt['leiden'].cat.codes
    sub_count_sample_melt.loc[sub_count_sample_melt['group'] == 'disease', 'code'] = sub_count_sample_melt.loc[sub_count_sample_melt['group'] == 'disease', 'code'] + 0.2
    sub_count_sample_melt.loc[sub_count_sample_melt['group'] == 'control', 'code'] = sub_count_sample_melt.loc[sub_count_sample_melt['group'] == 'control', 'code'] - 0.2

    # plot barplot

    cf_pl = sns.color_palette(['#00bfc4', '#f8766d'])
    alternative = 'less'

    sns.barplot(x='leiden', y='value', hue='group', data=sub_count_sample_melt, palette=cf_pl, ax=axs[2])
    sns.scatterplot(x="code", y="value", hue='group', data=sub_count_sample_melt, s=70, facecolors='white', edgecolor='black', ax=axs[2], legend=False, zorder=2)

    pairs = [((current_type, 'disease'), (current_type, 'control')) for current_type in sdata.obs['leiden'].cat.categories]

    annot = Annotator(axs[2], pairs, plot='barplot', data=sub_count_sample_melt, x='leiden', y='value', hue='group')
    annot.configure(test='t-test_ind', text_format='star', loc='inside', verbose=2)
    annot.apply_test(alternative=alternative).annotate()


    plt.show()

#### final run

In [None]:
%%time
# Embedding parameters
emb_dict = {
    sub_id: {'n_neighbors': 30, 'n_pcs':16, 'min_dist': 1, 'cluster_resolution': .4},
           }

# Computing the neighborhood graph

n_neighbors = emb_dict[sub_id]['n_neighbors']
n_pcs = emb_dict[sub_id]['n_pcs']
min_dist = emb_dict[sub_id]['min_dist']
cluster_resolution = emb_dict[sub_id]['cluster_resolution']

test_id = f'genethres_{ingroup_threshold}_{outgroup_threshold}_pc{n_pcs}_cr{cluster_resolution}'
save_embedding = True

sc.pp.neighbors(sdata, n_neighbors=n_neighbors, n_pcs=n_pcs)

# Run UMAP
sc.tl.tsne(sdata, n_pcs=n_pcs, perplexity=5)
sc.tl.umap(sdata, min_dist=min_dist)
sc.tl.diffmap(sdata, n_comps=n_pcs)

# Run leiden cluster
sc.tl.leiden(sdata, resolution = cluster_resolution, random_state=0)

In [None]:
# # merge clusters
# sdata.obs['leiden'] = sdata.obs['leiden'].astype(object)
# sdata.obs.loc[sdata.obs['leiden'] == '4', 'leiden'] = '0'
# sdata.obs.loc[sdata.obs['leiden'] == '2', 'leiden'] = '1'
# sdata.obs.loc[sdata.obs['leiden'] == '5', 'leiden'] = '3'
# sdata.obs['leiden'] = sdata.obs['leiden'].astype('category')

In [None]:
%%time
# Plot UMAP with cluster labels 
fig, axs = plt.subplots(3, 3, figsize=(15, 12))
axs = axs.flatten()
sc.pl.tsne(sdata, color='leiden', ax=axs[0], show=False, legend_loc=None)
sc.pl.umap(sdata, color='leiden', ax=axs[1], show=False, legend_loc=None)
sc.pl.diffmap(sdata, color='leiden', ax=axs[2], show=False)
sc.pl.tsne(sdata, color='Gfap', ax=axs[3], show=False)
sc.pl.umap(sdata, color='Gfap', ax=axs[4], show=False)
sc.pl.diffmap(sdata, color='Gfap', ax=axs[5], show=False)
sc.pl.tsne(sdata, color='Vim', ax=axs[6], show=False)
sc.pl.umap(sdata, color='Vim', ax=axs[7], show=False)
sc.pl.diffmap(sdata, color='Vim', ax=axs[8], show=False)
plt.show()

n_clusters = sdata.obs['leiden'].unique().shape[0]

if save_embedding:
    
    # Save log
    with open(f'{sub_level_fig_path}/log_{test_id}.txt', 'w') as f:
        f.write(f"""Number of neighbor: {n_neighbors}
    Number of PC: {n_pcs}
    Resolution: {cluster_resolution}
    Min-distance: {min_dist}
    Number of clusters: {n_clusters}""")

    print(f"""Number of neighbor: {n_neighbors}
Number of PC: {n_pcs}
Resolution: {cluster_resolution}
Min-distance: {min_dist}
Number of clusters: {n_clusters}""")
    
    # save embeddings
    np.savetxt(f'{sub_level_fig_path}/pca_{test_id}.csv', sdata.obsm['X_pca'], delimiter=",")
    np.savetxt(f'{sub_level_fig_path}/tsne_{test_id}.csv', sdata.obsm['X_tsne'], delimiter=",")
    np.savetxt(f'{sub_level_fig_path}/umap_{test_id}.csv', sdata.obsm['X_umap'], delimiter=",")
    np.savetxt(f'{sub_level_fig_path}/diffmap_{test_id}.csv', sdata.obsm['X_diffmap'], delimiter=",")
    
# Find gene markers 
# Add log layer
sdata.layers['log_raw'] = np.log1p(sdata.layers['raw'])
sc.pp.normalize_total(sdata, layer='log_raw')

# Find gene markers for each cluster
sc.tl.rank_genes_groups(sdata, 'leiden', method='wilcoxon', layer='log_raw', pts=True, use_raw=False, n_genes=sdata.shape[1])

# Filter markers
sc.tl.filter_rank_genes_groups(sdata, min_fold_change=.1, min_in_group_fraction=0.15, max_out_group_fraction=0.85)

marker_genes_dict = {}

# Add other markers
common_markers = ['Aldoc', 'Slc1a3']
marker_genes_dict[sub_id] = common_markers

temp = pd.DataFrame(sdata.uns['rank_genes_groups_filtered']['names']).head(10)
temp_genes = []
for i in range(temp.shape[1]):
    current_genes = temp.iloc[:, i].to_list()
    current_genes = [x for x in current_genes if str(x) != 'nan']
    current_genes = [x for x in current_genes if x not in temp_genes]
    
    for j in current_genes:
        temp_genes.append(j)
        
    current_key = temp.columns[i]
    marker_genes_dict[current_key] = current_genes

sdata.obs['leiden-replicate'] = sdata.obs['leiden'].astype(str) + '-' + sdata.obs['replicate'].astype(str)
sdata.obs['leiden-replicate'] = sdata.obs['leiden-replicate'].astype('category')

sc.pl.dotplot(sdata, marker_genes_dict, 'leiden-replicate', dendrogram=False, cmap='Reds', standard_scale='group', swap_axes=True)

In [None]:
sc.pl.umap(sdata, color='replicate')
sc.pl.umap(sdata, color='total_counts')
sc.pl.umap(sdata, color='Luzp2')
sc.pl.diffmap(sdata, color='Luzp2')

In [None]:
fig, ax = plt.subplots(figsize=(10,10), ncols=1, nrows=1)
a = pd.crosstab(sdata.obs.cell_type, sdata.obs.leiden)
sns.heatmap(a, annot=True, fmt='d', ax=ax)

In [None]:
sc.pl.dotplot(sdata, marker_genes_dict, 'leiden-replicate', dendrogram=False, cmap='Reds', standard_scale='group', swap_axes=True, layer='raw', vmax=1)
sc.pl.matrixplot(sdata, marker_genes_dict, 'leiden-replicate', dendrogram=False, cmap='Reds', standard_scale='group', swap_axes=True)


In [None]:
sc.pl.umap(sdata, color='Aqp4')
sc.pl.diffmap(sdata, color='Aqp4')

In [None]:
customized_dict = {'0': ['Aldoc', 'Slc1a3', 'Ttyh1', 'Glud1'],
                  '1': ['Tspan7', 'Htra1', 'Caskin1', 'Cxcl14', 'S1pr1', 'Ndrg2', 'Vegfa', 'Trim9'],
                  '2': ['Gfap', 'Vim', 'Clu', 'Igfbp5', 'Cd63', 'Apoe', 'Ntrk2', ]}

sc.pl.dotplot(sdata, customized_dict, 'leiden-replicate', dendrogram=False, cmap='bwr', standard_scale='group', swap_axes=True, 
              layer='raw', vmax=.25, vmin=0)
sc.pl.matrixplot(sdata, customized_dict, 'leiden-replicate', dendrogram=False, cmap='Reds', standard_scale='gene', swap_axes=True)

In [None]:
markers = ['Aldoc', 'Slc1a3', 'Gfap', 'Vim', 'Clu', 'Igfbp5', 'Cd63', 'Apoe', 'Fxyd1', 'Ntrk2', 'Ctsb']
sc.pl.stacked_violin(sdata, markers, groupby='leiden', dendrogram=False, swap_axes=True, layer='raw', log=False, vmax=.01,
                        cmap='bwr')

In [None]:
current_cell_type = '2'

current_df = sc.get.rank_genes_groups_df(sdata, group=current_cell_type, key='rank_genes_groups_filtered')
current_df.head(50)

In [None]:
# save clustering results

with pd.ExcelWriter(os.path.join(fig_path, f'{sub_id}_clustering_markers.xlsx'), mode='w') as writer:  
    for current_cell_type in sdata.obs.leiden.cat.categories:
        current_df = sc.get.rank_genes_groups_df(sdata, group=current_cell_type, key='rank_genes_groups_filtered')
        current_df.to_excel(writer, sheet_name=f'{sub_id}_{current_cell_type}')
    

In [None]:
sub_count_sample = pd.DataFrame(index=sdata.obs['leiden'].cat.categories.to_list(), columns=sdata.obs['sample'].cat.categories.to_list())

for sample in sdata.obs['sample'].cat.categories:
    print(sample)
    current_obs = sdata.obs.loc[sdata.obs['sample'] == sample, :]
    current_count = current_obs['leiden'].value_counts()
    sub_count_sample.loc[:, sample] = current_count

sub_count_sample['leiden'] = sub_count_sample.index.values
sub_count_sample_melt = pd.melt(sub_count_sample, id_vars=['leiden'], value_vars=adata.obs['sample'].cat.categories.to_list())
sub_count_sample_melt.columns = ['leiden', 'sample', 'value']
sub_count_sample_melt['time-group'] = sub_count_sample_melt['sample'].values
time_group_dict = {
    'ADmouse_9723': '8months-disease', 'ADmouse_9735': '8months-control', 'ADmouse_9494': '13months-disease', 'ADmouse_9498': '13months-control', 
    'ADmouse_9723_2': '8months-disease', 'ADmouse_9707': '8months-control', 'ADmouse_11346': '13months-disease', 'ADmouse_11351': '13months-control', 
                  }
sub_count_sample_melt['time-group'] = sub_count_sample_melt['time-group'].map(time_group_dict)
sub_count_sample_melt[['time', 'group']] = sub_count_sample_melt['time-group'].str.split('-', 1, expand=True)
sub_count_sample_melt['group'] = sub_count_sample_melt['group'].astype('category')
sub_count_sample_melt['type-group'] = sub_count_sample_melt['leiden'].astype(str) + '-' + sub_count_sample_melt['group'].astype(str)
sub_count_sample_melt['type-group'] = sub_count_sample_melt['type-group'].astype('category')

sub_count_sample_melt['leiden'] = sub_count_sample_melt['leiden'].astype('category')
sub_count_sample_melt['code'] = sub_count_sample_melt['leiden'].cat.codes
sub_count_sample_melt.loc[sub_count_sample_melt['group'] == 'disease', 'code'] = sub_count_sample_melt.loc[sub_count_sample_melt['group'] == 'disease', 'code'] + 0.2
sub_count_sample_melt.loc[sub_count_sample_melt['group'] == 'control', 'code'] = sub_count_sample_melt.loc[sub_count_sample_melt['group'] == 'control', 'code'] - 0.2

# plot barplot

sns.reset_orig()
cf_pl = sns.color_palette(['#00bfc4', '#f8766d'])
fig, ax = plt.subplots(figsize=(10, 7))
alternative = 'less'

sns.barplot(x='leiden', y='value', hue='group', data=sub_count_sample_melt, palette=cf_pl, ax=ax)
sns.scatterplot(x="code", y="value", hue='group', data=sub_count_sample_melt, s=70, facecolors='white', edgecolor='black', ax=ax, legend=False, zorder=2)

pairs = [((current_type, 'disease'), (current_type, 'control')) for current_type in sdata.obs['leiden'].cat.categories]

annot = Annotator(ax, pairs, plot='barplot', data=sub_count_sample_melt, x='leiden', y='value', hue='group')
annot.configure(test='t-test_ind', text_format='star', loc='inside', verbose=2)
annot.apply_test(alternative=alternative).annotate()

# plt.savefig(os.path.join(fig_path, f'cluster_freq_{sub_id}_group.pdf'))

plt.show()

In [None]:
sub_count_sample

In [None]:
sns.scatterplot('x', 'y', hue='top_level', data=adata.obs.loc[adata.obs['sample'] == 'ADmouse_9494', :])

In [None]:
sns.scatterplot('x', 'y', hue='leiden', data=sdata.obs.loc[sdata.obs['sample'] == 'ADmouse_9494', :])

In [None]:
adata.obs['cell_type_test'] = 'NA'

In [None]:
# save current results 
sdata.obs['cell_type_test'] = sdata.obs['leiden'].values
sdata.obs['cell_type_test'] = sdata.obs['cell_type_test'].map({'0': 'Astro1', '1': 'Astro2', '2': 'Astro3'})
adata.obs.loc[adata.obs['top_level'] == sub_id, 'cell_type_test'] = sdata.obs['cell_type_test'].values

In [None]:
adata.obs.cell_type_test.value_counts()

In [None]:
out_path

In [None]:
adata.write_h5ad(os.path.join(out_path, '2022-08-28-Hu-AD-stardist-scaled.h5ad'))