# 1. Filtering Cells and Generating input

2023-05-05

In [None]:
# Import Packages

%load_ext autoreload
%autoreload 2

import os
import warnings 
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import anndata as ad
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from anndata import AnnData
from natsort import natsorted

# Customized packages
import starmap.sc_util as su
# test()

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

## Set path

In [None]:
# Set path
base_path = 'path/to/dataset/folder'

input_path = os.path.join(base_path, 'input')

out_path = os.path.join(base_path, 'output')
if not os.path.exists(out_path):
    os.mkdir(out_path)
    
fig_path = os.path.join(base_path, 'figures')
if not os.path.exists(fig_path):
    os.mkdir(fig_path)

sc.settings.figdir = fig_path

In [None]:
# load RIBOmap rep1 h5ad file generated from ClusterMap output
rdata_rep1.read_h5ad(os.path.join(input_path, 'RIBOmap-rep1-raw.h5ad'))
rdata_rep1

In [None]:
# load RIBOmap rep2 h5ad file generated from ClusterMap output
rdata_rep2.read_h5ad(os.path.join(input_path, 'RIBOmap-rep2-raw.h5ad'))
rdata_rep2

In [None]:
# load STARmap rep2 h5ad file generated from ClusterMap output
sdata_rep2.read_h5ad(os.path.join(input_path, 'STARmap-rep2-raw.h5ad'))
sdata_rep2

## Create filtered (cell) input file

### RIBOmap-rep1

In [None]:
# Plot top 20 most expressed genes (before qc)
sc.pl.highest_expr_genes(rdata_rep1, n_top=20)

In [None]:
# calculate pp metric
sc.pp.calculate_qc_metrics(rdata_rep1, inplace=True)

# Calculate max count for each gene
rdata_rep1.var['max_counts_sample'] = rdata_rep1.X.max(axis=0)

In [None]:
# Total counts describe statistics
rdata_rep1.obs['total_counts'].describe()

In [None]:
# max counts describe statistics
rdata_rep1.var['max_counts_sample'].describe()

In [None]:
# mad threshold
from scipy import stats
n = 3
mad = stats.median_absolute_deviation(rdata_rep1.obs['log1p_total_counts'], scale=1)
lower_bd = rdata_rep1.obs['log1p_total_counts'].median() - n*mad
upper_bd = rdata_rep1.obs['log1p_total_counts'].median() + n*mad

print(lower_bd)
print(upper_bd)
print(np.expm1(lower_bd))
print(np.expm1(upper_bd))

In [None]:
# mad threshold
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12,5))
sns.histplot(rdata_rep1.obs['total_counts'], ax=axs[0])
axs[0].axvline(np.expm1(lower_bd), c='r')
axs[0].axvline(np.expm1(upper_bd), c='r')

sns.histplot(rdata_rep1.obs['log1p_total_counts'], ax=axs[1])
axs[1].axvline(lower_bd, c='r')
axs[1].axvline(upper_bd, c='r')

# plt.savefig(os.path.join(fig_path, 'reads_filtering_threshold.pdf'))
plt.show()

In [None]:
# Total counts describe statistics
ncell_left = rdata_rep1.obs.loc[(rdata_rep1.obs['total_counts'] >= np.expm1(lower_bd)) & (rdata_rep1.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].shape
median_counts = rdata_rep1.obs.loc[(rdata_rep1.obs['total_counts'] >= np.expm1(lower_bd)) & (rdata_rep1.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].median()

print(f'With current threshold, there are {ncell_left[0]} cells left and median counts per cell is {median_counts}')

In [None]:
# Filter gene by max counts 
rdata_rep1.var['detected_sample'] = rdata_rep1.var['max_counts_sample'] > 2
rdata_rep1.var['highly_variable_sample'] = rdata_rep1.var['max_counts_sample'] > 2
print(rdata_rep1.var['detected_sample'].sum())

In [None]:
# Filtration (cell)
sc.pp.filter_cells(rdata_rep1, min_genes=10)
sc.pp.filter_cells(rdata_rep1, min_counts=np.expm1(lower_bd))
sc.pp.filter_cells(rdata_rep1, max_counts=np.expm1(upper_bd))

rdata_rep1.layers['raw'] = rdata_rep1.X.copy()
rdata_rep1.X.shape

In [None]:
# save combined h5ad
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata_rep1.write_h5ad(f"{out_path}/{date}-RIBOmap-rep1-3mad-filtered.h5ad")

### RIBOmap-rep2

In [None]:
# Plot top 20 most expressed genes (before qc)
sc.pl.highest_expr_genes(rdata_rep2, n_top=20)

In [None]:
# calculate pp metric
sc.pp.calculate_qc_metrics(rdata_rep2, inplace=True)

# Calculate max count for each gene
rdata_rep2.var['max_counts_sample'] = rdata_rep2.X.max(axis=0)

In [None]:
# Total counts describe statistics
rdata_rep2.obs['total_counts'].describe()

In [None]:
# max counts describe statistics
rdata_rep2.var['max_counts_sample'].describe()

In [None]:
# mad threshold
from scipy import stats
n = 3
mad = stats.median_absolute_deviation(rdata_rep2.obs['log1p_total_counts'], scale=1)
lower_bd = rdata_rep2.obs['log1p_total_counts'].median() - n*mad
upper_bd = rdata_rep2.obs['log1p_total_counts'].median() + n*mad

print(lower_bd)
print(upper_bd)
print(np.expm1(lower_bd))
print(np.expm1(upper_bd))

In [None]:
# mad threshold
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12,5))
sns.histplot(rdata_rep2.obs['total_counts'], ax=axs[0])
axs[0].axvline(np.expm1(lower_bd), c='r')
axs[0].axvline(np.expm1(upper_bd), c='r')

