# VSN Pipelines: pycisTopic QC report

scATAC-seq quality control and cell calling from pycisTopic (https://github.com/aertslab/pycisTopic)

In [None]:
import pycisTopic

pycisTopic.__version__

In [None]:
import warnings

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")

In [None]:
import pybiomart as pbm
import pandas as pd
import pickle
import re
import os
import json

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import multiprocessing as mp  # for kde multithreading calculation

%matplotlib inline

In [None]:
params = json.loads(WORKFLOW_PARAMETERS)

sample_ids = SAMPLES.split(",")

print(f"SAMPLES: {sample_ids}")
print(f"pycisTopic parameters: {json.dumps(params, indent=4)}")

## QC summary

In [None]:
from pycisTopic.qc import plot_sample_metrics
from scipy.stats import gaussian_kde

### Per-sample metrics

for sample_id in profile_data_dict:
    plot_sample_metrics({sample_id: profile_data_dict[sample_id]},
               profile_list=['barcode_rank_plot', 'insert_size_distribution', 'profile_tss', 'frip'],
               insert_size_distriubtion_xlim=[0,600],
               ncol=4,
               cmap='tab20',
               plot=True)
plt.show()

### Combined sample metrics

plot_sample_metrics(profile_data_dict,
           profile_list=['barcode_rank_plot', 'insert_size_distribution', 'profile_tss', 'frip'],
           insert_size_distriubtion_xlim=[0,600],
           ncol=4,
           cmap='tab20',
           plot=True)
plt.show()

## Cell calling

In [None]:
# Create definition for multiprocessing kde calculation.
def calc_kde(xy):
    return gaussian_kde(xy)(xy)


def plot_frag_qc(
    x,
    y,
    ax,
    x_thr_min=None,
    x_thr_max=None,
    y_thr_min=None,
    y_thr_max=None,
    ylab=None,
    xlab="Number of (unique) fragments",
    cmap="viridis",
    density_overlay=False,
    s=10,
    marker="+",
    c="#343434",
    xlim=None,
    ylim=None,
    **kwargs
):
    assert all(x.index == y.index)
    barcodes = x.index.values

    if density_overlay:
        cores = 8

        x_log = np.log(x)

        # Split input array for KDE [log(x), y] array in
        # equaly spaced parts (start_offset + n * nbr_cores).
        kde_parts = [np.vstack([x_log[i::cores], y[i::cores]]) for i in range(cores)]

        # Get nultiprocess context object to spawn processes.
        mp_ctx = mp.get_context("spawn")

        # Calculate KDE in parallel.
        with mp_ctx.Pool(processes=cores) as pool:
            results = pool.map(calc_kde, kde_parts)

        z = np.concatenate(results)

        idx = z.argsort()
        x, y, z, barcodes = x[idx], y[idx], z[idx], barcodes[idx]
    else:
        z = c

    sp = ax.scatter(x, y, c=z, s=s, edgecolors=None, marker=marker, cmap=cmap, **kwargs)
    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])
    if xlim is not None:
        ax.set_xlim(xlim[0], xlim[1])

    # Start with keeping all barcodes.
    barcodes_to_keep = np.full(x.shape[0], True)

    # Filter barcodes out if needed based on thresholds:
    if x_thr_min is not None:
        ax.axvline(x=x_thr_min, color="r", linestyle="--")
        barcodes_to_keep &= x > x_thr_min
    if x_thr_max is not None:
        ax.axvline(x=x_thr_max, color="r", linestyle="--")
        barcodes_to_keep &= x < x_thr_max
    if y_thr_min is not None:
        ax.axhline(y=y_thr_min, color="r", linestyle="--")
        barcodes_to_keep &= y > y_thr_min
    if y_thr_max is not None:
        ax.axhline(y=y_thr_max, color="r", linestyle="--")
        barcodes_to_keep &= y < y_thr_max

    ax.set_xscale("log")
    ax.set_xmargin(0.01)
    ax.set_ymargin(0.01)
    ax.set_xlabel(xlab, fontsize=10)
    ax.set_ylabel(ylab, fontsize=10)

    return barcodes[barcodes_to_keep]


