In [None]:
import hail as hl
from hail.plot import show
from pprint import pprint

hl.init(default_reference = "GRCh38", min_block_size=128, 
        spark_conf={'spark.driver.memory': '40g', 'spark.task.maxFailures': '20', 'spark.master': 'local[20,20]'})

In [None]:
# folder with raw VCFs
data_dir = '~/WGS/BRAVA/raw_data_exome_f3/'  # raw VCFs
mt_dir = '~/WGS/BRAVA/mt/' #MTs converted from VCF
checkpoint_dir = '~/WGS/BRAVA/checkpoint/' # folder with intermediate files (filtered by GT and variant)
filtered_dir = '~/WGS/BRAVA/filtered/' # final filtered QC MT

# Loading resources
# Import the interval lists for the target intervals.
target_intervals = hl.import_locus_intervals('~/WGS/BRAVA/Exome_QC_Hail/resources/twist_comprehensive_exome_hg38_one_based.txt', reference_genome='GRCh38')
# Import the interval lists for the padded target intervals.
padded_target_intervals = hl.import_locus_intervals('~/WGS/BRAVA/Exome_QC_Hail/resources/twist_comprehensive_exome_hg38__one_based_50bp_padded.txt', reference_genome='GRCh38')
# Import the interval lists for the LCRs.
LCR_intervals = hl.import_locus_intervals('~/WGS/BRAVA/Exome_QC_Hail/resources/LCR-hs38.bed', reference_genome='GRCh38')


In [None]:
# chromosome[s) to run
chrom = ['6', '7', '8', '9', '10', '11', '12', '13', '18', 'X']

In [None]:
# converting VCF chunks to chromosome MTs 
# files are bgzipped but the extension is .gz, therefore, using force_bgz=True
# there are some multialleleic sites the are malformed, therefore using skip_invalid_loci=True
# array_elements_required=False is to import variants with missing data

import glob, re
for chr in chrom:

    #outputs
    MT = mt_dir + 'chr' + chr + '.mt' # raw MT imports from VCFs
    
    # getting file paths for chromosome chunks
    chr_files = glob.glob(data_dir + '*_chr' + chr +'_*')
    chr_files.sort()


    # removing index files from the list of chunks
    chr_files = [item for item in chr_files if not re.search(r"csi", item)]

    hl.import_vcf(chr_files, reference_genome='GRCh38', force_bgz=True, find_replace=('nul', '.'),
              skip_invalid_loci=True,
              array_elements_required=False).write(MT, overwrite=True)



In [None]:
chr_files

In [None]:
# annotating variants in padded target intervals 
# removing multialleleic sites > 6
# splitting multialleleic sites 
# Ffiltering by genotype quality, repartitioning and saving