sns.histplot(rdata_rep2.obs['log1p_total_counts'], ax=axs[1])
axs[1].axvline(lower_bd, c='r')
axs[1].axvline(upper_bd, c='r')

# plt.savefig(os.path.join(fig_path, 'reads_filtering_threshold.pdf'))
plt.show()

In [None]:
# Total counts describe statistics
ncell_left = rdata_rep2.obs.loc[(rdata_rep2.obs['total_counts'] >= np.expm1(lower_bd)) & (rdata_rep2.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].shape
median_counts = rdata_rep2.obs.loc[(rdata_rep2.obs['total_counts'] >= np.expm1(lower_bd)) & (rdata_rep2.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].median()

print(f'With current threshold, there are {ncell_left[0]} cells left and median counts per cell is {median_counts}')

In [None]:
# Filter gene by max counts 
rdata_rep2.var['detected_sample'] = rdata_rep2.var['max_counts_sample'] > 2
rdata_rep2.var['highly_variable_sample'] = rdata_rep2.var['max_counts_sample'] > 2
print(rdata_rep2.var['detected_sample'].sum())

In [None]:
# Filtration (cell)
sc.pp.filter_cells(rdata_rep2, min_genes=10)
sc.pp.filter_cells(rdata_rep2, min_counts=np.expm1(lower_bd))
sc.pp.filter_cells(rdata_rep2, max_counts=np.expm1(upper_bd))

rdata_rep2.layers['raw'] = rdata_rep2.X.copy()
rdata_rep2.X.shape

In [None]:
# save combined h5ad
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata_rep2.write_h5ad(f"{out_path}/{date}-RIBOmap-rep2-3mad-filtered.h5ad")

### STARmap-rep2

In [None]:
# Plot top 20 most expressed genes (before qc)
sc.pl.highest_expr_genes(sdata_rep2, n_top=20)

In [None]:
# calculate pp metric
sc.pp.calculate_qc_metrics(sdata_rep2, inplace=True)

# Calculate max count for each gene
sdata_rep2.var['max_counts_sample'] = sdata_rep2.X.max(axis=0)

In [None]:
# Total counts describe statistics
sdata_rep2.obs['total_counts'].describe()

In [None]:
# max counts describe statistics
sdata_rep2.var['max_counts_sample'].describe()

In [None]:
# mad threshold
from scipy import stats
n = 3
mad = stats.median_absolute_deviation(sdata_rep2.obs['log1p_total_counts'], scale=1)
lower_bd = sdata_rep2.obs['log1p_total_counts'].median() - n*mad
upper_bd = sdata_rep2.obs['log1p_total_counts'].median() + n*mad

print(lower_bd)
print(upper_bd)
print(np.expm1(lower_bd))
print(np.expm1(upper_bd))

In [None]:
# mad threshold
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12,5))
sns.histplot(sdata_rep2.obs['total_counts'], ax=axs[0])
axs[0].axvline(np.expm1(lower_bd), c='r')
axs[0].axvline(np.expm1(upper_bd), c='r')

sns.histplot(sdata_rep2.obs['log1p_total_counts'], ax=axs[1])
axs[1].axvline(lower_bd, c='r')
axs[1].axvline(upper_bd, c='r')

# plt.savefig(os.path.join(fig_path, 'reads_filtering_threshold.pdf'))
plt.show()

