In [None]:
%matplotlib inline
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
import sys
import os
import scanpy as sp
import anndata as an
sns.set(style='white',font_scale= 1.25)

In [None]:
th = {
    '9915':0.2,
    '0527':0.2,
    '0406':0.4,
    '1349':0.25,
    '1369':0.2,
    '1578':0.35,
    '1354':0.25,
    '1237':0.25,
    '1750':0.18,
    '1527':0.32,
    '1611':0.2,
    '1402':0.2,
    '1630':0.18,
    '1114':0.25,
    '0354':0.18,
    '1864':0.2,
    '0444':0.2,
    '1159':0.2,
    '1629':0.2,
    '1227':0.2,
    '1372':0.2,
    '1849':0.3,
    '1828':0.2,
    '1868':0.2,
    '1342':0.2,
    '1820':0.2,
    '1875':0.3,
    '1862':0.4,
    '1244':0.2,
    '1560':0.2,
    '1046':0.2,
    '0552':0.22,
    '1420':0.3,
    '1414':0.25
}

#### Reading the data

In [None]:
path = '../ad/'
D = an.read_h5ad(path + 'ad_raw_doublet_scores.h5ad')
D.shape

#### Plotting doublet scores

In [None]:
ii=0
A = D.obs.copy()
plt.figure(figsize=(24,16))
for sample in A['sample'].unique():
    ii+=1
    plt.subplot(6,6,ii)
    plt.title('sample: ' + sample)
    a = A[A['sample']==sample]['ds']
    plt.hist(a,bins=50,range=(0,1),label='doublet score hist.',color='k')
    plt.xlabel('doublet score')
    plt.ylabel('number of cells')
    plt.axvline(th[sample],c='r',label='doublet score threhold')
    plt.text(th[sample]*1.1,100,str(th[sample]))
plt.tight_layout()
sns.despine()

#### Removing doublets

In [None]:
A = D.obs.copy()
A['ts'] = A.apply(lambda x: x['ds']>=th[x['sample']], axis = 1)

In [None]:
D.obs =  A.copy()
D = D[~D.obs['ts'],:]

In [None]:
print(D.shape)

#### Defining QC genes

In [None]:
D.var['mt'] = D.var_names.str.startswith('MT-')
D.var['rb'] = D.var_names.str.startswith('RPL') | D.var_names.str.startswith('RPS') | D.var_names.str.startswith('MRPL') | D.var_names.str.startswith('MRPS')
D.var['ercc'] = D.var_names.str.startswith('ERCC')
D.var['qc'] = D.var['mt'] | D.var['rb']

#### Defining PC genes

In [None]:
gene_db = pd.read_csv(path + 'Homo_sapiens.GRCh38.101.chr_3.csv',sep=',')
pc = set(list(gene_db[gene_db['gene_biotype']=='protein_coding']['gene_name']))

In [None]:
D.var = D.var.reset_index()
D.var['pc'] = D.var.apply(lambda x: (x['index'] in pc) or (x['index'][:-2] in pc), axis = 1)
print(D.var['pc'].sum())
D.var = D.var.set_index('index')

#### Calculating and plotting QC metrics

In [None]:
sp.pp.calculate_qc_metrics(D, qc_vars=['mt','rb','ercc'], percent_top=None, log1p=False, inplace=True)

In [None]:
a = D.obs
a = a.sort_values('sample')
b = a.groupby('sample')[['n_genes_by_counts','total_counts','pct_counts_mt','pct_counts_rb','pct_counts_ercc']].mean()
plt.figure(figsize=(24,6))
ii=0
for col in b.columns:
    ii+=1
    plt.subplot(1,5,ii)
    plt.barh(range(len(b)),b[col],color='k')
    plt.yticks(range(len(b)),list(b.index))
    plt.xlabel(col.replace('_',' '))
plt.tight_layout()
sns.despine()

In [None]:
plt.figure(figsize=(24,6))
plt.subplot(1,4,1)
plt.hist(np.log(D.obs['total_counts'])/np.log(10),bins=100)
plt.xlabel('total_counts (log)')
plt.ylabel('numbder of cells')
plt.subplot(1,4,2)
plt.hist(np.log(D.obs['n_genes_by_counts'])/np.log(10),bins=100)
plt.xlabel('n_genes_by_counts (log)')
plt.ylabel('numbder of cells')
plt.subplot(1,4,3)
plt.scatter(D.obs['pct_counts_rb'],D.obs['pct_counts_mt'],s=1)
plt.xlim(-1,)
plt.ylim(-1,)
plt.subplot(1,4,4)
plt.hist(D.obs['pct_counts_mt'],bins=40,range=(0,40))
plt.xlim(0,40)
plt.xlabel('pct_counts_mt')
plt.tight_layout()
sns.despine()

