# Template for GWAS pipeline with logistic regression and VariantSpark

# Introduction
- __What does this pipeline do?__
    - Load data from PLINK or Hail format
    - Quality Control
    - Principal Component Analysis
    - GWAS: logistic regression (QQ, Manhattan plot)
    - VariantSpark: random forest (Manhattan plot)
    - Scatter plot comparing SNP scores from VariantSpark and logistic regression
    - (if file is provided: overlap with GWAS catalog)
    - Polygenic Risk Score calculation
    - Prediction accuracy evaluation with random forest on top n most significant SNPs from GWAS and VS
- __What is not in here?__
    - Annotation with Hail database functionality since only possible with Hail version that is incompatible with VS (and on GCP)
    - Burden testing (reliant on annotation)
- __How do I execute it?__
    - Set paths and necessary parameters in the section <a href='#param'>parameters</a>
    - Adjust parameters in each cell of the <a href='#pipeline'>pipeline</a> if needed (i.e. for QC)
    - To see implementation go to <a href='#functions'>functions</a>

__Note__: Only execute the <a href='#loading'>loading</a> of Hail, VariantSpark and other packages once

    
The GWAS pipeline is based on 
[A guide to genome‐wide association analysis and post‐analytic interrogation](https://onlinelibrary.wiley.com/doi/full/10.1002/sim.6605)
and 
[A tutorial on conducting genome‐wide association studies: Quality control and statistical analysis](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6001694/)

--------------

<a id='param'></a>

# Parameters
Set paths and choose parameters for the pipeline here   
Other parameters can be set in the <a href='#functions'>functions</a> section

In [None]:
## Loading
# set only_plink = True, if the data is only available in PLINK format. If False, the plink_path is ignored
only_plink = False 
# adjust path to data in PLINK format
plink_paths = {
    'bed':'s3://tb-als/ukbio/ukb_efe_chr1_v1.bed',
    'bim': 's3://tb-als/ukbio/ukb_fe_exm_chrall_v1.bim',
    'fam': 's3://tb-als/ukbio/ukb27483_efe_chr1_v1_s49953.fam'
}
ref_genome = 'GRCh38'
# path to data in Hail format (path to load from or save to)
hl_path = "s3://tb-als/ukbio/results/"
hl_name = "asthma.mt"


## PCA
# plot the first few PCs against each other and a scree plot
PCA_visual = True

## GWAS logistic regression and VS random forest
# perform gwas catalog overlap or not
gwas_catalog_overlap = False
gwas_catalog_filepath = ""
# the files produced here will also be saved into the hl_path folder

--------------

<a id='loading'></a>

# Loading Hail, VariantSpark and other packages
Only execute this section once in a running notebook. Otherwise Hail fails

In [None]:
import time
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import subprocess

In [None]:
from pyspark import SparkContext
sc = SparkContext()

In [None]:
import hail as hl
import varspark.hail as vshl
vshl.init(sc=sc, default_reference=ref_gen)

In [None]:
from hail.plot import show
from pprint import pprint
hl.plot.output_notebook()

In [None]:
import varspark.hail.plot as vshlplt

In [None]:
import sys
!{sys.executable} -m pip install sklearn

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from sklearn.feature_selection import SelectFromModel
from sklearn.feature_selection import VarianceThreshold
from sklearn.feature_selection import RFECV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score

In [None]:
import sys
!{sys.executable} -m pip install myvariant

In [None]:
import myvariant

In [None]:
import sys
!{sys.executable} -m pip install treeinterpreter

In [None]:
from treeinterpreter import treeinterpreter as ti

----------------

<a id='functions'></a>

# Functions


## Loading and converting data

In [None]:
# convert PLINK to hail matrix table
# this needs lots of CPUs

def plink_to_hl(plink_paths, hl_path):
    print("Reading PLINK data from " + plink_path + ". This might take some time.")
    start = time.time()
    mt = hl.import_plink(bed=plink_paths['bed'], bim=plink_paths['bim'], fam=plink_paths['fam'],
                         reference_genome=ref_genome)
    end = time.time()
    print("Loading as PLINK took "+ str(round(end-start)) + " seconds")
    # save as hail mt to hl_path
    print("Writing hail matrix table to " + hl_path + ".")  
    start = time.time()
    mt.write(hl_path)
    end = time.time()
    print("Writing Hail matrix table took "+ str(round(end-start)) + " seconds")
    return mt

## Miscellaneous

In [None]:
# To choose how many principal components should be used as covariates in logistic regression
# plots a scree plot (eigenvalues) and PCs against each other to see how well they separate the data

def choose_pcs(mt,k=5, annotated=False):
    if not annotated:
        eigenvalues, scores, loadings = hl.hwe_normalized_pca(mt.GT, k=k)
        mt = mt.annotate_cols(scores = scores[mt.s])
        pc_names = []
        for i in range(len(eigenvalues)):
            pc_names.append("PC"+str(i+1))
        #In multivariate statistics, a scree plot is a line plot of the eigenvalues of factors or principal components in an analysis.
        #The scree plot is used to determine the number of factors to retain in an exploratory factor analysis (FA) or principal components to keep in a principal component analysis (PCA).

        #A scree plot always displays the eigenvalues in a downward curve, ordering the eigenvalues from largest to smallest.
        #According to the scree test, the "elbow" of the graph where the eigenvalues seem to level off is found and factors or components to the left of this point should be retained as significant

        # scree plot
        pc_names = []
        for i in range(k):
            pc_names.append("PC"+str(i+1))
        plt.bar(pc_names,eigenvalues)
        plt.show()

    # plot pcs 1-4 against each other
    p = hl.plot.scatter(mt.scores.scores[0],
                        mt.scores.scores[1],
                        label=mt.is_case,
                        title='PCA', xlabel='PC1', ylabel='PC2')
    show(p)
    p = hl.plot.scatter(mt.scores.scores[0],
                        mt.scores.scores[2],
                        label=mt.is_case,
                        title='PCA', xlabel='PC1', ylabel='PC3')
    show(p)
    p = hl.plot.scatter(mt.scores.scores[0],
                        mt.scores.scores[3],
                        label=mt.is_case,
                        title='PCA', xlabel='PC1', ylabel='PC4')
    show(p)
    p = hl.plot.scatter(mt.scores.scores[1],
                        mt.scores.scores[2],
                        label=mt.is_case,
                        title='PCA', xlabel='PC2', ylabel='PC3')
    show(p)
    p = hl.plot.scatter(mt.scores.scores[1],
                        mt.scores.scores[3],
                        label=mt.is_case,
                        title='PCA', xlabel='PC2', ylabel='PC4')
    show(p)
    p = hl.plot.scatter(mt.scores.scores[2],
                        mt.scores.scores[3],
                        label=mt.is_case,
                        title='PCA', xlabel='PC3', ylabel='PC4')
    show(p)
    
    return mt

In [None]:
# Replace empty genotype calls with 0/0
def replace_empty_gt(mt):
    # or_else(a,b): if a is missing, return b
    mt = mt.annotate_entries(GT=hl.or_else(mt.GT, hl.call(0, 0, phased=False)))
    return mt

## Variant QC

#### Call rate filter (variants)

In [None]:
# filter out VARIANTS with low call rate 
# (thresh 1 needed for it to work with vs -> to get comparable results also do before logistic regression)
def call_rate_filter_variants(mt, thresh=0.95):
    print(">>>>> call_rate_filter_variants")
    mt = hl.variant_qc(mt)
    before_filter = mt.count_rows()
    mt = mt.filter_rows(mt.variant_qc.call_rate >= thresh, keep=True)
    print('After call rate filter (VARIANTS) with threshold '+str(thresh)+', '+str(mt.count_rows()) + '/'+ str(before_filter)+' variants remain.')
    return mt

#### Minor allele frequency filter


In [None]:
# Variants with a low minor allele frequency are filtered out

def maf_filter(mt, thresh=0.01):
    print(">>>>> maf_filter with threshold " + str(thresh))
    mt = hl.variant_qc(mt)
    before_filter = mt.count_rows()
    mt = mt.filter_rows(hl.min(mt.variant_qc.AF) < thresh, keep=False)
    print("After MAF filter with threshold "+str(thresh)+", " + str(mt.count_rows()) + "/" + str(before_filter) + " variants remain.")
    return mt

#### Linkage disequilibrium pruning


In [None]:
# This removes variants that are in high LD to reduce redundance
def ld_prune(mt, r2=0.2, window_size=500000):
    print(">>>>> ld_prune")
    mt = hl.variant_qc(mt)
    before_filter = mt.count_rows()
    #biallelic_mt = mt.filter_rows(hl.len(mt.alleles) == 2)
    #if biallelic_mt.count() == mt.count():
    start = time.time()
    pruned_variant_table_r02 = hl.ld_prune(mt.GT, r2=r2, bp_window_size=window_size)
    end = time.time()
    print("Pruning took " + str(round(end-start)) + " seconds")
    mt = mt.filter_rows( hl.is_defined(pruned_variant_table_r02[mt.row_key]))
    print("After LD pruning with r2="+str(r2)+", " + str(mt.count_rows()) + "/" + str(before_filter) + " SNPs left")
    return mt

#### Hardy Weinberg equilibrium filter 


In [None]:
# options to use the same filter for all samples (all_together) or use different thresholds for cases and controls (control_sep)
def hwe_filter(mt, case_control="all_together"):
    print(">>>>> hwe_filter")
    mt = hl.variant_qc(mt)
    before_filter = mt.count_rows()
    # changed control_sep to different thresholds of HWE filtering
    if case_control == "control_sep":
        case = mt.filter_cols(mt.is_case, keep=True)
        ctrl = mt.filter_cols(~mt.is_case, keep=True)
        ctrl = hl.variant_qc(ctrl)
        ctrl = ctrl.filter_rows(ctrl.variant_qc.p_value_hwe < 1e-10, keep=False)
        case = hl.variant_qc(case)
        case = case.filter_rows(case.variant_qc.p_value_hwe < 1e-6, keep=False)
        mt = case.union_cols(ctrl)
    # all together applies thresh of 1e-6
    elif case_control == "all_together":
        mt = hl.variant_qc(mt)
        mt = mt.filter_rows(mt.variant_qc.p_value_hwe < 0.000001, keep=False)
    else:
        print("WRONG OPTION. use 'all_together' or 'control_sep'")
    print("After HWE filter, " + str(mt.count_rows()) + "/" + str(before_filter) + " SNPs left")
    return mt

#### Autosome filter


In [None]:
# Only use autosomes (exclude X,Y, mitochondrial and other contigs)
def autosomes(mt):
    print(">>>>> only use autosomes")
    mt = hl.variant_qc(mt)
    before_filter = mt.count_rows()
    # in_autosome_or_par() returns True if the locus is on an autosome or a pseudoautosomal region of chromosome X or Y
    mt =  mt.filter_rows(mt.locus.in_autosome_or_par())
    print("After autosome filter, " + str(mt.count_rows()) + "/" + str(before_filter) + " SNPs left")
    return mt

#### Single nucleotide filter


In [None]:
# Variations involving multiple nucleotides are filtered out
def single_np(mt):
    print(">>>>> only use single nucleotide variants (not multiple nucleotides)")
    mt = hl.variant_qc(mt)
    before_filter = mt.count_rows()
    # hl.is_snp('A', 'T')
    #mt =  mt.filter_rows((mt.alleles[0],mt.alleles[1]).is_snp())
    mt = hl.filter_alleles(mt, lambda allele, i: hl.is_snp(mt.alleles[0], allele))
    # filter_alleles() does not update any fields other than locus and alleles. This means that row fields like allele count (AC) and entry fields like allele depth (AD) can become meaningless unless they are also updated. You can update them with annotate_rows() and annotate_entries().
    mt = hl.variant_qc(mt)
    print("After is_snp filter, " + str(mt.count_rows()) + "/" + str(before_filter) + " SNPs left")
    return mt

## Sample QC

#### Call rate filter (samples)


In [None]:
def call_rate_filter_samples(mt, thresh=0.95):
    print(">>>>> call rate filter (samples) with threshold " + str(thresh))
    mt = hl.sample_qc(mt)
    before_filter = mt.count_cols()
    n_cases_before = mt.filter_cols(mt.is_case).count_cols()
    mt = mt.filter_cols(mt.sample_qc.call_rate >= thresh)
    n_cases_after = mt.filter_cols(mt.is_case).count_cols()
    print('After call rate filter (SAMPLES) with threshold '+str(thresh)+', '+str(mt.count_cols()) + '/'+ str(before_filter)+' samples, including '+str(n_cases_after)+'/'+str(n_cases_before)+' cases, remain.')
    return mt

#### Imputed vs reported sex discrepancy filter

In [None]:
# Remove samples where reported sex differs from imputed sex
def sex_discrepancy(mt):
    print(">>>>> check reported vs imputed sex")
    mt = hl.sample_qc(mt)
    before = mt.count_cols()
    n_cases_before = mt.filter_cols(mt.is_case).count_cols()
    imputed_sex = hl.impute_sex(mt.GT)
    mt = mt.filter_cols(imputed_sex[mt.s].is_female != mt.is_female, keep=False)
    after = mt.count_cols()
    n_cases_after = mt.filter_cols(mt.is_case).count_cols()
    print("After filtering for discrepancies between imputed and reported sex "+str(after) + "/" + str(before) + " samples, including "+str(n_cases_after)+"/"+str(n_cases_before)+" cases, remain.")
    return mt

#### Heterozygosity filter
From [Hail discussion board](https://discuss.hail.is/t/filtering-samples-with-extreme-heterozygosity-in-hail/1277/5)

In [None]:
# (This takes a lot of time and sometimes breaks, consider running this separately to check if it affects sample count at all)
# Samples with extreme heterozygosity are filtered out
# removing individuals who deviate ±3 SD from the samples' heterozygosity rate mean.
def het_filter(mt):
    print(">>>>> heterozygosity filter")
    mt = hl.sample_qc(mt)
    before = mt.count_cols()
    n_cases_before = mt.filter_cols(mt.is_case).count_cols()
    mt=mt.annotate_cols(
        heterozygosity=(mt.sample_qc.n_het/mt.sample_qc.n_called), 
        inbreeding = hl.agg.inbreeding(mt.GT, mt.variant_qc.AF[1]))
    het_stats = mt.aggregate_entries(hl.agg.stats(mt.heterozygosity))
    mt.filter_cols(mt.heterozygosity > het_stats.mean + 3*het_stats.stdev, keep=False)
    mt.filter_cols(mt.heterozygosity < het_stats.mean - 3*het_stats.stdev, keep=False)
    after = mt.count_cols()
    n_cases_after = mt.filter_cols(mt.is_case).count_cols()
    print("After heterozygosity filtering "+str(after) + "/" + str(before) + " samples, including "+str(n_cases_after)+"/"+str(n_cases_before)+" cases, remain.")
    return mt

#### IBD pruning


In [None]:
# needed for removal of controls over cases
def tie_breaker(l, r):
    return hl.cond(l.is_case & ~r.is_case, -1, hl.cond(~l.is_case & r.is_case, 1, 0))

In [None]:
# Removes samples with identity by descent measure of >0.125 (third degree relatedness)
def ibd_prune(mt, removal="max_ind_set", set_min=0.125):
    print(">>>>> IBD pruning (this should happen AFTER LD pruning!)")
    mt = hl.sample_qc(mt)
    before = mt.count_cols()
    n_cases_before = mt.filter_cols(mt.is_case).count_cols()
    # create ibd table of samples with ibd higher than 0.125 (third degree relatedness)
    start = time.time()
    ibd_matrix_min = hl.identity_by_descent(mt, min=set_min)
    if removal=="max_ind_set":
        # remove the maximal independent set of samples in the table
        related_samples_to_remove = hl.maximal_independent_set(ibd_matrix_min.i, ibd_matrix_min.j, False)
        mt = mt.filter_cols(hl.is_defined(related_samples_to_remove[mt.col_key]), keep=False)
    elif removal=="cases_over_controls":
        samples = mt.cols()
        pairs_with_case = ibd_matrix_min.key_by(
            i=hl.struct(id=ibd_matrix_min.i, is_case=samples[ibd_matrix_min.i].is_case),
            j=hl.struct(id=ibd_matrix_min.j, is_case=samples[ibd_matrix_min.j].is_case))
        related_samples_to_remove = hl.maximal_independent_set(
            pairs_with_case.i, pairs_with_case.j, False, tie_breaker)
        mt = mt.filter_cols(hl.is_defined(
            related_samples_to_remove.key_by(
            s = related_samples_to_remove.node.id.s)[mt.col_key]), keep=False)
    else:
         raise Exception("Wrong option for 'removal'. Use max_ind_set or cases_over_controls")
    end = time.time()
    print("IBD pruning took "+ str(round(end-start)) + " seconds")
    after = mt.count_cols()
    n_cases_after = mt.filter_cols(mt.is_case).count_cols()
    print("After IBD pruning " + str(after) + "/" + str(before) + " samples, including " + str(n_cases_after) + "/" + str(n_cases_before) + " cases, remain.")
    return mt

# Annotation with Hail annotation database
[Hail Documentation for Annotation Database](https://hail.is/docs/0.2/annotation_database_ui.html#id1 )

Only possible on Google Cloud Platform   
Only available with latest Hail version (>0.2.34) therefore not directly compatible with VariantSpark in current version as it is dependent on Hail 0.2.16

In [None]:
# helper function
def remove_empty_unpack(mt):
    mt = mt.filter_rows(hl.len(mt.gencode.gene_name)>0)
    mt = mt.filter_rows(hl.len(mt.gencode.gene_id)>0)
    mt = mt.annotate_rows(gene_name1 = mt.gencode.gene_name[0])
    mt = mt.annotate_rows(gene_id1 = mt.gencode.gene_id[0])
    return mt

In [None]:
# use hail annotation database to annotate genes from gencode
def annotate_genes(mt, ht):
    db = hl.experimental.DB()
    mt = db.annotate_rows_db(mt,"gencode") 
    ht = mt.select_rows(mt.gencode).rows()
    #mt = mt.annotate_rows(gencode = ht[mt.locus, mt.alleles].gencode)
    # filter out SNPs without annotation and reduce annotation to one gene name
    mt = remove_empty_unpack(mt)
    return mt, ht

In [None]:
# for burden test: use only top n SNPs
def get_top_burden(mt, ht, method="GWAS", top_n=1000):    
    if method == "GWAS":
        ht = ht.order_by(hl.asc('p_value')).add_index(name='new_rank')
    elif method=="VS":
        ht = ht.order_by(hl.desc('importance')).add_index(name='new_rank')
    else:
        raise("Use method='GWAS' or method='VS'")
    ht = ht.filter(ht.new_rank<top_n, keep=True)
    ht = ht.key_by("locus", "alleles")
    mt_top = mt.semi_join_rows(ht)
    return mt, ht

# GWAS
#### Logistic regression

In [None]:
# perform logistic regression on the given matrix table
# use this for only 1 control, 1 sample group, annotated with is_case
# needs previously performed PCA annotation with (at least) 5 PCs
def gwas_log_regression_1c(mt):
    # covariates = first 5 PCs
    print("GWAS running")
    gwas_ht = hl.logistic_regression_rows(
                             test='wald',
                             y=mt.is_case,
                             x=mt.GT.n_alt_alleles(),
                             covariates=[1.0, mt.is_female, mt.scores.scores[0], mt.scores.scores[1], mt.scores.scores[2], mt.scores.scores[3], mt.scores.scores[4]])
    # order top snps by pvalue and add log10(p) as column to table
    gwas_ht = gwas_ht.annotate(log10= -hl.log10(gwas_ht.p_value)).order_by(hl.asc('p_value')).add_index(name='gwas_rank')
    # add key for being able to annotate mt
    # this erases ordering (doesn't matter since loci will be annotated anyways)
    gwas_ht = gwas_ht.key_by('locus', 'alleles')
    # annotate mt with gwas results: pval and -log10(pval)
    mt = mt.annotate_rows(log_reg_pval = gwas_ht[mt.locus, mt.alleles].p_value)
    mt = mt.annotate_rows(log_reg_log10p = gwas_ht[mt.locus, mt.alleles].log10)
    mt = mt.annotate_rows(log_reg_rank = gwas_ht[mt.locus, mt.alleles].gwas_rank)
    
    return gwas_ht, mt

In [None]:
# Prints QQ and manhattan plot
def qq_manh_plot(gwas_ht, title="Asthma", sign_threshold=5e-8):
    print(">>>>> qq and manhattan plot")
    # check if GWAS was well controlled with qq plot (assess inflation)
    p_qq = hl.plot.qq(gwas_ht.p_value, title=title)
    show(p_qq)
    # visualize significant SNPs with manhattan plot
    p_manhattan = hl.plot.manhattan(gwas_ht.p_value, title="Manhattan plot of " + title + ", logistic regression", significance_line=sign_threshold, collect_all=True)
    show(p_manhattan)

# Variant Spark
#### Random Forest

In [None]:
# Perform Variant Spark AFTER logistic regression
def vs_rf_after_gwas(mt, gwas, n_trees=5000, mTry=0.1, seed=4957):
    print(">>>>> Start VariantSpark")
    # since Variant Spark cannot build model on data with missing genotypes (throws NullPointerException): 
    # filter out all SNPs with missing genotypes
    #c1s_mt.aggregate_entries(hl.agg.stats(c1s_mt.variant_qc.call_rate))
    mt = hl.sample_qc(mt)
    mt = hl.variant_qc(mt)
    mt_no_missing_genotypes = mt.filter_rows(mt.variant_qc.call_rate == 1, keep= True)
    before = mt.count_rows()
    after = mt_no_missing_genotypes.count_rows()
    numFeatures=after
    print("Filtered " + str(before-after)+ "/" + str(before) + " variants with missing genotypes (required for VS)")

    start = time.time()
    rf_model = vshl.random_forest_model(y=mt_no_missing_genotypes.is_case,
                    x=mt_no_missing_genotypes.GT.n_alt_alleles(),
                    mtry_fraction=mTry,
                    max_depth=5,
                    min_node_size = 200,
                    seed=seed)
    rf_model.fit_trees(n_trees=n_trees, batch_size = 100) 
    end = time.time()
    print("Building VS random forest model took " + str(round(end-start)) + " seconds")
    print("rf_model.oob error="+str(rf_model.oob_error()))
    impTable = rf_model.variable_importance()
    gwas_with_imp = gwas.join(impTable)
    # add a rank to make finding top SNPs easier
    gwas_with_imp = gwas_with_imp.order_by(hl.desc('importance')).add_index(name='vs_rank')
    gwas_with_imp = gwas_with_imp.key_by('locus', 'alleles') 
    mt = mt.annotate_rows(vs_imp = gwas_with_imp[mt.locus, mt.alleles].importance) 
    mt = mt.annotate_rows(vs_rank = gwas_with_imp[mt.locus, mt.alleles].vs_rank) 
    return gwas_with_imp, mt

In [None]:
def vs_manhattan(gwas_with_imp, title="Asthma"):
    p = vshlplt.manhattan_imp(gwas_with_imp.importance, 
                            significance_line = None, title="Manhattan plot of " +title+", random forest")
    show(p)

# Burden testing
[Documentation - SKAT](https://hail.is/docs/0.2/methods/genetics.html#hail.methods.skat)
[Discussion board - Burden](https://discuss.hail.is/t/logistic-regression-burden-tests/206/3)   
Requires gene annotations with the annotation function and reduction of the table with get_top_burden   
Because of compatibility issues (GCP vs AWS, required Hail version vs current VariantSpark Hail version) this is not part of the <a href='#pipeline'>pipeline</a>, but can be performed in different notebooks/other platforms

In [None]:
def skat(mt):
    mt = hl.variant_qc(mt)
    mt = mt.annotate_rows(weight = hl.dbeta(hl.min(mt.variant_qc.AF), 1.0, 25.0) ** 2)
    skat_table = hl.skat(key_expr=mt.gene_name1,
                weight_expr=mt.weight,
                y=mt.is_case,
                x=mt.GT.n_alt_alleles(),
                covariates=[1.0, mt.is_female, mt.scores.scores[0], mt.scores.scores[1], mt.scores.scores[2], mt.scores.scores[3], mt.scores.scores[4]],
                    logistic=True,
                    iterations=100000)
    print("Fraction of results with no issues: " + str(round(skat_table.aggregate(hl.agg.fraction(skat_table.fault == 0)), 2)))
    return skat_table
# skat_table.order_by(skat_table.p_value).show(10)

In [None]:
def burden(mt):
    gene_mt = mt.group_rows_by(mt.gene_name1).aggregate(
    mac = hl.agg.sum(
        hl.cond(mt.variant_qc.AF[1] <= 0.5,
                mt.GT.n_alt_alleles(),
                2 - mt.GT.n_alt_alleles())))
    gene_mt = hl.logistic_regression_rows(
    y=gene_mt.is_case, 
    x=gene_mt.mac,
    covariates=[1.0, gene_mt.is_female, gene_mt.scores.scores[0], gene_mt.scores.scores[1], gene_mt.scores.scores[2], gene_mt.scores.scores[3], gene_mt.scores.scores[4]],
    test='wald')
    return gene_mt
# gene_mt.order_by(gene_mt.p_value).show(10)

# Evaluation of GWAS and VS Results

### Compare Variant Spark with GWAS logistic regression results
Scatter plot of sqrt(random forest importance score) vs -log10(logistic regression p-value)

In [None]:
def comparison_scatter(gwas_with_imp, title="Asthma"):
    gwas_with_imp = gwas_with_imp.annotate(sqrt_importance = hl.sqrt(gwas_with_imp["importance"]))
    scatter= hl.plot.scatter(x=gwas_with_imp["log10"], y=gwas_with_imp["sqrt_importance"],
                         xlabel= "-log10(Hail P-value)",
                         ylabel= "sqrt(VariantSpark Importance)",
                         title="Comparing GWAS with Variant Spark ("+ title + ")")
    show(scatter)

### Benjamini Hochberg 
p-value adjustment for multiple testing

In [None]:
# WARNING: can cause timeouts
# return df with variants and their respective Benjamini Hochberg adjusted pvalue

def bh_thresh(mt, gwas_ht, fdr=0.05):
    print(">>>>> significance with benjamini hochberg threshold")
    
    mt = mt.annotate_rows(log_reg_pval = gwas_ht[mt.locus, mt.alleles].p_value)
    mt = mt.filter_rows(hl.is_missing(mt.log_reg_pval), keep=False)
    mt = mt.annotate_rows(log_reg_rank = gwas_ht[mt.locus, mt.alleles].gwas_rank)
    # work around bug: instead of directly annotating bh_thresh in hail (errors), write out as pandas, read back in and calculate bh_thresh
    table = mt.rows()
    table = table.select(table.log_reg_pval, table.log_reg_rank)
    print("exporting table")
    table.export(hl_path+'test.tsv')
    print("importing tsv")
    df = pd.read_csv(hl_path+'test.tsv', sep='\t')
    print("calculate bh threshold")
    n_var = df['log_reg_pval'].count()
    df['bh_thresh'] = df.log_reg_pval < (((df.log_reg_rank+1)/n_var)*fdr)
    print("Variants passing bh threshold")
    print(df.loc[df['bh_thresh']==True][['locus','alleles','log_reg_pval','log_reg_rank','bh_thresh' ]])
    return df

### Overlaps with SNPs associated with the phenotype from GWAS catalog

In [None]:
# Convert GWAS catalog file to interval table usable with hail
# This function is untested on all but 3 GWAS catalog files (asthma, type I diabetes, height) and prone to errors

# delta is the window around the locus from GWAS catalog. So 50,000 would mean +/- 25,000 bp around locus
def gwas_catalog_to_intervals(gwas_catalog_filepath, delta = 50000):
    gwas_cat_df = pd.read_csv(gwas_catalog_filepath, sep="\t")
    gwas_cat_df.sort_values(by=['P-VALUE'], inplace = True)
    gwas_cat_df.drop(axis=1, labels=['DATE ADDED TO CATALOG', 'PUBMEDID', 'FIRST AUTHOR', 'DATE', 'JOURNAL',
       'LINK', 'STUDY', 'CONTEXT', 'INTERGENIC', 'INITIAL SAMPLE SIZE', 'REPLICATION SAMPLE SIZE',
        'UPSTREAM_GENE_ID', 'DOWNSTREAM_GENE_ID', 'UPSTREAM_GENE_DISTANCE', 'DOWNSTREAM_GENE_DISTANCE', 'MERGED', 'CNV',
        'P-VALUE (TEXT)', 'MAPPED_TRAIT_URI', 'STUDY ACCESSION', 'OR or BETA', '95% CI (TEXT)', 'PLATFORM [SNPS PASSING QC]',
       'GENOTYPING TECHNOLOGY'], inplace=True)
    gwas_cat_df["locus"] = ["chr"+str(gwas_cat_df["CHR_ID"][i])+":"+str(gwas_cat_df["CHR_POS"][i]) for i in range(gwas_cat_df["CHR_ID"].size)]
    gwas_cat_df.dropna(subset=["CHR_ID", "CHR_POS"], inplace=True)
    gwas_cat_df =gwas_cat_df.dropna(axis=1)
    gwas_cat_df.to_csv(hl_path+"gwas_catalog_loci.tsv",columns=["locus"], sep="\t", index=False)
    gwas_cat_df = gwas_cat_df.astype({'CHR_POS': 'int32'})
    interval = {}
    for i in range(len(gwas_cat_df["CHR_POS"])):
        try:
            s = "chr"+str(gwas_cat_df["CHR_ID"][i])+":"+str(gwas_cat_df["CHR_POS"][i]-delta)+"-"+str(gwas_cat_df["CHR_POS"][i]+delta)
            interval.update({i:s})
        except KeyError:
            pass
    gwas_cat_df["interval"] = pd.Series(interval)
    gwas_cat_df.dropna(subset=["interval"], inplace=True)
    gwas_cat_df.to_csv(hl_path+"gwas_catalog_intervals_"+str(delta)+"_windows.tsv",columns=["interval"] , sep="\t", index=False, header=False)
    # format for intervals that includes annotation: contig  start  end  direction  target
    gwas_cat_df["chr"] = ["chr"+str(gwas_cat_df["CHR_ID"][i]) for i in range(gwas_cat_df["CHR_ID"].size)]
    gwas_cat_df["interval_start"] = [str(gwas_cat_df["CHR_POS"][i]-delta) for i in range(gwas_cat_df["CHR_ID"].size)]
    gwas_cat_df["interval_end"] = [str(gwas_cat_df["CHR_POS"][i]+delta) for i in range(gwas_cat_df["CHR_ID"].size)]
    gwas_cat_df["dir"] = ["-" for i in range(gwas_cat_df["CHR_ID"].size)]
    gwas_cat_df.to_csv(hl_path+"gwas_catalog_intervals_"+str(delta)+"_windows_gene.tsv",columns=["chr","interval_start", "interval_end", "dir", "REPORTED GENE(S)"] , sep="\t", index=False, header=False)

In [None]:
# check if any of the the top x SNPs have overlap with GWAS catalog SNPs
def check_for_overlaps_in_top_snps(gwas_with_imp, gwas_ht, top_n=500, ref_genome='GRCh38', interval_path=hl_path+"intervals_50000.tsv"):
    print(">>>>> Check for overlaps with GWAS catalog SNPs (reported as significant by other studies) with top " + str(top_n) + " highest ranked SNPs")

    # get top_n ranked in gwas - use gwas_with_imp.rank
    print("top_n = " + str(top_n))
    top_gwas_snps = gwas_ht.filter(gwas_ht.gwas_rank < top_n, keep=True)
    print("# GWAS SNPs: " + str(top_gwas_snps.count()))
    # get top_n importance scores
    top_vs_snps = gwas_with_imp.filter(gwas_with_imp.vs_rank < top_n, keep=True)
    print("# VS SNPs: " + str(top_vs_snps.count()))

    # SNPs found with both methods
    top_both_snps = top_vs_snps.filter(top_vs_snps.gwas_rank < top_n, keep=True)
    print("# both SNPs: " + str(top_both_snps.count()))
 
    # check for overlaps with GWAS catalog 
    intervals = hl.import_locus_intervals(interval_path, reference_genome=ref_genome)
    
    # annotate and filter, then count and output
    top_gwas_snps = top_gwas_snps.annotate(region = hl.is_defined(intervals[top_gwas_snps.locus]))
    top_vs_snps = top_vs_snps.annotate(region = hl.is_defined(intervals[top_vs_snps.locus]))
    top_both_snps = top_both_snps.annotate(region = hl.is_defined(intervals[top_both_snps.locus]))
    top_gwas_snps_region =  top_gwas_snps.filter(top_gwas_snps.region)
    top_vs_snps_region =  top_vs_snps.filter(top_vs_snps.region)
    top_both_snps_region =  top_both_snps.filter(top_both_snps.region)
    print("Of top " + str(top_n) + " SNPs from GWAS " + str(top_gwas_snps_region.count()) + " were within "+str(window_size)+"bp of a GWAS catalog SNP.")
    print("Of top " + str(top_n) + " SNPs from VS " + str(top_vs_snps_region.count()) + " were within "+str(window_size)+"bp of GWAS catalog SNP.")
    print("Of SNPs in top " + str(top_n) + " with both GWAS and VS (overlap: "+str(top_both_snps.count())+") " + str(top_both_snps_region.count()) + " were within "+str(window_size)+"bp of a GWAS catalog SNP.")
    return top_gwas_snps_region, top_vs_snps_region

## Polygenic Risk Score
[Hail Documentation on PRS](https://hail.is/docs/0.2/guides/genetics.html#polygenic-risk-score-calculation)

In [None]:
# Calculate and annotate polygenic risk scores to each sample
# based on beta from logistic regression or importance score in VS
# NOTE: beta represents the direction that the phenotype is influenced in (pos./neg.) by each SNP while importance score does not
# therefore PRS with importance score is not useful. (Feature contribution would be the counterpart of beta for random forest)
def prs(mt, method="gwas"):
    mt = hl.variant_qc(mt)
    mt = mt.annotate_rows(MAF = hl.min(mt.variant_qc.AF))
    # assuming 'allele' is the alternate allele with lower AF
    mt = mt.annotate_rows(allele=(hl.cond(mt.MAF == mt.variant_qc.AF[0], mt.alleles[0], mt.alleles[1])))
    flip = hl.case().when(mt.allele == mt.alleles[0], True).when(mt.allele == mt.alleles[1], False).or_missing()
    mt = mt.annotate_rows(flip=flip)
    mt = mt.annotate_rows(prior=2 * hl.cond(mt.flip, mt.variant_qc.AF[0], mt.variant_qc.AF[1]))
    # assuming 'score' is beta from log. reg.
    if method=="gwas":
        score = mt.beta
    elif method=="vs":
        score=mt.importance
    mt = mt.annotate_cols(prs=hl.agg.sum(
        score *
            hl.coalesce(
                hl.cond(
                mt.flip, 2 - mt.GT.n_alt_alleles(),
                mt.GT.n_alt_alleles()), 
            mt.prior)))
    return mt

In [None]:
# Prediction accuracy with PRS on training and test data
# p>0.5 means more samples in training set, p=0.5 means approximately 50/50 split
# WARNING: have experienced timeouts in this function. If timeout occurs, restart kernel
def pred_accuracy_prs(mt, p=0.5):
    training_mt, test_mt = test_training_split(mt, p=0.5)
    gwas_ht, training_mt = gwas_log_regression_1c(training_mt)
    
    training_mt = training_mt.annotate_rows(beta = gwas_ht[training_mt.locus, training_mt.alleles].beta)
    training_mt = prs(training_mt, method="gwas")
    training_mt = training_mt.annotate_cols(prs_pred = training_mt.prs>0)
    # number of samples predicted correctly by PRS = number of samples where is_case and prs_pred is the same
    tmp = training_mt.filter_cols(training_mt.is_case)
    tmp = tmp.filter_cols(tmp.prs_pred)
    true_pred = tmp.count_cols()
    # true_pred + all the correct negative predictions!
    tmp = training_mt.filter_cols(~training_mt.is_case)
    tmp = tmp.filter_cols(~tmp.prs_pred)
    true_pred = true_pred + tmp.count_cols()
    false_pred = training_mt.count_cols() - true_pred
    acc = true_pred/(true_pred+false_pred)
    print("PRS accuracy on training set: " + str(acc))
    
    test_mt = test_mt.annotate_rows(beta = training_mt.index_rows(test_mt.row_key).beta)
    test_mt = prs(test_mt, method="gwas")
    test_mt = test_mt.annotate_cols(prs_pred = test_mt.prs>0)
    true_pred = test_mt.filter_cols(test_mt.prs_pred == test_mt.is_case).count_cols()
    false_pred = test_mt.count_cols() - true_pred
    acc = true_pred/(true_pred+false_pred)
    print("PRS accuracy on test set: " + str(acc))

## Prediction Accuracy with random forest on top SNPs

In [None]:
# check if MAF, Call rate, HWE is different in cases vs controls, which may confound random forest results
# (this method only makes sense when using only the TOP SNPs from VS)
def noise_pred(mt, top_snps_ht, method="GWAS"):
    mt_top = mt.semi_join_rows(top_snp_ht)
    mt_top_cases = mt_top.filter_cols(mt_top.is_case)
    mt_top_cases = hl.variant_qc(mt_top_cases)
    mt_top_controls = mt_top.filter_cols(~mt_top.is_case)
    mt_top_controls = hl.variant_qc(mt_top_controls)
    print("MAF")
    print("Cases stats")
    print(mt_top_cases.aggregate_entries(hl.agg.stats(mt_top_cases.MAF)))
    print("Controls stats")
    print(mt_top_controls.aggregate_entries(hl.agg.stats(mt_top_controls.MAF)))
    p = hl.plot.histogram(mt_top_controls.MAF, title="MAF - " + method +" controls")
    show(p)
    p = hl.plot.histogram(mt_top_cases.MAF, title="MAF - " + method +" cases")
    show(p)
    print("Call rate")
    print("Cases stats")
    print(mt_top_cases.aggregate_entries(hl.agg.stats(mt_top_cases.variant_qc.call_rate)))
    print("Controls stats")
    print(mt_top_controls.aggregate_entries(hl.agg.stats(mt_top_controls.variant_qc.call_rate)))
    print("Hardy Weinberg")
    print("Cases stats")
    print(mt_top_cases.aggregate_entries(hl.agg.stats(mt_top_cases.variant_qc.p_value_hwe)))
    print("Controls stats")
    print(mt_top_controls.aggregate_entries(hl.agg.stats(mt_top_controls.variant_qc.p_value_hwe)))

In [None]:
# split into test and training set
def test_training_split(mt, p=0.5, seed=29347):
    print("Total number of samples: " + str(mt.count_cols()))
    cases_mt = mt.filter_cols(mt.is_case)
    print("Number of cases: " + str(cases_mt.count_cols()))
    controls_mt = mt.filter_cols(~mt.is_case)
    print("Number of controls: " + str(controls_mt.count_cols()))
    print("Split with p=" + str(p))
    training_cases_mt = cases_mt.sample_cols(p=p, seed=seed)
    training_controls_mt = controls_mt.sample_cols(p=p, seed=seed)
    print("Training set contains " + str(training_cases_mt.count_cols()) + " cases and " + str(training_controls_mt.count_cols()) + " controls.")
    test_cases_mt = cases_mt.anti_join_cols(training_cases_mt.cols())
    test_controls_mt = controls_mt.anti_join_cols(training_controls_mt.cols())
    training_mt = training_cases_mt.union_cols(training_controls_mt)
    test_mt = test_cases_mt.union_cols(test_controls_mt)
    return training_mt, test_mt

In [None]:
# get the top n most significant/important SNPs
def get_top(ht, method="GWAS", top_n=500):    
    if method == "GWAS":
        ht = ht.order_by(hl.asc('p_value')).add_index(name='new_rank')
    elif method=="VS":
        ht = ht.order_by(hl.desc('importance')).add_index(name='new_rank')
    else:
        raise("Use method='GWAS' or method='VS'")
    ht = ht.filter(ht.new_rank<top_n, keep=True)
    ht = ht.key_by("locus", "alleles")
    return ht

In [None]:
# retreive labels from hail matrix table and convert to numpy array (compatibility with sklearn)
def get_labels(mt):
    # get labels as np array
    mt = mt.select_cols(mt.is_case) 
    # preserve order
    mt = mt.key_cols_by()
    t = mt.cols() 
    t.export(hl_path+"tmp/table2.csv",delimiter=',')
    pd_t = pd.read_csv(hl_path+"tmp/table2.csv")
    labels = np.array(pd_t.is_case)
    return labels

In [None]:
# retreive features from hail matrix table and convert to numpy array (compatibility with sklearn)
def get_features(mt):
    # reduce mt to only entries
    mt = mt.select_cols()
    mt = mt.select_rows()
    # make table with rows locus, alleles, sample.GT for each sample
    table = mt.make_table()
    # to_pandas (does not work for large dataframes)
    #pd_t = table.to_pandas()
    table.export(hl_path+"tmp/table1.csv",delimiter=',')
    pd_t = pd.read_csv(hl_path+"tmp/table1.csv", dtype="str")
    pd_t.fillna(value='0/0', inplace=True) 
    replace_dict = {'0/0': 0, '0/1': 1, '1/1': 2}
    pd_t = pd_t.replace(replace_dict)
    pd_t = pd_t.drop(labels= ['locus', 'alleles'], axis=1)
    pd_t = pd_t.applymap(int)
    pd_t = pd_t.transpose()
    features = np.array(pd_t.values)
    features = features.astype(int)
    return features

In [None]:
# wrapper for get_features and get_labels
def get_features_labels(mt, top_snps_ht, prnt = "get features and labels"):
    print(prnt)
    mt_top = mt.semi_join_rows(top_snps_ht)
    mt_top = replace_empty_gt(mt_top)
    labels = get_labels(mt_top)
    features = get_features(mt_top)
    print("features.shape: " + str(features.shape))
    print("len(labels): " + str(len(labels)))
    return labels, features

In [None]:
# Output accuracy, ROC curve, ROC AUC
def pred_acc_auc(rf, predictions, test_labels, test_features, method='GWAS'):
    #Accuracy = (TP + TN)/(TP + TN + FP + FN)
    false_pred = sum([abs(x) for x in test_labels - predictions])
    true_pred = len(predictions)-false_pred
    acc = true_pred/(true_pred+false_pred)
    print('Accuracy:', round(acc*100,2), '%')
    # Percentage predicted positive
    print("Predicted as case: ", round(100*sum(predictions)/len(predictions),2), '%')
    # auc
    predictions_prob = rf.predict_proba(test_features)
    roc_auc = roc_auc_score(test_labels, predictions_prob[:,1])
    print('ROC AUC score: ' + str(round(100*roc_auc,2)))
    # ROC curve
    fpr, tpr, thresholds = roc_curve(y_true = test_labels, y_score = predictions_prob[:,1], pos_label=1)
    roc_auc= auc(fpr, tpr)
    plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic ' + method)
    plt.legend(loc="lower right")
    plt.show()

In [None]:
# Perform GWAS with logistic regression and random forest with n_trees trees 
# and get prediction accuracy of a new random forest with the top_n most significant/important SNPs from GWAS.
# adjust p for training/test set size
def prediction_accuracy(mt, top_n=200, p=0.5, seed=59172, n_trees=1000):
    # step 1: separate cases and controls
    cases_mt = mt.filter_cols(mt.is_case)
    controls_mt = mt.filter_cols(~mt.is_case)
 
    # step 2: downsample with sample_cols(0.5) to form training set
    training_cases_mt = cases_mt.sample_cols(p=p, seed=seed) # keep each column with probability p
    training_controls_mt = controls_mt.sample_cols(p=p, seed=seed)
    
    # step 3: get the eids of training to exclude from mt to form test set
    test_cases_mt = cases_mt.anti_join_cols(training_cases_mt.cols())
    test_controls_mt = controls_mt.anti_join_cols(training_controls_mt.cols())
    
    # step 4: perform joins to merge cases and controls of training and test set
    training_mt = training_cases_mt.union_cols(training_controls_mt)
    test_mt = test_cases_mt.union_cols(test_controls_mt)
    
    # check number of cases and controls in training and test set
    print("Number of cases in training_mt: " + str(training_mt.filter_cols(training_mt.is_case).count_cols()))
    print("Number of controls in training_mt: " + str(training_mt.filter_cols(~training_mt.is_case).count_cols()))
    print("Number of cases in test_mt: " + str(test_mt.filter_cols(test_mt.is_case).count_cols()))
    print("Number of controls in test_mt: " + str(test_mt.filter_cols(~test_mt.is_case).count_cols()))

    
    # step 5: do GWAS and VS on the training set
    gwas_ht, training_mt = gwas_log_regression_1c(training_mt)
    gwas_with_imp, training_mt = vs_rf_after_gwas(training_mt, gwas_ht, n_trees=n_trees) # call rate == 1 filter

    # step 6: get top n SNPs to train the random forest with 
    gwas_top = get_top(gwas_ht, method="GWAS", top_n=top_n)
    vs_top = get_top(gwas_with_imp, method="VS", top_n = top_n)
    
    # step 5.5: how different are the top_n (lowest p, highest imp) SNPs
    print("Out of top_n="+str(top_n)+" SNPs, "+str(gwas_top.semi_join(vs_top).count()) + " are identical.\nNote: Discrepancy could be bc of call_rate==1 filter for VS.")
    training_mt_top_gwas = training_mt.semi_join_rows(gwas_top)
    test_mt_top_gwas = test_mt.semi_join_rows(gwas_top)
    training_mt_top_vs = training_mt.semi_join_rows(vs_top)
    test_mt_top_vs = test_mt.semi_join_rows(vs_top)
    
    # step 6.5: check noise prediction (MAF/Call rate/HWE of the top SNPs predicted by VS in cases vs controls)
    #noise_pred(training_mt, training_mt_top_vs)
    
    # step 7: get features and labels for these SNPs in training and test dataset
    training_labels_gwas, training_features_gwas = get_features_labels(training_mt_top_gwas, gwas_top, prnt="train gwas")
    print("Number of pos. labels in training gwas: " + str(sum(training_labels_gwas)))
    training_labels_vs, training_features_vs = get_features_labels(training_mt_top_vs, vs_top, prnt="train vs")
    print("Number of pos. labels in training vs: " + str(sum(training_labels_vs)))
    test_labels_gwas, test_features_gwas = get_features_labels(test_mt_top_gwas, gwas_top, prnt="test gwas")
    print("Number of pos. labels in test gwas: " + str(sum(test_labels_gwas)))
    test_labels_vs, test_features_vs = get_features_labels(test_mt_top_vs, vs_top, prnt="test vs")
    print("Number of pos. labels in test vs: " + str(sum(test_labels_vs)))
    
    # step 8: train the random forests with top snps from GWAS and VS
    rf_gwas = RandomForestClassifier(n_estimators = 1000, random_state = 472, max_features = "sqrt", max_depth=60, min_samples_split=10)
    rf_vs = RandomForestClassifier(n_estimators = 1000, random_state = 472, max_features = "sqrt", max_depth=60, min_samples_split=10)
    rf_gwas.fit(training_features_gwas, training_labels_gwas);
    rf_vs.fit(training_features_vs, training_labels_vs);
    
    # step 9: predict the test labels
    gwas_predictions = rf_gwas.predict(test_features_gwas)
    gwas_self_predictions = rf_gwas.predict(training_features_gwas)
    vs_predictions = rf_vs.predict(test_features_vs)
    vs_self_predictions = rf_vs.predict(training_features_vs)
    
    # step 10: accuracy, ROC AUC
    print("GWAS Accuracy - On training data")
    pred_acc_auc(rf=rf_gwas, predictions=gwas_self_predictions, test_labels=training_labels_gwas, test_features=training_features_gwas, method='GWAS')
    print("GWAS Accuracy - On test data")
    pred_acc_auc(rf=rf_gwas, predictions=gwas_predictions, test_labels=test_labels_gwas, test_features=test_features_gwas, method='GWAS')
    #pred_acc_auc(rf_gwas, gwas_predictions, test_labels_gwas, test_features_gwas, method='GWAS')
    print("VS Accuracy - On training data")
    pred_acc_auc(rf=rf_vs, predictions=vs_self_predictions, test_labels= training_labels_vs, test_features=training_features_vs, method='VS')
    print("VS Accuracy - On test data")
    pred_acc_auc(rf=rf_vs, predictions=vs_predictions, test_labels= test_labels_vs, test_features=test_features_vs, method='VS')
    
    # step 11: look at probabilities 
    print("Plot proba on GWAS")
    plot_proba(rf_gwas, test_features_gwas, training_features_gwas, test_labels_gwas, training_labels_gwas)
    print("Plot proba on VS")
    plot_proba(rf_vs, test_features_vs, training_features_vs, test_labels_vs, training_labels_vs)
    
    # step 12: plot a tree
    # Extract single tree
    estimator = rf_vs.estimators_[10]
    fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (25,4), dpi=300)
    
    tree.plot_tree(estimator,
    #           feature_names = fn, 
    #           class_names=cn,
               filled = True,
                  fontsize=5);

## Grid search for hyperparameter settings
To optimize hyperparameters of the prediction random forest, set parameters in the param_grid in my_grid_search   
To test the improvement the best_param settings bring in comparison to default parameters (or previous param settings) use evaluate_new_params    
Uses [Scikit Learn RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)

In [None]:
def my_grid_search(mt, features, labels):
    param_grid = {
        'bootstrap': [False],
        'max_depth': [10, 40, 90],
        'max_features': [0.3, 0.5, 0.7], 
        'min_samples_split': [2, 5, 10],
        'n_estimators': [1000, 2500] 
        }
    rf = RandomForestClassifier()
    grid_search = GridSearchCV(estimator = rf, param_grid = param_grid, 
                          cv = 3, n_jobs = -1, verbose = 2)
    rf_random = RandomizedSearchCV(estimator = rf, param_distributions = param_grid, n_iter = 100, cv = 2, verbose=2, random_state=42, n_jobs = -1)
    rf_random.fit(features, labels)
    print("best param")
    best_param = rf_random.best_params_
    print(best_param)
    print("best score: " + str(round(rf_random.best_score_, 4)))
    cv_res = pd.DataFrame.from_dict(rf_random.cv_results_)
    cv_res.to_csv("rf_cv_results.csv", index=False)
    return best_param

In [None]:
def evaluate_new_params(best_params, features, labels):
    X = features
    y = labels
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
    
    #X_train, X_test, y_train, y_test 
    base_model = RandomForestClassifier(n_estimators = 10, random_state = 42)
    base_model.fit(X_train, y_train)
    print("base model")
    base_accuracy = evaluate(base_model, X_test, y_test)

    #best_random = rf_random.best_estimator_
    best_random = RandomForestClassifier(**best_params) 
    best_random.fit(X_train, y_train)
    print("best param")
    random_accuracy = evaluate(best_random, X_test, y_test)

    print('Improvement of {:0.2f}%.'.format( 100 * (random_accuracy - base_accuracy) / base_accuracy))

## Get list of genes 
i.e. for [GO analysis](http://geneontology.org/)   
Circumvents need for GCP Hail annotation database   

In [None]:
# WARNING: this function is extremely slow! use low top_n
def get_gene_list(gwas_ht, top_n = 50):
    gwas_ht_top = gwas_ht.filter(gwas_ht.gwas_rank < top_n) 
    mv = myvariant.MyVariantInfo()
    gwas_df_top = gwas_ht_top.to_pandas()
    gwas_df_top.rename(columns={"locus.contig": "CHR", "locus.position": "BP"}, inplace=True)
    gwas_df_top["A1"] = [gwas_df_top["alleles"][i][0] for i in range(gwas_df_top.shape[0])]
    gwas_df_top["A2"] = [gwas_df_top["alleles"][i][1] for i in range(gwas_df_top.shape[0])]
    gwas_df_top.drop("alleles", axis=1, inplace= True)
    gwas_df_top["hgvs"] = [gwas_df_top["CHR"][i] + ":g." + str(gwas_df_top["BP"][i]) + gwas_df_top["A1"][i] + ">" + gwas_df_top["A2"][i] for i in range(gwas_df_top.shape[0])]
    start = time.time()
    l = []
    # note: this is very slow, but the lookup with a list throws an internal server error
    # mv.getvariants(gwas_df_top["hgvs"], assembly="hg38",fields="dbsnp.rsid", as_dataframe=True) "error ID required"
    for h in gwas_df_top['hgvs']:
        #print(h)
        q = mv.getvariant(h, assembly="hg38", fields="dbsnp.gene")
        if q==None:
            l.append("")
        else:
            try:
                l.append(q["dbsnp"]["gene"]["symbol"])
            except:
                l.append("")
    gwas_df_top["gene"] = l
    end = time.time()
    print("For " + str(top_n) + " variants annotation gene took " + str(round(end-start)) + " seconds")
    return l

-----------

<a id='pipeline'></a>
# Pipeline


## Load data

In [None]:
if only_plink = True:
    mt = plink_to_hl(plink_paths, hl_path)
else:
    mt = hl.read_matrix_table(hl_path)

## Quality control
Set parameters of QC functions here

In [None]:
# Variant QC 1
# generally recommended to filter variants before filtering samples
print("variant qc I")
mt = hl.variant_qc(mt)
mt = single_np(mt)
mt = call_rate_filter_variants(mt, thresh=0.95)
mt = maf_filter(mt, thresh=0.01)
mt = ld_prune(mt, r2=0.2, window_size=500000) 

# Sample QC
print("sample qc")
mt = hl.sample_qc(mt)
mt = sex_discrepancy(mt)
mt = call_rate_filter_samples(mt, thresh=0.95)
mt = ibd_prune(mt, removal="max_ind_set", set_min=0.125)
mt = het_filter(mt)

# Variant QC 2
print("variant qc II")
mt = autosomes(mt) # filter to only autosomes after sex_discrepancy filter (uses gynosomes)
mt =  hwe_filter(mt, case_control="control_sep") # HWE filter after LD pruning

In [None]:
## Write out QCed matrix table
mt.write(hl_path+name+"_qc.mt")
# mt = hl.read_matrix_table(hl_path+name+"_qc.mt")

## Principal component analysis
Gives insight into population structure (ie ancestry)   
__Note__: GWAS uses first 5 PCs, no option to alter this at a high level in this script yet

In [None]:
if PCA_visual:
    # check out the first few principal components of PCA on this data
    # this is not necessary, but useful to get an impression of the data
    mt = choose_pcs(mt)
else:
    # only annotate PCs 
    # k must be >=5 for use with current GWAS logistic regression implementation
    eigenvalues, scores, loadings = hl.hwe_normalized_pca(mt.GT, k=5)
    mt = mt.annotate_cols(scores = scores[mt.s])

### Save

In [None]:
# save mt with PCs
mt.write(hl_path+name+"_qc_pca.mt")
# mt = hl.read_matrix_table(hl_path+name+"_qc_pca.mt")

## GWAS and VariantSpark
Set parameters for VS here

### GWAS: logistic regression

In [None]:
gwas_ht, mt = gwas_log_regression_1c(mt)
# Set phenotype/name of dataset as title to produce manhattan plots with caption: "Manhattan plot of " +title+", logistic regression"
qq_manh_plot(gwas_ht, title="Asthma", sign_threshold=5e-8)

### Variant Spark: random forest

In [None]:
gwas_with_imp, mt = vs_rf_after_gwas(mt, gwas, n_trees=5000, mTry=0.1, seed=4957)
# Set phenotype/name of dataset as title to produce manhattan plots with caption: "Manhattan plot of " +title+", random forest"
vs_manhattan(gwas_with_imp, title="Asthma")

### Save

In [None]:
# save hail table gwas_ht from GWAS and gwas_with_imp from VS
gwas_ht.write(hl_path+"results/"+pheno+"_gwas.ht")
gwas_with_imp_ht.write(hl_path+"results/"+pheno+"_gwas_"+str(ntrees)+"trees.ht")
# save mt with annotations log_reg_pval, log_reg_log10p, log_reg_rank from GWAS and  vs_imp, vs_rank from VS
# note that the mt now only contains variants with call_rate==1. This means log_reg_rank is probably incorrect
mt.write(hl_path+name+"_qc_pca_gwas_vs.mt")

# Evaluate results

### Comparison scatter plot

In [None]:
# on x axis: -log10(logistic regression P-value)
# on y axis: sqrt(VariantSpark Importance score)
# title="Comparing GWAS with Variant Spark ("+ title + ")"
comparison_scatter(gwas_with_imp, title="Asthma")

### Benjamini Hochberg threshold

In [None]:
# prints variants that pass BH threshold 
# returns pandas table with all variants and their BH adjusted pvalues
df = bh_thresh(mt, gwas_ht, fdr=0.05)

### Check for differences in cases and controls of top significant SNPs predicted by VariantSpark
Plots MAF, gives stats about call rate and HWE

In [None]:
top_n = 500
top_vs_snps = gwas_with_imp.filter(gwas_with_imp.vs_rank < top_n, keep=True)
noise_pred(mt, top_vs_snps, method="VS")

### GWAS catalog overlap
Optional, please provide path to file from GWAS catalog in "Parameters"   
(On [GWAS catalog](https://www.ebi.ac.uk/gwas/) search for phenotype (trait) and click Download Catalog data)   

In [None]:
# returns tables with top SNPs that overlap with a GWAs catalog SNP
# size of window around GWAS catalog SNPs to overlap with
delta = 50000
if gwas_catalog_overlap:
    gwas_catalog_to_intervals(gwas_catalog_filepath, delta = delta)
    top_gwas_snps_region, top_vs_snps_region = check_for_overlaps_in_top_snps(gwas_with_imp, gwas_ht, top_n=500, ref_genome=ref_genome, interval_path=hl_path+"gwas_catalog_intervals_"+str(delta)+"_windows.tsv"):

## Prediction accuracy

### Polygenic Risk Score

In [None]:
mt = mt.annotate_rows(beta = gwas_ht[mt.locus, mt.alleles].beta)
mt = mt.annotate_rows(importance = gwas_with_imp[mt.locus, mt.alleles].importance)
mt = prs(mt, method="gwas")

In [None]:
# Scatter plot of PRS score with cases=1, controls=0
# Histograms would make more sense, but cause hail timeout exception
p = hl.plot.scatter(mt.is_case, mt.prs, xlabel='is_case', ylabel='PRS GWAS beta')
show(p)

In [None]:
# Report prediction accuracy of PRS with test and training data
# p=0.5 means 50/50 split, higher p means more samples in training set
pred_accuracy_prs(mt, p=0.5)

### Random forest trained with the most significant SNPs from logistic regression and VS

In [None]:
# top_n = number of SNPs predicted to be most significant to use for training
# n_trees is number of trees in sklearn tree
prediction_accuracy(mt, top_n=200, p=0.5, n_trees=1000)

## Genes of top SNPs

In [None]:
# returns list of genes from myvariant.info annotation
# WARNING: SLOW!
l = get_gene_list(gwas_ht, top_n = 50)
# unique genes
list(set(l)) 