def histogram(array, nbins=100):
    """
    Draw histogram from distribution and identify centers.
    Parameters
    ---------
    array: `class::np.array`
            Scores distribution
    nbins: int
            Number of bins to use in the histogram
    Return
    ---------
    float
            Histogram values and bin centers.
    """
    array = array.ravel().flatten()
    hist, bin_edges = np.histogram(array, bins=nbins, range=None)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
    return hist, bin_centers


def threshold_otsu(array, nbins=100):
    """
    Apply Otsu threshold on topic-region distributions [Otsu, 1979].
    Parameters
    ---------
    array: `class::np.array`
            Array containing the region values for the topic to be binarized.
    nbins: int
            Number of bins to use in the binarization histogram
    Return
    ---------
    float
            Binarization threshold.
    Reference
    ---------
    Otsu, N., 1979. A threshold selection method from gray-level histograms. IEEE transactions on systems, man, and
    cybernetics, 9(1), pp.62-66.
    """
    hist, bin_centers = histogram(array, nbins)
    hist = hist.astype(float)
    # Class probabilities for all possible thresholds
    weight1 = np.cumsum(hist)
    weight2 = np.cumsum(hist[::-1])[::-1]
    # Class means for all possible thresholds
    mean1 = np.cumsum(hist * bin_centers) / weight1
    mean2 = (np.cumsum((hist * bin_centers)[::-1]) / weight2[::-1])[::-1]
    # Clip ends to align class 1 and class 2 variables:
    # The last value of ``weight1``/``mean1`` should pair with zero values in
    # ``weight2``/``mean2``, which do not exist.
    variance12 = weight1[:-1] * weight2[1:] * (mean1[:-1] - mean2[1:]) ** 2
    idx = np.argmax(variance12)
    threshold = bin_centers[:-1][idx]
    return threshold

In [None]:
# extract filter thresholds from Nextflow parameters
filter_frags_lower = {}
filter_frags_upper = {}
filter_tss_lower = {}
filter_tss_upper = {}
filter_frip_lower = {}
filter_frip_upper = {}
filter_dup_rate_lower = {}
filter_dup_rate_upper = {}


def float_or_none(x):
    try:
        return float(x)
    except ValueError:
        return None


def extract_sample_specific_param(filter_dict, sample_id, param_key):
    if type(params["call_cells"][param_key]) is dict:
        if sample_id in params["call_cells"][param_key]:
            filter_dict[sample_id] = float_or_none(
                params["call_cells"][param_key][sample_id]
            )
        else:
            try:
                filter_dict[sample_id] = float_or_none(
                    params["call_cells"][param_key]["default"]
                )
            except KeyError:
                print(
                    f"WARNING: Missing 'default' key in the sample parameters list. Filter for '{param_key}' will be missing for sample '{sample_id}'."
                )
                filter_dict[sample_id] = None
    else:
        filter_dict[sample_id] = float_or_none(params["call_cells"][param_key])
    return filter_dict


for s in sample_ids:
    filter_frags_lower = extract_sample_specific_param(
        filter_frags_lower, s, "filter_frags_lower"
    )
    filter_frags_upper = extract_sample_specific_param(
        filter_frags_upper, s, "filter_frags_upper"
    )
    #
    filter_tss_lower = extract_sample_specific_param(
        filter_tss_lower, s, "filter_tss_lower"
    )
    filter_tss_upper = extract_sample_specific_param(
        filter_tss_upper, s, "filter_tss_upper"
    )
    #
    filter_frip_lower = extract_sample_specific_param(
        filter_frip_lower, s, "filter_frip_lower"
    )
    filter_frip_upper = extract_sample_specific_param(
        filter_frip_upper, s, "filter_frip_upper"
    )
    #
    filter_dup_rate_lower = extract_sample_specific_param(
        filter_dup_rate_lower, s, "filter_dup_rate_lower"
    )
    filter_dup_rate_upper = extract_sample_specific_param(
        filter_dup_rate_upper, s, "filter_dup_rate_upper"
    )