#### QC filtering

In [None]:
sp.pp.filter_genes(D, min_cells=2)
D = D[D.obs.total_counts >= 1000, :]
D = D[D.obs.n_genes_by_counts >= 500, :]
D = D[D.obs.pct_counts_mt < 10, :]
D = D[D.obs.pct_counts_rb < 10, :]

In [None]:
plt.figure(figsize=(24,6))
plt.subplot(1,4,1)
plt.hist(np.log(D.obs['total_counts'])/np.log(10),bins=20,color='k')
plt.xlabel('log10(total counts)')
plt.ylabel('numbder of cells')
plt.subplot(1,4,2)
plt.hist(np.log(D.obs['n_genes_by_counts'])/np.log(10),bins=20,color='k')
plt.xlabel('log10(number of genes per cell)')
plt.ylabel('numbder of cells')
plt.subplot(1,4,3)
plt.hist(D.obs['pct_counts_rb'],bins=40,range=(0,40), color = 'k')
plt.xlim(0,10)
plt.xlabel('pct. of ribu. counts')
plt.xlim(0,)
plt.ylim(0,)
plt.subplot(1,4,4)
plt.hist(D.obs['pct_counts_mt'],bins=40,range=(0,40), color = 'k')
plt.xlim(0,10)
plt.xlabel('pct. of mito. counts')
plt.ylabel('numbder of cells')
sns.despine()
plt.tight_layout()

In [None]:
D_raw = D.copy()
sp.pp.normalize_total(D, target_sum=1e4)
sp.pp.log1p(D)
print(D.shape)

#### HVG selection

In [None]:
D_ = D.copy()
sp.pp.highly_variable_genes(D_,n_top_genes=5000)
sp.pl.highly_variable_genes(D_)
sns.despine()

In [None]:
D_ = D_[:, D_.var.highly_variable]
D_.shape

#### PCA

In [None]:
sp.tl.pca(D_, svd_solver='arpack',n_comps=48)

In [None]:
sp.pl.pca_variance_ratio(D_, log=True,n_pcs=48)

In [None]:
sp.pl.pca_variance_ratio(D_, log=False,n_pcs=48)
sns.despine()

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(list(range(48)),100*D_.uns['pca']['variance_ratio'],c='k')
sns.despine()
plt.ylabel('explained variance ratio (%)')
plt.xlabel('# PCA')

#### Batch correction

In [None]:
D_bbknn = D_

In [None]:
import bbknn

#### Umap, Leiden clustering and call type annotation

In [None]:
bbknn.bbknn(D_bbknn, batch_key='sample', n_pcs=20,neighbors_within_batch=1)

In [None]:
sp.tl.umap(D_bbknn,min_dist=0.3)

In [None]:
sp.tl.leiden(D_bbknn)
D_bbknn.obs['cluster'] = D_bbknn.obs['leiden']
del D_bbknn.obs['leiden']

In [None]:
markers = {
'Exci. Neurons':'SLC17A7',
'Inhi. Neurons':'GAD1',
'Oligo.':'MOG',
'Endo. cells': 'CLDN5',
'Astrocytes':'AQP4',
'Microglia':'CD74',
}

In [None]:
A = D_bbknn.obs.copy()
for marker in markers:
    g = markers[marker]
    A[marker] = D.obs_vector(g)
A = A[['cluster']+list(markers.keys())]
M = A.groupby('cluster').mean()
labels = {}
for index,x in M.iterrows():
    cluster = index
    max_ = 0
    label = ''
    for marker in markers:
        if x[marker]>max_:
            max_ = x[marker]
            label = marker
    labels[cluster] = label   
D_bbknn.obs['cell_type'] = D_bbknn.obs.apply(lambda x: labels[x['cluster']], axis = 1)

