# 1. Filtering Cells and Generating input

2022-10-23

In [None]:
# Import Packages

%load_ext autoreload
%autoreload 2

import os
import warnings 
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import anndata as ad
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from anndata import AnnData
from natsort import natsorted

# Customized packages
import starmap.sc_util as su
# test()

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

## Set path

In [None]:
# Set path
base_path = 'Z:/Data/Analyzed/2022-09-05-Hu-Tissue/'

input_path = os.path.join(base_path, 'input')

out_path = os.path.join(base_path, 'output')
if not os.path.exists(out_path):
    os.mkdir(out_path)
    
fig_path = os.path.join(base_path, 'figures')
if not os.path.exists(fig_path):
    os.mkdir(fig_path)

sc.settings.figdir = fig_path

In [None]:
rdata.read_h5ad(os.path.join(input_path, 'rep2/batch2', '2022-10-23-Brain-RIBOmap-raw.h5ad'))
rdata

In [None]:
sdata.read_h5ad(os.path.join(input_path, 'rep2/batch2', '2022-10-23-Brain-STARmap-raw.h5ad'))
sdata

## Create filtered (cell) input file (rep2)

### RIBOmap

In [None]:
# Plot top 20 most expressed genes (before qc)
sc.pl.highest_expr_genes(rdata, n_top=20)

In [None]:
# calculate pp metric
sc.pp.calculate_qc_metrics(rdata, inplace=True)

# Calculate max count for each gene
rdata.var['max_counts_sample'] = rdata.X.max(axis=0)

In [None]:
# Total counts describe statistics
rdata.obs['total_counts'].describe()

In [None]:
# max counts describe statistics
rdata.var['max_counts_sample'].describe()

In [None]:
# mad threshold
from scipy import stats
n = 3
mad = stats.median_absolute_deviation(rdata.obs['log1p_total_counts'], scale=1)
lower_bd = rdata.obs['log1p_total_counts'].median() - n*mad
upper_bd = rdata.obs['log1p_total_counts'].median() + n*mad

print(lower_bd)
print(upper_bd)
print(np.expm1(lower_bd))
print(np.expm1(upper_bd))

In [None]:
# mad threshold
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12,5))
sns.histplot(rdata.obs['total_counts'], ax=axs[0])
axs[0].axvline(np.expm1(lower_bd), c='r')
axs[0].axvline(np.expm1(upper_bd), c='r')

sns.histplot(rdata.obs['log1p_total_counts'], ax=axs[1])
axs[1].axvline(lower_bd, c='r')
axs[1].axvline(upper_bd, c='r')

# plt.savefig(os.path.join(fig_path, 'reads_filtering_threshold.pdf'))
plt.show()

In [None]:
# Total counts describe statistics
ncell_left = rdata.obs.loc[(rdata.obs['total_counts'] >= np.expm1(lower_bd)) & (rdata.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].shape
median_counts = rdata.obs.loc[(rdata.obs['total_counts'] >= np.expm1(lower_bd)) & (rdata.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].median()

print(f'With current threshold, there are {ncell_left[0]} cells left and median counts per cell is {median_counts}')

In [None]:
# manual threshold
lower_bd_manual = 40
upper_bd_manual = 1200

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12,5))
sns.histplot(rdata.obs['total_counts'], ax=axs[0])
axs[0].axvline(lower_bd_manual, c='r')
axs[0].axvline(upper_bd_manual, c='r')

sns.histplot(rdata.obs['log1p_total_counts'], ax=axs[1])
axs[1].axvline(np.log1p(lower_bd_manual), c='r')
axs[1].axvline(np.log1p(upper_bd_manual), c='r')

# plt.savefig(os.path.join(fig_path, 'reads_filtering_threshold.pdf'))
plt.show()

In [None]:
# Total counts describe statistics
ncell_left = rdata.obs.loc[(rdata.obs['total_counts'] >= lower_bd_manual) & (rdata.obs['total_counts'] <= upper_bd_manual), 'total_counts'].shape
median_counts = rdata.obs.loc[(rdata.obs['total_counts'] >= lower_bd_manual) & (rdata.obs['total_counts'] <= upper_bd_manual), 'total_counts'].median()

print(f'With current threshold, there are {ncell_left[0]} cells left and median counts per cell is {median_counts}')

In [None]:
# Filter gene by max counts 
rdata.var['detected_sample'] = rdata.var['max_counts_sample'] > 2
rdata.var['highly_variable_sample'] = rdata.var['max_counts_sample'] > 2
print(rdata.var['detected_sample'].sum())

In [None]:
# Filtration (cell)
sc.pp.filter_cells(rdata, min_genes=10)
sc.pp.filter_cells(rdata, min_counts=np.expm1(lower_bd))
sc.pp.filter_cells(rdata, max_counts=np.expm1(upper_bd))