for chr in chrom:

    # inputs
    MT = mt_dir + 'chr' + chr + '.mt' # raw MT imports from VCFs

    # outputs
    MT_GT_MULTI = checkpoint_dir + 'chr' + chr + '.GT_multi.mt' # MT after filtering by genotype and splitting MA sites
    MT_GT_MULTI_HARDCALLS = checkpoint_dir + 'chr' + chr + '.GT_multi.hardcalls.mt' # Genotypes only
    
    mt = hl.read_matrix_table(MT)

    # 1. Initial count.
    n = mt.count()
    print('Chromosome ' + chr +": " + str(n))

    # 2. Annotate variants that are not in padded target interval.
    mt = mt.annotate_rows(not_in_padded_target_intervals = ~hl.is_defined(padded_target_intervals[mt.locus]))

    # Get information about the number of variants that were excluded.
    not_in_padded_target_intervals = mt.filter_rows(mt.not_in_padded_target_intervals).count_rows()

    print('')
    print('Chromosome ' + chr + ': n variants not in padded target intervals:')
    pprint(not_in_padded_target_intervals)

    print('')
    print('Chromosome ' + chr + ':% of variants not in padded target intervals marked:')
    pprint(not_in_padded_target_intervals/n[0]*100)

    # 3. Removing multiallelic sites with > 6 alleles.
    n = mt.count_rows()
    pprint('')
    pprint('Chromosome ' + chr + ': All variants: ' + str(n))

    mt = mt.filter_rows(mt.alleles.length() <= 6)

    n = mt.count_rows()
    pprint('')
    pprint('Chromosome ' + chr + ': n variants not more than 6 alleles:' + str(n))
    print(n)

    # 4. Splitting multiallelic sites
    mt = hl.split_multi_hts(mt)
    
    # 5. Filtering by genotype quality, repartitioning and saving
    
    mt = mt.filter_entries(
        hl.is_defined(mt.GT) &
        (
            (mt.GT.is_hom_ref() & 
                (
                    # ((mt.AD[0] / mt.DP) < 0.8) | # Has to be removed because allele depth no longer defined for hom ref calls.
                    (mt.GQ < 20) |
                    (mt.DP < 10)
                )
            ) |
            (mt.GT.is_het() & 
                ( 
                    (((mt.AD[0] + mt.AD[1]) / mt.DP) < 0.8) | 
                    ((mt.AD[1] / mt.DP) < 0.2) | 
                    (mt.PL[0] < 20) |
                    (mt.DP < 10)
                )
            ) |
            (mt.GT.is_hom_var() & 
                (
                    ((mt.AD[1] / mt.DP) < 0.8) |
                    (mt.PL[0] < 20) |
                    (mt.DP < 10)
                )
            )
        ),
        keep = False
    )


    # saving full filtere and split matrix table with genotypes
    mt = mt.checkpoint(MT_GT_MULTI, overwrite=True)

    # saving just genotypes (hard calls)
    mt = hl.read_matrix_table(MT_GT_MULTI)
    mt.select_entries(mt.GT).repartition(512).write(MT_GT_MULTI_HARDCALLS, overwrite=True)

    mt = hl.read_matrix_table(MT_GT_MULTI_HARDCALLS)
    n = mt.count()


    pprint('Chromosome ' + chr + ' n samples:' + str(n[1]))
    pprint('Chromosome ' + chr + ' n variants:' + str(n[0]))
    