In [None]:
# manual annotation based on key markers, this should be adjusted according to cluster numbering
def subcluster(d):
    cell_type_ = []
    for index,x in d.obs.iterrows():
        if x['cell_type']=='Exci. Neurons':
            if int(x['cluster'])==17:
                cell_type_.append('en3')
            elif (int(x['cluster'])==23) or (int(x['cluster'])==19):
                cell_type_.append('en2')
            else:
                cell_type_.append('en1')
        elif x['cell_type']=='Inhi. Neurons':
            if (int(x['cluster'])==16) or (int(x['cluster'])==18) :
                cell_type_.append('i1')
            elif (int(x['cluster'])==9):
                cell_type_.append('i2')
            else:
                cell_type_.append('i3')
        else:
            cell_type_.append(x['cell_type'].lower())
    d.obs['cell_type_sub'] = cell_type_
    return d.obs

In [None]:
D_bbknn.obs = subcluster(D_bbknn)

In [None]:
D_bbknn.obs.columns

In [None]:
from numpy import random
def color(X,A,att,siz=10,cmap={},th=0,alpha=1,annotate=False,legend=True,sample=10000000000):
        colors = ['#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', '#fabebe', '#008080', '#e6beff', '#9a6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', '#ffffff', '#000000']
        X_ = X ; att_ = A
        d = set(att_.sort_values(att)[att])
        if len(cmap)==0:
            cmap = {} ; i = -1
            for val in d:
                if val not in cmap:
                    if val !='nan': i+=1 ; cmap[val] = colors[i%len(colors)]
                    else: cmap[val] = '#000000'
        for key in cmap:
            s = (att_[att] == key)
            if s.sum()>th:
                x = X_[s,:]
                if s.sum()<sample:
                    plt.scatter(x[:,0],x[:,1],color=cmap[key],label=key,s=siz,alpha=alpha)
                else:
                    number_of_rows = s.sum()
                    random_indices = np.random.choice(number_of_rows, size=sample, replace=False)
                    x_ = x[random_indices,:]
                    plt.scatter(x_[:,0],x_[:,1],color=cmap[key],label=key,s=siz,alpha=alpha)
                if annotate:
                    mx = x[:,0].mean()
                    my = x[:,1].mean()
                    plt.text(mx,my,key, bbox=dict(facecolor='white', alpha=0.75,boxstyle="round"))
        if legend:
            plt.legend(markerscale=6,framealpha=0.5,bbox_to_anchor=(1.05,1))


def color_real(X,A,att,siz=10,alpha=0.8):
        colors = sns.color_palette("tab10")
        palette = sns.color_palette("Reds",10)
        y = np.array(list(A[att]))
        y = (y - y.min())/(y.max() - y.min())
        color = []
        for y_ in y:
            if y_ == 0: color.append('#bdc3c7')
            else:
                idx = int(y_*10-0.000000000001)
                color.append(palette[idx])
        R = pd.DataFrame({'x':X[:,0],'y':X[:,1],'c':color,'z':y})
        R = R.sort_values('z').reset_index()
        plt.scatter(R['x'],R['y'],c=R['c'],s=siz,alpha=alpha)            

In [None]:
plt.figure(figsize=(32,8))
D0 = D_bbknn
plt.subplot(1,4,1)
color(D0.obsm['X_umap'],D0.obs,'cluster',legend=False,annotate=True)
sns.despine()
plt.subplot(1,4,2)
color(D0.obsm['X_umap'],D0.obs,'cell_type',legend=False,annotate=True)
plt.subplot(1,4,3)
color(D0.obsm['X_umap'],D0.obs,'cell_type_sub',legend=False,annotate=True)
plt.subplot(1,4,4)
color(D0.obsm['X_umap'],D0.obs,'sample',legend=False,annotate=False)
sns.despine()

In [None]:
D_bbknn.obs['sample'].value_counts()

In [None]:
meta = pd.read_pickle(path + 'meta.pkl')
A = D_bbknn.obs.copy()
if 'index' in A.columns:
    del A['index']
if 'level_0' in A.columns:
    del A['level_0']
A = A.reset_index().reset_index().set_index('sample')
A['cell_id'] = A['level_0']
del A['level_0']
del A['type']
A = A.join(meta)
A = A.reset_index().set_index('index').sort_values('cell_id')

In [None]:
D_bbknn.obs = A.copy()

In [None]:
D_raw.obs = D_bbknn.obs.copy()
D_raw.obsm = D_bbknn.obsm.copy()
D_raw.uns = D_bbknn.uns.copy()
D.obs = D_bbknn.obs.copy()
D.obsm = D_bbknn.obsm.copy()
D.uns = D_bbknn.uns.copy()
D.write(path + 'ad_cpm_annotated.h5ad')
D_raw.write(path + 'ad_raw_annotated.h5ad')
D_bbknn.write(path + 'ad_red_annotated.h5ad')