In [None]:
# show cell filters:
print(f"Filter parameters:")
print(f"filter_frags_lower: {json.dumps(filter_frags_lower, indent=4)}")
print(f"filter_frags_upper: {json.dumps(filter_frags_upper, indent=4)}")
print(f"filter_tss_lower: {json.dumps(filter_tss_lower, indent=4)}")
print(f"filter_tss_upper: {json.dumps(filter_tss_upper, indent=4)}")
print(f"filter_frip_lower: {json.dumps(filter_frip_lower, indent=4)}")
print(f"filter_frip_upper: {json.dumps(filter_frip_upper, indent=4)}")
print(f"filter_dup_rate_lower: {json.dumps(filter_dup_rate_lower, indent=4)}")
print(f"filter_dup_rate_upper: {json.dumps(filter_dup_rate_upper, indent=4)}")

Using Otsu thresholding

In [None]:
def plot_qc(
    sample,
    metadata_bc_df,
    include_kde=True,
    detailed_title=params["call_cells"]["use_detailed_title_on_scatterplot"],
    s=4,
):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4), dpi=150)

    # calculate thresholds using otsu:
    x_arr = metadata_bc_df["Log_total_nr_frag"]
    x_threshold = threshold_otsu(x_arr, 5000)
    x_threshold = 10**x_threshold

    y_arr = metadata_bc_df["TSS_enrichment"]
    y_arr_cutoff = threshold_otsu(y_arr, 5000)
    y_threshold = threshold_otsu(y_arr_cutoff, 5000)

    # plot everything
    p1_cells = plot_frag_qc(
        x=metadata_bc_df["Unique_nr_frag"],
        y=metadata_bc_df["TSS_enrichment"],
        ylab="TSS Enrichment",
        s=s,
        x_thr_min=x_threshold,
        y_thr_min=y_threshold,
        density_overlay=include_kde,
        ax=ax1,
    )
    p2_cells = plot_frag_qc(
        x=metadata_bc_df["Unique_nr_frag"],
        y=metadata_bc_df["FRIP"],
        x_thr_min=x_threshold,
        ylab="FRIP",
        s=s,
        ylim=[0, 1],
        density_overlay=include_kde,
        ax=ax2,
    )
    p3_cells = plot_frag_qc(
        x=metadata_bc_df["Unique_nr_frag"],
        y=metadata_bc_df["Dupl_rate"],
        x_thr_min=x_threshold,
        ylab="Duplicate rate per cell",
        s=s,
        ylim=[0, 1],
        density_overlay=include_kde,
        ax=ax3,
    )
    bc_passing_filters = list(set(p1_cells) & set(p2_cells) & set(p3_cells))
    if detailed_title:
        med_nf = metadata_bc_df.loc[bc_passing_filters, "Unique_nr_frag"].median()
        med_tss = metadata_bc_df.loc[bc_passing_filters, "TSS_enrichment"].median()
        med_frip = metadata_bc_df.loc[bc_passing_filters, "FRIP"].median()
        title = f"{sample}: Kept {len(bc_passing_filters)} cells using Otsu filtering. Median(fragments): {med_nf:.0f}. Median(TSS Enrichment): {med_tss:.2f}). Median FRIP: {med_frip:.2f}\nUsed a minimum of {x_threshold} fragments and TSS enrichment of {y_threshold})"
    else:
        title = sample

    fig.suptitle(title, x=0.5, y=0.95, fontsize=10)
    plt.tight_layout()
    plt.savefig(f"plots_qc/qc_otsu_{sample}.png", dpi=300, facecolor="white")
    plt.show()
    return bc_passing_filters, fig

In [None]:
if not os.path.exists("selected_barcodes"):
    os.makedirs("selected_barcodes")
if not os.path.exists("plots_qc"):
    os.makedirs("plots_qc")

for sample in sample_ids:
    with open(f"metadata/{sample}.metadata.pkl", "rb") as fh:
        metadata_bc_df = pickle.load(fh)

    bc_passing_filters, fig = plot_qc(sample, metadata_bc_df)
    with open(f"selected_barcodes/{sample}_bc_passing_filters_otsu.pkl", "wb") as fh:
        pickle.dump(bc_passing_filters, fh)