rdata.layers['raw'] = rdata.X.copy()
rdata.X.shape

In [None]:
# save combined h5ad
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata.write_h5ad(f"{out_path}/{date}-Brain-RIBOmap-rep2-3mad-filtered.h5ad")

### STARmap

In [None]:
# Plot top 20 most expressed genes (before qc)
sc.pl.highest_expr_genes(sdata, n_top=20)

In [None]:
# calculate pp metric
sc.pp.calculate_qc_metrics(sdata, inplace=True)

# Calculate max count for each gene
sdata.var['max_counts_sample'] = sdata.X.max(axis=0)

In [None]:
# Total counts describe statistics
sdata.obs['total_counts'].describe()

In [None]:
# max counts describe statistics
sdata.var['max_counts_sample'].describe()

In [None]:
# mad threshold
from scipy import stats
n = 3
mad = stats.median_absolute_deviation(sdata.obs['log1p_total_counts'], scale=1)
lower_bd = sdata.obs['log1p_total_counts'].median() - n*mad
upper_bd = sdata.obs['log1p_total_counts'].median() + n*mad

print(lower_bd)
print(upper_bd)
print(np.expm1(lower_bd))
print(np.expm1(upper_bd))

In [None]:
# mad threshold
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12,5))
sns.histplot(sdata.obs['total_counts'], ax=axs[0])
axs[0].axvline(np.expm1(lower_bd), c='r')
axs[0].axvline(np.expm1(upper_bd), c='r')

sns.histplot(sdata.obs['log1p_total_counts'], ax=axs[1])
axs[1].axvline(lower_bd, c='r')
axs[1].axvline(upper_bd, c='r')

# plt.savefig(os.path.join(fig_path, 'reads_filtering_threshold.pdf'))
plt.show()

In [None]:
# Total counts describe statistics
ncell_left = sdata.obs.loc[(sdata.obs['total_counts'] >= np.expm1(lower_bd)) & (sdata.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].shape
median_counts = sdata.obs.loc[(sdata.obs['total_counts'] >= np.expm1(lower_bd)) & (sdata.obs['total_counts'] <= np.expm1(upper_bd)), 'total_counts'].median()

print(f'With current threshold, there are {ncell_left[0]} cells left and median counts per cell is {median_counts}')

In [None]:
# manual threshold
lower_bd_manual = 40
upper_bd_manual = 1200

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12,5))
sns.histplot(sdata.obs['total_counts'], ax=axs[0])
axs[0].axvline(lower_bd_manual, c='r')
axs[0].axvline(upper_bd_manual, c='r')

sns.histplot(sdata.obs['log1p_total_counts'], ax=axs[1])
axs[1].axvline(np.log1p(lower_bd_manual), c='r')
axs[1].axvline(np.log1p(upper_bd_manual), c='r')

# plt.savefig(os.path.join(fig_path, 'reads_filtering_threshold.pdf'))
plt.show()

In [None]:
# Total counts describe statistics
ncell_left = sdata.obs.loc[(sdata.obs['total_counts'] >= lower_bd_manual) & (sdata.obs['total_counts'] <= upper_bd_manual), 'total_counts'].shape
median_counts = sdata.obs.loc[(sdata.obs['total_counts'] >= lower_bd_manual) & (sdata.obs['total_counts'] <= upper_bd_manual), 'total_counts'].median()

print(f'With current threshold, there are {ncell_left[0]} cells left and median counts per cell is {median_counts}')

In [None]:
# Filter gene by max counts 
sdata.var['detected_sample'] = sdata.var['max_counts_sample'] > 2
sdata.var['highly_variable_sample'] = sdata.var['max_counts_sample'] > 2
print(sdata.var['detected_sample'].sum())

In [None]:
# Filtration (cell)
sc.pp.filter_cells(sdata, min_genes=10)
sc.pp.filter_cells(sdata, min_counts=np.expm1(lower_bd))
sc.pp.filter_cells(sdata, max_counts=np.expm1(upper_bd))

sdata.layers['raw'] = sdata.X.copy()
sdata.X.shape

In [None]:
# save combined h5ad
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
sdata.write_h5ad(f"{out_path}/{date}-Brain-STARmap-rep2-3mad-filtered.h5ad")

### check filtered data

In [None]:
# Correlation between two protocol
rdata_vector = np.log2(np.array(rdata.X.sum(axis=0)))
sdata_vector = np.log2(np.array(sdata.X.sum(axis=0)))

from scipy import stats
p_corr = stats.pearsonr(rdata_vector, sdata_vector)