In [None]:
# Total counts describe statistics
ncell_left = sdata_rep2.obs.loc[(sdata_rep2.obs['total_counts'] >= np.expm1(lower_bd)) & (sdata_rep2.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].shape
median_counts = sdata_rep2.obs.loc[(sdata_rep2.obs['total_counts'] >= np.expm1(lower_bd)) & (sdata_rep2.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].median()

print(f'With current threshold, there are {ncell_left[0]} cells left and median counts per cell is {median_counts}')

In [None]:
# Filter gene by max counts 
sdata_rep2.var['detected_sample'] = sdata_rep2.var['max_counts_sample'] > 2
sdata_rep2.var['highly_variable_sample'] = sdata_rep2.var['max_counts_sample'] > 2
print(sdata_rep2.var['detected_sample'].sum())

In [None]:
# Filtration (cell)
sc.pp.filter_cells(sdata_rep2, min_genes=10)
sc.pp.filter_cells(sdata_rep2, min_counts=np.expm1(lower_bd))
sc.pp.filter_cells(sdata_rep2, max_counts=np.expm1(upper_bd))

sdata_rep2.layers['raw'] = sdata_rep2.X.copy()
sdata_rep2.X.shape

In [None]:
# save combined h5ad
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
sdata_rep2.write_h5ad(f"{out_path}/{date}-STARmap-rep2-3mad-filtered.h5ad")

### correlation between RIBOmap and STARmap rep2

In [None]:
# Correlation between two protocol
rdata_vector = np.log2(np.array(rdata_rep2.X.sum(axis=0)))
sdata_vector = np.log2(np.array(sdata_rep2.X.sum(axis=0)))

from scipy import stats
p_corr = stats.pearsonr(rdata_vector, sdata_vector)

corre_df = pd.DataFrame({'RIBOmap': rdata_vector, 'STARmap': sdata_vector})
g = sns.lmplot(x='RIBOmap', y='STARmap', data=corre_df, scatter_kws={'s': 1}, line_kws={'color': 'r'})
g.set_axis_labels('RIBOmap - log2(total counts)', 'STARmap - log2(total counts)')
plt.title(f"Pearson's correlation coefficient: {round(p_corr[0], 3)}")
plt.savefig(os.path.join(fig_path, 'correlation_ribomap_starmap_rep2_3mad.pdf'))
# plt.show()

### correlation between two RIBOmap replicates

In [None]:
# Correlation between two protocol
rdata_vector = np.log2(np.array(rdata_rep1.X.sum(axis=0)))
rdata_vector_2 = np.log2(np.array(rdata_rep2.X.sum(axis=0)))

from scipy import stats
p_corr = stats.pearsonr(rdata_vector, rdata_vector_2)

corre_df = pd.DataFrame({'RIBOmap_rep2': rdata_vector, 'RIBOmap_rep1': rdata_vector_2})
g = sns.lmplot(x='RIBOmap_rep2', y='RIBOmap_rep1', data=corre_df, scatter_kws={'s': 1}, line_kws={'color': 'r'})
g.set_axis_labels('RIBOmap_rep2 - log2(total counts)', 'RIBOmap_rep1 - log2(total counts)')
plt.title(f"Pearson's correlation coefficient: {round(p_corr[0], 3)}")
plt.savefig(os.path.join(fig_path, 'correlation_ribomap_rep1_rep2_3mad.pdf'))
# plt.show()

## Combine datasets (n=3)

In [None]:
# combine three datasets 
adata = ad.concat([sdata_rep2, rdata_rep2, rdata_rep1])
adata.obs.index = [str(s) for s in range(adata.obs.shape[0])]
adata

In [None]:
# add var
adata.var['max_counts_rep1_RIBOmap'] = rdata_rep1.var['max_counts_sample'].values
adata.var['max_counts_rep2_RIBOmap'] = rdata_rep2.var['max_counts_sample'].values
adata.var['max_counts_rep2_STARmap'] = sdata_rep1.var['max_counts_sample'].values

In [None]:
# add obs
adata.obs['protocol-replicate'] = adata.obs['protocol'].astype(str) + '-' + adata.obs['replicate'].astype(str) 
adata.obs['protocol-replicate'] = adata.obs['protocol-replicate'].astype('category')

In [None]:
# save a backup file 
adata.write_h5ad(os.path.join(out_path, f'{date}-Brain-combined-3mad-filtered.h5ad'))

In [None]:
sns.boxplot(data=adata.var[["max_counts_rep1_RIBOmap", "max_counts_rep2_RIBOmap", "max_counts_rep2_STARmap"]],)
# ax = plt.gca()
# ax.set_ylim([0, 10])
plt.xticks(rotation=45)
plt.show()

In [None]:
adata.obs['protocol-replicate'].value_counts()

## Combine datasets (n=2, two RIBOmap sections)

In [None]:
# combine three datasets 
adata = ad.concat([rdata_rep2, rdata_rep1])
adata.obs.index = [str(s) for s in range(adata.obs.shape[0])]
adata

In [None]:
# add var
adata.var['max_counts_rep1_RIBOmap'] = rdata_rep1.var['max_counts_sample'].values
adata.var['max_counts_rep2_RIBOmap'] = rdata_rep2.var['max_counts_sample'].values

In [None]:
# add obs
adata.obs['protocol-replicate'] = adata.obs['protocol'].astype(str) + '-' + adata.obs['replicate'].astype(str) 
adata.obs['protocol-replicate'] = adata.obs['protocol-replicate'].astype('category')

In [None]:
# save a backup file 
adata.write_h5ad(os.path.join(out_path, f'{date}-Brain-RIBOmap-combined-3mad-filtered.h5ad'))

In [None]:
adata.obs['protocol-replicate'].value_counts()