In [1]:
import os
import hail as hl
import pyspark
import bokeh
import logging
import random
import pandas as pd
import numpy as np
from scipy import stats
import pickle 
from matplotlib import pyplot as plt
from typing import Any, Counter, List, Optional, Tuple, Union,Dict,Set, Iterable
from hail.plot import show, output_notebook
from bokeh.palettes import d3  # pylint: disable=no-name-in-module
from bokeh.models import Plot, Row, Span, NumeralTickFormatter, LabelSet
from gnomad.utils.plotting import *
from typing import Set, Tuple

tmp_dir = "hdfs://spark-master:9820/"
temp_dir = "file:///home/ubuntu/data/tmp"
plot_dir = "/home/ubuntu/data/tmp"

sc = pyspark.SparkContext()
hadoop_config = sc._jsc.hadoopConfiguration()
hadoop_config.set("fs.s3a.access.key", "8YY584J59H7Q6AVKHSU8")
hadoop_config.set("fs.s3a.secret.key", "P8vePa7JUvxKXX2me9ti1cGujgYWMoimAwx4mMlM")
hadoop_config.set("fs.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem")
hadoop_config.set("fs.AbstractFileSystem.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS")
hl.init(sc=sc, tmp_dir=tmp_dir, default_reference='GRCh38')
output_notebook()
logging.basicConfig(format="%(levelname)s (%(name)s %(lineno)s): %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

pip-installed Hail requires additional configuration options in Spark referring
  to the path to the Hail Python module directory HAIL_DIR,
  e.g. /path/to/python/site-packages/hail:
    spark.jars=HAIL_DIR/hail-all-spark.jar
    spark.driver.extraClassPath=HAIL_DIR/hail-all-spark.jar
    spark.executor.extraClassPath=./hail-all-spark.jarRunning on Apache Spark version 2.4.5
SparkUI available at http://spark-master:4040
Welcome to
     __  __     <>__
    / /_/ /__  __/ /
   / __  / _ `/ / /
  /_/ /_/\_,_/_/_/   version 0.2.41-b8144dba46e6
LOGGING: writing to /home/ubuntu/data/tmp/scripts/sanger_gnomad_hail_qc/notebooks/hail-20201111-1403-0.2.41-b8144dba46e6.log


# Variant and Sample QC with hail - to try and recreate with opencga

Read chr19 from the exome WES cohort with 93674 samples. 

In [114]:
mt=hl.read_matrix_table(f"{temp_dir}/ddd-elgh-ukbb/chr19.mt")

In [115]:
mt.describe()

----------------------------------------
Global fields:
    None
----------------------------------------
Column fields:
    's': str
----------------------------------------
Row fields:
    'locus': locus<GRCh38>
    'alleles': array<str>
    'rsid': str
    'qual': float64
    'filters': set<str>
    'info': struct {
        AC: array<int32>, 
        AF: array<float64>, 
        AN: int32, 
        AS_BaseQRankSum: array<float64>, 
        AS_FS: array<float64>, 
        AS_InbreedingCoeff: array<float64>, 
        AS_MQ: array<float64>, 
        AS_MQRankSum: array<float64>, 
        AS_QD: array<float64>, 
        AS_ReadPosRankSum: array<float64>, 
        AS_SOR: array<float64>, 
        BaseQRankSum: float64, 
        DB: bool, 
        DP: int32, 
        DS: bool, 
        END: int32, 
        ExcessHet: float64, 
        FS: float64, 
        InbreedingCoeff: float64, 
        MLEAC: array<int32>, 
        MLEAF: array<float64>, 
        MQ: float64, 
        MQRankSum: floa

In [116]:
mt.count()

(1663562, 93674)

# Variant QC with hail


In [117]:
mt_vqc=hl.variant_qc(mt)

In [118]:
mt_vqc.describe()

----------------------------------------
Global fields:
    None
----------------------------------------
Column fields:
    's': str
----------------------------------------
Row fields:
    'locus': locus<GRCh38>
    'alleles': array<str>
    'rsid': str
    'qual': float64
    'filters': set<str>
    'info': struct {
        AC: array<int32>, 
        AF: array<float64>, 
        AN: int32, 
        AS_BaseQRankSum: array<float64>, 
        AS_FS: array<float64>, 
        AS_InbreedingCoeff: array<float64>, 
        AS_MQ: array<float64>, 
        AS_MQRankSum: array<float64>, 
        AS_QD: array<float64>, 
        AS_ReadPosRankSum: array<float64>, 
        AS_SOR: array<float64>, 
        BaseQRankSum: float64, 
        DB: bool, 
        DP: int32, 
        DS: bool, 
        END: int32, 
        ExcessHet: float64, 
        FS: float64, 
        InbreedingCoeff: float64, 
        MLEAC: array<int32>, 
        MLEAF: array<float64>, 
        MQ: float64, 
        MQRankSum: floa

## Hard Filters

In [121]:
mt=mt.annotate_rows(fail_hard_filters= (mt.info.QD < 2) | (mt.info.FS > 60) | (mt.info.MQ < 30))

### How many variants fail these filters and what percent of total variants is it

In [124]:
fail_hard_filters_count=mt.aggregate_rows(hl.agg.count_where(mt.fail_hard_filters==True))

In [126]:
hl.eval(fail_hard_filters_count)

81127

In [127]:
81127/1663562*100

4.876704324816268

### Filter out these variants

In [130]:
mt_filtered=mt.filter_rows(mt.fail_hard_filters==True,keep=False)

In [129]:
mt_filtered.count()

(1582219, 93674)

## Re-run variant_qc for new subset of variants

In [131]:
mt_filtered_vqc=hl.variant_qc(mt_filtered)

## Sample QC with hail

In [5]:
mt_sqc=hl.sample_qc(mt)

### Automatic calculations of sample qc metrics:

In [6]:
mt_sqc.describe()

----------------------------------------
Global fields:
    None
----------------------------------------
Column fields:
    's': str
    'sample_qc': struct {
        dp_stats: struct {
            mean: float64, 
            stdev: float64, 
            min: float64, 
            max: float64
        }, 
        gq_stats: struct {
            mean: float64, 
            stdev: float64, 
            min: float64, 
            max: float64
        }, 
        call_rate: float64, 
        n_called: int64, 
        n_not_called: int64, 
        n_filtered: int64, 
        n_hom_ref: int64, 
        n_het: int64, 
        n_hom_var: int64, 
        n_non_ref: int64, 
        n_singleton: int64, 
        n_snp: int64, 
        n_insertion: int64, 
        n_deletion: int64, 
        n_transition: int64, 
        n_transversion: int64, 
        n_star: int64, 
        r_ti_tv: float64, 
        r_het_hom_var: float64, 
        r_insertion_deletion: float64
    }
----------------------------

## Annotate samples with their population assignments. 
### These assignments are the result of a previous investigation within the ongoing gnomad_qc project. 

In [7]:
populations=hl.import_table(f"{temp_dir}/ddd-elgh-ukbb/pop_assignments.tsv").key_by('s')

2020-11-11 14:03:46 Hail: INFO: Reading table with no type imputation
  Loading column 's' as type 'str' (type not specified)
  Loading column 'known_pop' as type 'str' (type not specified)
  Loading column 'pca_scores' as type 'str' (type not specified)
  Loading column 'pop' as type 'str' (type not specified)
  Loading column 'prob_African' as type 'str' (type not specified)
  Loading column 'prob_Any other Asian background' as type 'str' (type not specified)
  Loading column 'prob_Any other Black background' as type 'str' (type not specified)
  Loading column 'prob_Any other mixed background' as type 'str' (type not specified)
  Loading column 'prob_Any other white background' as type 'str' (type not specified)
  Loading column 'prob_Asian or Asian British' as type 'str' (type not specified)
  Loading column 'prob_Bangladeshi' as type 'str' (type not specified)
  Loading column 'prob_Black or Black British' as type 'str' (type not specified)
  Loading column 'prob_British' as type '

In [8]:
df=pd.read_csv(f"{temp_dir}/ddd-elgh-ukbb/pop_assignments.tsv", sep="\t")

  interactivity=interactivity, compiler=compiler, result=result)


In [9]:
df.columns

Index(['s', 'known_pop', 'pca_scores', 'pop', 'prob_African',
       'prob_Any other Asian background', 'prob_Any other Black background',
       'prob_Any other mixed background', 'prob_Any other white background',
       'prob_Asian or Asian British', 'prob_Bangladeshi',
       'prob_Black or Black British', 'prob_British', 'prob_Caribbean',
       'prob_Chinese', 'prob_Do not know', 'prob_Indian', 'prob_Irish',
       'prob_Mixed', 'prob_Other ethnic group', 'prob_Pakistani',
       'prob_Prefer not to answer', 'prob_White', 'prob_White and Asian',
       'prob_White and Black African', 'prob_White and Black Caribbean'],
      dtype='object')

In [10]:
df["pop"].value_counts()

oth                           52358
British                       40806
Caribbean                       243
Indian                          176
Chinese                          75
White and Black Caribbean         7
African                           7
Any other white background        1
Other ethnic group                1
Name: pop, dtype: int64

In [11]:
df["pop"].describe()

count     93674
unique        9
top         oth
freq      52358
Name: pop, dtype: object

 ## Annotate samples with their population information

In [12]:
mt_sqc=mt_sqc.annotate_cols(pop=populations[mt_sqc.s].pop)

In [13]:
mt_sqc.cols().describe()

----------------------------------------
Global fields:
    None
----------------------------------------
Row fields:
    's': str 
    'sample_qc': struct {
        dp_stats: struct {
            mean: float64, 
            stdev: float64, 
            min: float64, 
            max: float64
        }, 
        gq_stats: struct {
            mean: float64, 
            stdev: float64, 
            min: float64, 
            max: float64
        }, 
        call_rate: float64, 
        n_called: int64, 
        n_not_called: int64, 
        n_filtered: int64, 
        n_hom_ref: int64, 
        n_het: int64, 
        n_hom_var: int64, 
        n_non_ref: int64, 
        n_singleton: int64, 
        n_snp: int64, 
        n_insertion: int64, 
        n_deletion: int64, 
        n_transition: int64, 
        n_transversion: int64, 
        n_star: int64, 
        r_ti_tv: float64, 
        r_het_hom_var: float64, 
        r_insertion_deletion: float64
    } 
    'pop': str 
-------------

2020-11-11 14:04:03 Hail: WARN: cols(): Resulting column table is sorted by 'col_key'.
    To preserve matrix table column order, first unkey columns with 'key_cols_by()'


## Create a hail table with only the columns from the matrixtable

In [24]:
ht=mt_sqc.cols()

# Calculate median scores per population and mark as failed the samples that are outside 4 MAD from median for a list of metrics

In [35]:
def get_median_and_mad_expr(
    metric_expr: hl.expr.ArrayNumericExpression, k: float = 1.4826
) -> hl.expr.StructExpression:
    """
    Computes the median and median absolute deviation (MAD) for the given expression.
    Note that the default value of k assumes normally distributed data.

    :param metric_expr: Expression to compute median and MAD for
    :param k: The scaling factor for MAD calculation. Default assumes normally distributed data.
    :return: Struct with median and MAD
    """
    return hl.bind(
        lambda x: hl.struct(median=x[1], mad=k *
                            hl.median(hl.abs(x[0] - x[1]))),
        hl.bind(lambda x: hl.tuple([x, hl.median(x)]),
                hl.agg.collect(metric_expr)),
    )

def compute_stratified_metrics_filter(
    ht: hl.Table,
    qc_metrics: Dict[str, hl.expr.NumericExpression],
    strata: Optional[Dict[str, hl.expr.Expression]] = None,
    lower_threshold: float = 4.0,
    upper_threshold: float = 4.0,
    metric_threshold: Optional[Dict[str, Tuple[float, float]]] = None,
    filter_name: str = "qc_metrics_filters",
) -> hl.Table:
    """
    Compute median, MAD, and upper and lower thresholds for each metric used in outlier filtering

    :param ht: HT containing relevant sample QC metric annotations
    :param qc_metrics: list of metrics (name and expr) for which to compute the critical values for filtering outliers
    :param strata: List of annotations used for stratification. These metrics should be discrete types!
    :param lower_threshold: Lower MAD threshold
    :param upper_threshold: Upper MAD threshold
    :param metric_threshold: Can be used to specify different (lower, upper) thresholds for one or more metrics
    :param filter_name: Name of resulting filters annotation
    :return: Table grouped by strata, with upper and lower threshold values computed for each sample QC metric
    """

    _metric_threshold = {
        metric: (lower_threshold, upper_threshold) for metric in qc_metrics
    }
    if metric_threshold is not None:
        _metric_threshold.update(metric_threshold)

    def make_filters_expr(
        ht: hl.Table, qc_metrics: Iterable[str]
    ) -> hl.expr.SetExpression:
        return hl.set(
            hl.filter(
                lambda x: hl.is_defined(x),
                [hl.or_missing(ht[f"fail_{metric}"], metric) for metric in qc_metrics],
            )
        )

    if strata is None:
        strata = {}

    ht = ht.select(**qc_metrics, **strata).key_by("s").persist()

    agg_expr = hl.struct(
        **{
            metric: hl.bind(
                lambda x: x.annotate(
                    lower=x.median - _metric_threshold[metric][0] * x.mad,
                    upper=x.median + _metric_threshold[metric][1] * x.mad,
                ),
                get_median_and_mad_expr(ht[metric]),
            )
            for metric in qc_metrics
        }
    )

    if strata:
        ht = ht.annotate_globals(
            qc_metrics_stats=ht.aggregate(
                hl.agg.group_by(hl.tuple([ht[x] for x in strata]), agg_expr),
                _localize=False,
            )
        )
        metrics_stats_expr = ht.qc_metrics_stats[hl.tuple([ht[x] for x in strata])]
    else:
        ht = ht.annotate_globals(
            qc_metrics_stats=ht.aggregate(agg_expr, _localize=False)
        )
        metrics_stats_expr = ht.qc_metrics_stats

    fail_exprs = {
        f"fail_{metric}": (ht[metric] <= metrics_stats_expr[metric].lower)
        | (ht[metric] >= metrics_stats_expr[metric].upper)
        for metric in qc_metrics
    }
    ht = ht.transmute(**fail_exprs)
    stratified_filters = make_filters_expr(ht, qc_metrics)
    return ht.annotate(**{filter_name: stratified_filters})


In [36]:
strata = {}
qc_metrics = {
        'sample_qc.n_snp': ht.sample_qc.n_snp,
        'sample_qc.r_ti_tv': ht.sample_qc.r_ti_tv,
        'sample_qc.r_insertion_deletion': ht.sample_qc.r_insertion_deletion,
        'sampleqc.n_insertion': ht.sample_qc.n_insertion,
        'sampleqc.n_deletion': ht.sample_qc.n_deletion,
        'sample_qc.r_het_hom_var': ht.sample_qc.r_het_hom_var
    }
strata['pop'] = ht.pop

In [37]:
pop_filter_ht = compute_stratified_metrics_filter(
        ht, qc_metrics, strata)

2020-11-11 14:34:14 Hail: INFO: Coerced sorted dataset


In [38]:
pop_filter_ht.write(f"{tmp_dir}/chr19_pop.ht")

2020-11-11 14:35:26 Hail: INFO: wrote table with 93674 rows in 16 partitions to hdfs://spark-master:9820//chr19_pop.ht


In [39]:
pop_filter_ht=hl.read_table(f"{temp_dir}/ddd-elgh-ukbb/chr19_pop.ht")

In [48]:
pop_filter_ht.globals.describe()

--------------------------------------------------------
Type:
        struct {
        qc_metrics_stats: dict<tuple (
            str
        ), struct {
            `sample_qc.n_snp`: struct {
                median: int64, 
                mad: float64, 
                lower: float64, 
                upper: float64
            }, 
            `sample_qc.r_ti_tv`: struct {
                median: float64, 
                mad: float64, 
                lower: float64, 
                upper: float64
            }, 
            `sample_qc.r_insertion_deletion`: struct {
                median: float64, 
                mad: float64, 
                lower: float64, 
                upper: float64
            }, 
            `sampleqc.n_insertion`: struct {
                median: int64, 
                mad: float64, 
                lower: float64, 
                upper: float64
            }, 
            `sampleqc.n_deletion`: struct {
                median: int64, 
           

In [74]:
hl.eval(pop_filter_ht.globals['qc_metrics_stats'].keys())

[('African',),
 ('Any other white background',),
 ('British',),
 ('Caribbean',),
 ('Chinese',),
 ('Indian',),
 ('Other ethnic group',),
 ('White and Black Caribbean',),
 ('oth',)]

In [75]:
hl.eval(pop_filter_ht.globals['qc_metrics_stats'][('African',)])

Struct(sample_qc.n_snp=Struct(median=10530, mad=397.3368, lower=8940.6528, upper=12119.3472), sample_qc.r_ti_tv=Struct(median=2.3009739239710965, mad=0.03810543268110271, lower=2.148552193246686, upper=2.453395654695507), sample_qc.r_insertion_deletion=Struct(median=0.8488210818307905, mad=0.021231487716861272, lower=0.7638951309633455, upper=0.9337470326982356), sampleqc.n_insertion=Struct(median=582, mad=50.4084, lower=380.3664, upper=783.6336), sampleqc.n_deletion=Struct(median=713, mad=103.782, lower=297.872, upper=1128.128), sample_qc.r_het_hom_var=Struct(median=1.7222997488138432, mad=0.01941424357466751, lower=1.644642774515173, upper=1.7999567231125133))

In [76]:
hl.eval(pop_filter_ht.globals['qc_metrics_stats'][('African',)]['sample_qc.n_snp'])

Struct(median=10530, mad=397.3368, lower=8940.6528, upper=12119.3472)

In [78]:
pop_filter_ht.show()

s,fail_sample_qc.n_snp,fail_sample_qc.r_ti_tv,fail_sample_qc.r_insertion_deletion,fail_sampleqc.n_insertion,fail_sampleqc.n_deletion,fail_sample_qc.r_het_hom_var,qc_metrics_filters
str,bool,bool,bool,bool,bool,bool,set<str>
"""EGAN00001006259""",False,False,False,False,False,True,"{""sample_qc.r_het_hom_var""}"
"""EGAN00001006260""",False,False,False,False,False,False,{}
"""EGAN00001006261""",False,False,False,False,False,True,"{""sample_qc.r_het_hom_var""}"
"""EGAN00001006263""",False,False,False,False,False,True,"{""sample_qc.r_het_hom_var""}"
"""EGAN00001006264""",False,False,False,False,False,True,"{""sample_qc.r_het_hom_var""}"
"""EGAN00001006265""",False,False,False,False,False,False,{}
"""EGAN00001006266""",False,False,False,False,False,False,{}
"""EGAN00001006267""",False,False,False,False,False,False,{}
"""EGAN00001006268""",False,False,False,False,False,False,{}
"""EGAN00001006269""",False,False,False,False,False,False,{}


## How many samples are outside the 4 MAD from median and failed this QC?


In [84]:
 failed_samples_count=pop_filter_ht.aggregate(hl.agg.count_where(hl.len(pop_filter_ht.qc_metrics_filters) != 0))

In [85]:
hl.eval(failed_samples_count)

1349

 ## Find and remove these failed samples

In [97]:
failed_samples=pop_filter_ht.filter((hl.len(pop_filter_ht.qc_metrics_filters) != 0), keep=True)
passed_samples=pop_filter_ht.filter((hl.len(pop_filter_ht.qc_metrics_filters) == 0), keep=True)

In [98]:
failed_samples.count()

1349

In [99]:
passed_samples.count()

92325

In [102]:
passed_samples=passed_samples.annotate(PASS="PASS")

## Filter original matrixtable

In [105]:
mt=mt.annotate_cols(passed_population_filters=passed_samples[mt.s].PASS)

In [107]:
mt=mt.filter_cols(mt.passed_population_filters =="PASS")

In [108]:
mt.count()

(1663562, 92325)

## Calculate new sample QC for new set of samples

In [109]:
mt_sqc_new=hl.sample_qc(mt)

In [111]:
mt_sqc_new.describe()

----------------------------------------
Global fields:
    None
----------------------------------------
Column fields:
    's': str
    'passed_population_filters': str
    'sample_qc': struct {
        dp_stats: struct {
            mean: float64, 
            stdev: float64, 
            min: float64, 
            max: float64
        }, 
        gq_stats: struct {
            mean: float64, 
            stdev: float64, 
            min: float64, 
            max: float64
        }, 
        call_rate: float64, 
        n_called: int64, 
        n_not_called: int64, 
        n_filtered: int64, 
        n_hom_ref: int64, 
        n_het: int64, 
        n_hom_var: int64, 
        n_non_ref: int64, 
        n_singleton: int64, 
        n_snp: int64, 
        n_insertion: int64, 
        n_deletion: int64, 
        n_transition: int64, 
        n_transversion: int64, 
        n_star: int64, 
        r_ti_tv: float64, 
        r_het_hom_var: float64, 
        r_insertion_deletion: float