corre_df = pd.DataFrame({'RIBOmap': rdata_vector, 'STARmap': sdata_vector})
g = sns.lmplot(x='RIBOmap', y='STARmap', data=corre_df, scatter_kws={'s': 1}, line_kws={'color': 'r'})
g.set_axis_labels('RIBOmap - log2(total counts)', 'STARmap - log2(total counts)')
plt.title(f"Pearson's correlation coefficient: {round(p_corr[0], 3)}")
plt.savefig(os.path.join(fig_path, 'correlation_ribomap_starmap_rep2_3mad.pdf'))
# plt.show()

## Load filtered (cell) input file (rep1)

### RIBOmap

In [None]:
rdata_2 = sc.read_h5ad(os.path.join('Z:/Data/Analyzed/2021-11-23-Hu-MouseBrain/output', '2022-10-23-Hu-TissueRIBOmap-3mad-filtered.h5ad'))
rdata_2

In [None]:
# rename columns 
rdata_2.var.rename(columns={'max_counts':'max_counts_sample'}, inplace=True)
rdata_2.var.rename(columns={'detected':'detected_sample'}, inplace=True)
rdata_2.var.rename(columns={'highly_variable':'highly_variable_sample'}, inplace=True)

In [None]:
# add new metadata
rdata_2.obs['replicate'] = 'rep1'
rdata_2.obs['orig_index'] = rdata_2.obs.index.to_list()

In [None]:
# Total counts describe statistics
ncell_left = rdata_2.obs.shape
median_counts = rdata_2.obs['total_counts'].median()

print(f'In rep1 RIBOmap, there are {ncell_left[0]} cells left and median counts per cell is {median_counts}')

### check filtered data

In [None]:
# Correlation between two protocol
rdata_vector = np.log2(np.array(rdata.X.sum(axis=0)))
rdata_vector_2 = np.log2(np.array(rdata_2.X.sum(axis=0)))

from scipy import stats
p_corr = stats.pearsonr(rdata_vector, rdata_vector_2)

corre_df = pd.DataFrame({'RIBOmap_rep2': rdata_vector, 'RIBOmap_rep1': rdata_vector_2})
g = sns.lmplot(x='RIBOmap_rep2', y='RIBOmap_rep1', data=corre_df, scatter_kws={'s': 1}, line_kws={'color': 'r'})
g.set_axis_labels('RIBOmap_rep2 - log2(total counts)', 'RIBOmap_rep1 - log2(total counts)')
plt.title(f"Pearson's correlation coefficient: {round(p_corr[0], 3)}")
plt.savefig(os.path.join(fig_path, 'correlation_ribomap_rep1_rep2_3mad.pdf'))
# plt.show()

## Combine datasets (n=3)

In [None]:
# combine three datasets 
adata = ad.concat([sdata, rdata, rdata_2])
adata.obs.index = [str(s) for s in range(adata.obs.shape[0])]
adata

In [None]:
# add var
adata.var['max_counts_rep1_RIBOmap'] = rdata_2.var['max_counts_sample'].values
adata.var['max_counts_rep2_RIBOmap'] = rdata.var['max_counts_sample'].values
adata.var['max_counts_rep2_STARmap'] = sdata.var['max_counts_sample'].values

In [None]:
# add obs
adata.obs['protocol-replicate'] = adata.obs['protocol'].astype(str) + '-' + adata.obs['replicate'].astype(str) 
adata.obs['protocol-replicate'] = adata.obs['protocol-replicate'].astype('category')

In [None]:
# save a backup file 
adata.write_h5ad(os.path.join(out_path, '2022-10-23-Brain-combined-3mad-filtered.h5ad'))

In [None]:
sns.boxplot(data=adata.var[["max_counts_rep1_RIBOmap", "max_counts_rep2_RIBOmap", "max_counts_rep2_STARmap"]],)
# ax = plt.gca()
# ax.set_ylim([0, 10])
plt.xticks(rotation=45)
plt.show()

In [None]:
adata.obs['protocol-replicate'].value_counts()

## Combine datasets (n=2, two RIBOmap sections)

In [None]:
# combine three datasets 
adata = ad.concat([rdata, rdata_2])
adata.obs.index = [str(s) for s in range(adata.obs.shape[0])]
adata

In [None]:
# add var
adata.var['max_counts_rep1_RIBOmap'] = rdata_2.var['max_counts_sample'].values
adata.var['max_counts_rep2_RIBOmap'] = rdata.var['max_counts_sample'].values

In [None]:
# add obs
adata.obs['protocol-replicate'] = adata.obs['protocol'].astype(str) + '-' + adata.obs['replicate'].astype(str) 
adata.obs['protocol-replicate'] = adata.obs['protocol-replicate'].astype('category')

In [None]:
# save a backup file 
adata.write_h5ad(os.path.join(out_path, '2022-10-23-Brain-RIBOmap-combined-3mad-filtered.h5ad'))

In [None]:
adata.obs['protocol-replicate'].value_counts()