In [None]:
for chr in chrom:

    # inputs
    MT = mt_dir + 'chr' + chr + '.mt' # raw MT imports from VCFs
    MT_GT_MULTI = checkpoint_dir + 'chr' + chr + '.GT_multi.mt' # MT after filtering by genotype and splitting MA sites
    MT_GT_MULTI_HARDCALLS = checkpoint_dir + 'chr' + chr + '.GT_multi.hardcalls.mt' # Genotypes only

    # outputs
    INITIAL_VARIANT_QC_FILE  = './prefilter_metrics/chr' + chr + '_prefilter_metrics.tsv'
    INITIAL_VARIANT_LIST = './prefilter_metrics/chr' + chr + '.keep.variant_list'
    # INITIAL_SAMPLE_QC_FILE = 'gs://hail-brava/prefilter_metrics/COLORADO_Freeze_Two.chr' + chr + '.initial_sample_qc.tsv'

    FILTERED_MT = checkpoint_dir + 'chr' + chr + '.GT_multi.variant.filtered.mt'




    
    # 1. Loading genotype hardcalls
    mt = hl.read_matrix_table(MT_GT_MULTI_HARDCALLS)
    
    # 2. Counting the number of variants before filtering by variant
    n = mt.count()
    print('Chromosome ' + chr +" before filtering by variant: " + str(n))

    
    # 3. Annotating variants with flag indicating if they are in LCR or failed VQSR.
    mt = mt.annotate_rows(fail_VQSR = (hl.len(mt.filters) != 0) & ~hl.is_missing(mt.filters))
    mt = mt.annotate_rows(in_LCR = hl.is_defined(LCR_intervals[mt.locus]))
    mt = mt.annotate_rows(not_in_target_intervals = ~hl.is_defined(target_intervals[mt.locus]))
    mt = mt.annotate_rows(not_in_padded_target_intervals = ~hl.is_defined(padded_target_intervals[mt.locus]))

    # Get information about the number of variants that will be excluded.
    fail_VQSR = mt.filter_rows(mt.fail_VQSR).count_rows()
    in_LCR = mt.filter_rows(mt.in_LCR).count_rows()
    not_in_target_intervals = mt.filter_rows(mt.not_in_target_intervals).count_rows()
    not_in_padded_target_intervals = mt.filter_rows(mt.not_in_padded_target_intervals).count_rows()

    print('Chromosome ' + chr + ': n variants failing VQSR: ' + str(fail_VQSR))
    print('Chromosome ' + chr + ': n variants in low complexity regions:' + str(in_LCR))
    print('Chromosome ' + chr + ': n variants not in target intervals:' + str(not_in_target_intervals))
    print('Chromosome ' + chr + ': n variants not in padded target intervals:' + str(not_in_padded_target_intervals))
    
    
    # 4. Variant filtering.
    # removing failed VQSR, variants in LCR and variant outside of padded target intervals
    mt_rows = mt.rows()
    mt_rows.select(mt_rows.fail_VQSR, mt_rows.in_LCR, mt_rows.not_in_padded_target_intervals).export(INITIAL_VARIANT_QC_FILE)
    mt = mt.filter_rows(mt.fail_VQSR | mt.in_LCR | mt.not_in_padded_target_intervals, keep=False)

    # removing variants not in 24 canonical chromosomes
    intervals = [hl.parse_locus_interval(x, reference_genome='GRCh38') for x in ['chr1:START-chr22:END', 'chrX:START-chrX:END', 'chrY:START-chrY:END']]
    mt = hl.filter_intervals(mt, intervals)

    # removing invariant rows.
    mt = hl.variant_qc(mt, name='qc')
    mt = mt.filter_rows((mt.qc.AF[0] > 0.0) & (mt.qc.AF[0] < 1.0))

    # saving list of variants that passed QC
    mt_rows_filter = mt.rows().select().export(INITIAL_VARIANT_LIST)

    n_variants = hl.import_table(INITIAL_VARIANT_LIST).count()

    print('Chromosome ' + chr + ': n variants after filter:' + str(n_variants))
   
    # saving filtered mt for chromosome
    mt.write(FILTERED_MT, overwrite = True)
    

In [None]:
# generating sample QC metrics by chromosome

for chr in chrom:
    #inputs
    MT_GT_MULTI = checkpoint_dir + 'chr' + chr + '.GT_multi.mt'
    INITIAL_VARIANT_LIST = './prefilter_metrics/chr' + chr + '.keep.variant_list'

    #outputs
    INITIAL_SAMPLE_QC_FILE = './sample_QC/chr' + chr + '_initial_sample_qc.tsv'

    # sample QC is done on the full MT that has varianr level QC data
    mt = hl.read_matrix_table(MT_GT_MULTI)

    # loading list of variants that passed QC
    variants_to_filter = hl.import_table(INITIAL_VARIANT_LIST,
    	types={'locus':hl.tlocus(reference_genome='GRCh38'), 'alleles':hl.tarray(hl.tstr)})
    variants_to_filter = variants_to_filter.key_by(locus=variants_to_filter.locus, alleles=variants_to_filter.alleles)
    
    
    mt = mt.filter_rows(hl.is_defined(variants_to_filter[mt.row_key]))
    
    n = mt.count()
    pprint('n samples:')
    print(n[1])
    pprint('n variants:')
    print(n[0])

    
    mt = hl.sample_qc(mt, name='qc_padded_twist')
    
    mt = mt.annotate_rows(not_in_target_intervals = ~hl.is_defined(target_intervals[mt.locus]))
    mt = mt.filter_rows(mt.not_in_target_intervals, keep=False)
    
    n = mt.count()
    
    pprint('n samples:')
    print(n[1])
    pprint('n variants:')
    print(n[0])
    
    mt = hl.sample_qc(mt, name='qc_twist')
    
    mt.cols().select('qc_padded_twist', 'qc_twist').flatten().export(output=INITIAL_SAMPLE_QC_FILE)

In [None]:
hl.spark_context()