- Purpose:
    - annotate the table of variant calls from the GATK results and prepare it for downstream analysis
    - infer strand and adjust sequences based on coding strand
    - add per sample base data for reference and alt bases and convert GT calls to binary
    - cross reference variant entries to ClinVar
    - calculate %_ref for all samples in all entries
    - calculate %_snp if the target SNP is detected in a sample in a given entry
- Outputs
    1. a compressed tsv of the fully annotated variant table
    2. compressed tsvs with
        - binary genotype counts for each sample
        - stand assignment counts 
        - histogram tables for %_ref, %_snp, read depth, QD, and qual
    3. an excel sheet compiling all data from 2
    4. a tsv file of the entry for the editing target

In [None]:
#####################
# import statements #
#####################
import os
import pandas as pd
import numpy as np
import tqdm as tqdm
from Bio.Seq import Seq
import pysam
import math
from functools import reduce

In [None]:
##########################
# User-Defined Variables #
##########################
# - define all variables below with paths to the required files
# - this should be the only cell that requires modification

# full path to tsv output from GATK analysis
tsv_path =''
# full path to gtf for the reference genome
gtfgz_path = ''
# full path to clinvar vcf file
clinvar_vcf_path=''

# these must match the sample identifiers in the column titles of the variants tsv
samples = [
    'R9761',
    'R9762',
    'R9763',
    'R9764',
    'R9765',
    'R9766',
]

# target SNPs as a list of stings for example we are interested in C-to-T (C-to-U) so ['CT']
target_snps = ['CT']
# (chrom, pos) in the format of the genomic reference
target_edit = ('chr15', 64162933)

In [None]:
# automatically defined variables
work_path = os.path.split(tsv_path)[0]
out_path = os.path.join(work_path, 'init-processing')

# make directory that will hold all results
os.makedirs(out_path, exist_ok=True)

In [None]:
# import dataset

def import_VariantsToTable_tsv(path, samples):
    ###################################################################################################
    # Purpose: import variant table output from GATK, standardize datatypes and adjust column headers #
    # Inputs: 1. path - string, path of GATK output tsv file                                          #
    #         2. samples - list of strings with sample names                                          #
    # Output: formatted pandas dataframe created from the tsv data                                    #
    ###################################################################################################

    dtype_dict = {'CHROM':'str', 'POS':'int', 'REF':'str', 'ALT':'str', 'QUAL':'float', 'FILTER':'str'}
    for sample in samples:
        dtype_dict[f'{sample}.GT'] = 'str'
        dtype_dict[f'{sample}.AD'] = 'str'
        dtype_dict[f'{sample}.DP'] = 'float'
        dtype_dict[f'{sample}.GQ'] = 'float'
    df = pd.read_csv(path, sep='\t', dtype=dtype_dict, compression='infer')
    rename_dict = {}
    for column in df.columns:
        if '.' not in column:
            rename_dict[column] = column.lower()
    df = df.rename(columns=rename_dict)

    return df


var_df = import_VariantsToTable_tsv(tsv_path, samples=samples)

# convert nan genotypes to a string representation
gt_cols = [f'{sample}.GT' for sample in samples]
for col in gt_cols:
    var_df[col] = var_df[col].apply(lambda x: './.' if pd.isna(x) else x)

# display dataframe info
total_count = len(var_df)
print(f'Total Entries:\t{total_count}\n')
display(var_df.head())

dtypes_df = pd.DataFrame({'dtype':['object', 'str', 'float', 'int']})
for column in var_df.columns:
    dtype_counts = [0 for i in range(len(dtypes_df))]
    dtype_totals = var_df[column].apply(lambda x: str(type(x))).value_counts()
    for counted_dtype in dtype_totals.index:
        for dtype_idx, dtype_row in dtypes_df.iterrows():
            target_dtype = dtype_row['dtype']
            if target_dtype in counted_dtype:
                dtype_counts[dtype_idx] += dtype_totals[counted_dtype]
    dtypes_df_col = pd.DataFrame({f'{column}:{var_df[column].dtype}':dtype_counts})
    dtypes_df = pd.concat([dtypes_df, dtypes_df_col], axis=1)
display(dtypes_df)

In [None]:
# infer entry strand and correct sequences to reflect coding strand

def strand_id(var_row, gtfgz, feat_priority):
    ###################################################################################################
    # Purpose: take variant information from a row of the variant dataframe and obtain relevant       #
    #          feature information from the reference gtf                                             #
    # Inputs: 1. var_row - a subset of the dataframe row passed with pd.apply(axis=1)                 #
    #         2. gtfgz - gtfgz file imported with pysam.TabixFile()                                   #
    #         3. feat_priority - a list of feature types in priority order for assigning strands (ie. #
    #                            if feat_priority=['exon','transcript'] and a variant location has    #
    #                            transcript features in both strands but exon features in only one,   #
    #                            the exon strand will be assigned)                                    #
    # Output: values for strand, gene_name, gene_id, transcript_id, exon_id, and warning notes        #
    ###################################################################################################
    
    # the warning notes will be kept as a list then combined into a string at the end
    warning_str = [var_row['warnings']] if (var_row['warnings'] != '') else []
    try:
        chrom = var_row['chrom']
        query_start = var_row['pos']
        query_end = var_row['pos'] + 1 

        # iterate through feature types in priority order, look for features at the variant position
        # if features of that type are found in only one strand, that strand will be assigned if they
        # are in both strands it will be assigned 'ambiguous'. Other information is also obtaned accordingly
        # id and named info are stored as matched comma-separated values so that multiple hits can be accurately
        # assessed later
        for target_feature in feat_priority:
            target_rows = []
            for gtfgz_row in gtfgz.fetch(chrom, query_start, query_end):
                gtfgz_row_list = gtfgz_row.strip().split('\t')
                feat_type = gtfgz_row_list[2].strip()
                if feat_type == target_feature:
                    target_rows.append(gtfgz_row_list)
        
            if len(target_rows) > 0:
                if len(target_rows) > 1:
                    warning_str.append(f'overlapping {target_feature}s')
                    
                    strand_set = list(set(gftgz_row[6] for gftgz_row in target_rows))
                    if len(strand_set) == 0:
                        print(f'[WARNING] {target_feature}s identified but no strands identified')
                    elif len(strand_set) > 1:
                        strand = 'ambiguous'
                    else:
                        strand = strand_set[0]
                else:
                    strand_set = list(set(gftgz_row[6] for gftgz_row in target_rows))
                    strand = strand_set[0]

                if target_feature in ['exon']:
                    exon_id = ','.join(list((gftgz_row[8].split('exon_id')[1].split('"')[1].strip()) for gftgz_row in target_rows))
                else:
                    exon_id = ''

                if target_feature in ['exon', 'transcript']:
                    transcript_id = ','.join(list((gftgz_row[8].split('transcript_id')[1].split('"')[1].strip()) for gftgz_row in target_rows))
                else:
                    transcript_id = ''
                    

                gene_name = ','.join(list((gftgz_row[8].split('gene_name')[1].split('"')[1].strip()) for gftgz_row in target_rows))
                gene_id = ','.join(list((gftgz_row[8].split('gene_id')[1].split('"')[1].strip()) for gftgz_row in target_rows))

                return strand, gene_name, gene_id, transcript_id, exon_id, ','.join(warning_str)

            else:
                warning_str.append(f'no {target_feature}s')

        warning_str.append('feature search failed')
        return *['' for i in range(5)], ','.join(warning_str)
    except ValueError:
        warning_str.append('no features')
        return *['' for i in range(5)], ','.join(warning_str)


def strand_sequences(var_row, samples):
    ###################################################################################################
    # Purpose: adjust variant and GT sequences if the coding strand is -                              #
    # Inputs: 1. var_row - a subset of the dataframe row passed with pd.apply(axis=1)                 #
    #         2. samples - a list of sample names                                                     #
    # Output: values for ref, alt and all GT columns with appropriate sequences                       #
    ###################################################################################################
    
    # if the coding strand is + simply return the existing sequence data unchanged
    if var_row['strand'] != '-':
        return var_row['ref'], var_row['alt'], *[var_row[f'{sample}.GT'] for sample in samples]
    elif var_row['strand'] == '-':
        init_ref = var_row['ref']
        init_alt = var_row['alt']

        # get the reverse complement sequences for ref and alt, check for multiple values in each and handle that if necessary
        rt_list = []
        for alt in [init_ref, init_alt]: # I know the variable naming here is confusing, just roll with it
            if ('|' not in alt) and (',' not in alt):
                alt_rc = str(Seq(alt).reverse_complement()).strip()
            elif ('|' in alt) and (',' in alt):
                alt_list = alt.split('|')
                alt_list = [single_alt.strip() for single_alt in alt_list]

                for i in range(len(alt_list)):
                    if ',' in alt_list[i]:
                        alt_sublist = alt_list[i].split(',')
                        alt_sublist = [single_alt.strip() for single_alt in alt_sublist]
                        new_sublist = []
                        for single_alt in alt_sublist:
                            new_sublist.append(str(Seq(single_alt).reverse_complement()).strip())
                        new_sublist = ','.join(new_sublist)
                        alt_list[i] = new_sublist
                    else:
                        alt_list[i] = str(Seq(alt_list[i]).reverse_complement().strip())

                alt_rc = '|'.join(alt_list)
            else:
                alt_delim = '|' if ('|' in alt) else ','

                alt_list = alt.split(alt_delim)
                alt_list = [single_alt.strip() for single_alt in alt_list]
                alt_list = [(str(Seq(single_alt).reverse_complement().strip())) for single_alt in alt_list]
                
                alt_rc = alt_delim.join(alt_list)
                
            rt_list.append(alt_rc)

        # get the reverse complement sequences for GT values
        gt_rcs = []
        for sample in samples:
            gt = var_row[f'{sample}.GT']
            if not pd.isna(gt):    
                if '/' in gt:
                    gt_delim = '/'
                    alls = gt.split(gt_delim)
                elif '|' in gt:
                    gt_delim = '|'
                    alls = gt.split(gt_delim)
                else:
                    print(gt)
                gt_rc = []
                for all in alls:
                    gt_rc.append(str(Seq(all).reverse_complement().strip()))
                gt_rc = gt_delim.join(gt_rc)
                gt_rcs.append(gt_rc)
            else:
                gt_rc = './.'
                gt_rcs.append(gt_rc)
            
        return rt_list[0], rt_list[1], *gt_rcs



init_dot_gt_counts = [(var_df[f'{sample}.GT'].value_counts()['./.'] if ('./.' in var_df[f'{sample}.GT'].value_counts().index) else 0) for sample in samples]

gtfgz = pysam.TabixFile(gtfgz_path)

# initialize feature annotation columns
feature_cols = ['strand', 'gene_name', 'gene_id', 'transcript_id', 'exon_id', 'warnings']
for feature_col in feature_cols:
    var_df[feature_col] = ''

# add var_df columns with strand and annotation data
feat_priority = ['exon', 'transcript', 'gene']
var_df[feature_cols] = var_df[['chrom', 'pos', 'warnings']].apply(strand_id, gtfgz=gtfgz, feat_priority=feat_priority, axis=1, result_type='expand')

# if the coding strand is assigned - adjust the sequence identity of the variant and sample genotypes
seq_cols = ['ref', 'alt'] + [f'{sample}.GT' for sample in samples]
var_df[seq_cols] = var_df[(['ref', 'alt', 'strand']+ [f'{sample}.GT' for sample in samples])].apply(strand_sequences, axis=1, samples=samples, result_type='expand')

# display the number of uncalled genotypes for each sample and the new columns
for i in range(len(samples)):
    print(f'Initial {samples[i]} ./. GTs:{init_dot_gt_counts[i]}')
print('\n')
display_cols = ['chrom', 'pos'] + feature_cols
display(var_df[display_cols])

In [None]:
# check that all feature id columns have been populated with matched lists
# if there are mismatches the rows will be printed below

def mask_func(var_row):
    ###################################################################################################
    # Purpose: check that feature id lists are the same length and therefor corresponding             #
    # Inputs: 1. var_row - a subset of the dataframe row passed with pd.apply(axis=1)                 #
    # Output: True if the lengths are matched and false if they are not, used to build a boolean mask #
    ###################################################################################################
    count_list = []
    if var_row['gene_id'] != '':
        count_list.append(len(var_row['gene_id'].split(',')))
    if var_row['transcript_id'] != '':
        count_list.append(len(var_row['transcript_id'].split(',')))
    if var_row['exon_id'] != '':
        count_list.append(len(var_row['exon_id'].split(',')))

    counts_equal = True
    for i in range(1,len(count_list)):
        if count_list[i-1] != count_list[i]:
            counts_equal = False

    if not counts_equal:
        print(f'\n{var_row['gene_id'].split(',')}\t{len(var_row['gene_id'].split(','))}')
        print(f'\n{var_row['transcript_id'].split(',')}\t{len(var_row['transcript_id'].split(','))}')
        print(f'\n{var_row['exon_id'].split(',')}\t{len(var_row['exon_id'].split(','))}\n')

    return counts_equal

display_mask = var_df[['gene_id','transcript_id','exon_id']].apply(mask_func, axis=1)
display(var_df.loc[~display_mask, display_cols])


In [None]:
# display stranding information for the dataset

print(f'Total Hits: {len(var_df)}\n')
value_counts = var_df['strand'].value_counts()
strand_counts = pd.DataFrame({
    'strand':value_counts.index,
    'count':value_counts.values
})
display(strand_counts)

value_counts = pd.Series(zip(var_df['strand'], var_df['warnings'])).value_counts()
strand_warning_counts = pd.DataFrame({
    'strand, warnings':value_counts.index,
    'count':value_counts.values
})
display(strand_warning_counts)

In [None]:
# tabulate per sample nucleotides
def id_sample_nts(var_row, sample, all_samples):
    ###################################################################################################
    # Purpose: add sample specific ref and alt columns                                                #
    # Inputs: 1. var_row - a subset of the dataframe row passed with pd.apply(axis=1)                 #
    #         2. sample - the relevant sample                                                         #
    #         3. all_samples - a list of all samples                                                  #
    # Output: for each sample, values for the reference nucleotide and genotype nucleotides 1 and 2   #
    ###################################################################################################
    if var_row[f'{sample}.GT'].strip() == './.':
        return '','',''
    else:
        init_ref = var_row['ref'].strip()
        init_gt = var_row[f'{sample}.GT'].strip()
        
        gt_delim = '|' if ('|' in init_gt) else '/'
        sample_alls = init_gt.split(gt_delim)
        sample_alls = [sample_all.strip() for sample_all in sample_alls]

        if (',' not in init_ref) and ('|' not in init_ref):
            return init_ref, sample_alls[0], sample_alls[1]
        elif ('|' in init_ref) or (',' in init_ref):
            grouped_refs = init_ref.split('|')
            grouped_refs = [refs.strip() for refs in grouped_refs]
            grouped_refs = [refs.split(',') for refs in grouped_refs]

            if len(grouped_refs) == len(all_samples):
                sample_ref = grouped_refs[all_samples.index(sample)]
                sample_ref = ','.join(sample_ref)
                return sample_ref, sample_alls[0], sample_alls[1]
            
            refs_list = []
            for ref_list in grouped_refs:
                for ref in ref_list:
                    refs_list.append(ref.strip())
            
            init_alt = var_row['alt'].strip()
            alts_list = init_alt.split('|')
            alts_list = [alt.strip() for alt in alts_list]
            alts_list = ','.join(alts_list)
            alts_list = alts_list.split(',')
            alts_list = [alt.strip() for alt in alts_list]
            
            ref_alt_overlap = set(refs_list).intersection(set(alts_list))
            if len(ref_alt_overlap) > 0:
                var_samples = []
                for samp in all_samples:
                    if var_row[f'{samp}.GT'] != './.':
                        var_samples.append(samp)

                if len(grouped_refs) == len(var_samples):
                    sample_ref = grouped_refs[var_samples.index(sample)]
                    sample_ref = ','.join(sample_ref)
                elif (sample_alls[0] not in ref_alt_overlap) or (sample_alls[1] not in ref_alt_overlap):
                    sample_ref = ''
                    for sample_all in sample_alls:
                        if sample_all not in ref_alt_overlap:
                            if sample_all in refs_list:
                                if (sample_ref != '') and (sample_ref != sample_all):
                                    print(f'[WARNING] Conflicting refs:\tindex:{var_row.name}ref:{init_ref}\talt:{init_alt}\tGT:{init_gt}')
                                    return '', sample_alls[0], sample_alls[1]
                                sample_ref = sample_all
                else:
                    sample_ref = ''

            else:
                sample_ref = ''
                for sample_all in sample_alls:
                    if sample_all in refs_list:
                        if (sample_ref != '') and (sample_ref != sample_all):
                            print(f'[WARNING] Conflicting refs:\tindex:{var_row.name}ref:{init_ref}\talt:{init_alt}\tGT:{init_gt}')
                            return '', sample_alls[0], sample_alls[1]
                        sample_ref = sample_all
            return sample_ref, sample_alls[0], sample_alls[1]
        

        

for sample in samples:
    new_nt_cols = [f'{sample}_ref', f'{sample}_all_1', f'{sample}_all_2']
    for new_nt_col in new_nt_cols:
        var_df[new_nt_col] = ''
    var_df[new_nt_cols] = var_df[['ref', 'alt'] + [f'{sample}.GT' for sample in samples]].apply(id_sample_nts, sample=sample, axis=1, result_type='expand', all_samples=samples)

# for print output
display_cols = ['ref', 'alt']
for sample in samples:
    display_cols = display_cols + [f'{sample}.GT', f'{sample}_ref', f'{sample}_all_1', f'{sample}_all_2'] 
print('var_df')
display(var_df[display_cols].head())

# non ./. genotyped entries without nucleotide assignments will be displayed at the bottom of the output, the table should be empty
sample_masks = []
entry_count = 0
for sample in samples:
    mask = ((var_df[f'{sample}.GT'] != './.') & (var_df[f'{sample}_ref'] == ''))
    sample_masks.append(mask)
    entry_count += len(var_df[mask])
print(f'Non ./. Sample x Entries that are Unassigned ({entry_count}, {round((entry_count/(len(var_df)*len(samples)))*100,2)}%):')
display(var_df.loc[reduce(lambda x,y: (x | y), sample_masks), display_cols].head())

In [None]:
# cross reference variants to ClinVar and add ClinVar information to the table
def check_in_clinvar(var_row, clinvar_vcf, samples):
    ###################################################################################################
    # Purpose: add columns with ClinVar data to each variant entry                                    #
    # Inputs: 1. var_row - a subset of the dataframe row passed with pd.apply(axis=1)                 #
    #         2. clinvar_vcf - ClinVar vcf file read with pysam.VariantFile()                         #
    #         3. samples - a list of all samples                                                      #
    # Output: 1. in_clinvar - boolean of whether the variant appears in clinvar                       #
    #         2. clinvar_ids - all clinvar ids associated with the variant                            #
    #         3. clinvar_sig - the variant significance as annotated in clinvar                       #
    #         4. clinvar_disease - diseases associated with the variant in clinvar                    #
    ###################################################################################################

    chrom = var_row['chrom']
    pos = var_row['pos']
    strand = var_row['strand']

    in_clinvar = False
    clinvar_ids = set()
    clinvar_sig = set()
    clinvar_disease = []

    for sample in samples:
        ref = var_row[f'{sample}_ref']
        sample_alts = set()
        for sample_all in [var_row[f'{sample}_all_1'], var_row[f'{sample}_all_2']]:
            if (sample_all != ref):
                sample_alts.add(sample_all)

        # clinvar entries corespond to the + strand so the nucleotide identities need to be reversed if the variant was assigned to the - strand
        if strand == '-':
            ref = str(Seq(ref).reverse_complement())
            sample_alts = set(str(Seq(alt).reverse_complement()) for alt in sample_alts)

        # check whether the variant is in clinvar and pull the related clinvar information if so
        try:
            for clinvar_rec in clinvar_vcf.fetch(chrom, pos-1, pos): # fetch in 0 index vcf is 1 index
                if (len(sample_alts) > 0) and (clinvar_rec.alts != None):
                    if (clinvar_rec.alleles[0] == ref) and (len(sample_alts.intersection(set(clinvar_rec.alts))) > 0):
                        in_clinvar = True
                        clinvar_ids.add(clinvar_rec.id)
                        if 'CLNSIG' in clinvar_rec.info.keys():
                            clinvar_sig.add(','.join(clinvar_rec.info['CLNSIG']))
                        if 'CLNDN' in clinvar_rec.info.keys():
                            clinvar_disease.append(','.join(clinvar_rec.info['CLNDN']))

        except ValueError:
            pass

    clinvar_ids = ','.join(clinvar_ids)
    clinvar_sig = ','.join(clinvar_sig)
    clinvar_disease = ','.join(clinvar_disease)

    return in_clinvar, clinvar_ids, clinvar_sig, clinvar_disease



clinvar_vcf = pysam.VariantFile(clinvar_vcf_path, 'r')
var_df[['in_clinvar', 'clinvar_ids', 'clinvar_sig', 'clinvar_disease']] = var_df.apply(check_in_clinvar, clinvar_vcf=clinvar_vcf, samples=samples, axis=1, result_type='expand')

# print aggregated value counts for how many entries appear in ClinVar with a given significance
# and the head of the dataframe with new ClinVar columns
print(f'Total Entries in ClinVar:\t{len(var_df[var_df['in_clinvar']])}')
display(var_df.loc[var_df['in_clinvar'], 'clinvar_sig'].value_counts())
display(var_df.loc[var_df['in_clinvar'], ['chrom', 'pos', 'ref', 'alt', 'clinvar_ids', 'clinvar_sig', 'clinvar_disease']].head())

In [None]:
# convert GTs
def get_bin_gt(var_row, sample):
    ###################################################################################################
    # Purpose: convert all genotypes to binary genotyping (0 for ref 1 for alt)                       #
    # Inputs: 1. var_row - a subset of the dataframe row passed with pd.apply(axis=1)                 #
    #         2. sample - the sample for which the genotype should be converted                       #
    # Output: The binary genotype value for the specified sample                                      #
    ###################################################################################################
    
    sample_gt = var_row[f'{sample}.GT'].strip()
    if sample_gt == './.':
        return './.'
    else:
        sample_ref = var_row[f'{sample}_ref'].strip()

        gt_delim = ''
        for delim in ['/', '|']:
            gt_delim = delim if (delim in sample_gt) else gt_delim
        
        sample_alls = sample_gt.split(gt_delim)
        sample_alls = [sample_all.strip() for sample_all in sample_alls]
        bin_gt = ['.', '.']

        for i in range(len(sample_alls)):
            bin_gt[i] = str(0) if (sample_alls[i] == sample_ref) else str(1)
        
        return gt_delim.join(bin_gt)

def get_value_count_col(count_index, value_col_name, target_col, df_slice):
    ###################################################################################################
    # Purpose: return a 1 column dataframe that can be used to build a value count table across       #
    #           multiple columns with the same set of values                                          #
    # Inputs: 1. count_idex - list of values to be counted                                            #
    #         2. value_col_name - column header for the output column                                 #
    #         3. target_col - name of the column whose values will be counted                         #
    #         4. df_slice - dataframe with the target column                                          #
    # Output: A dataframe with one column that is titled value_col_name and contains counts, in order #
    #         for the values in count_index                                                           #
    ###################################################################################################

    df_counts = df_slice[target_col].value_counts()
    df_count_list = []

    for count_idx in count_index:
        if count_idx in df_counts.index:
            df_count_list.append(int(df_counts[count_idx]))
        else:
            df_count_list.append(0)
            
    return pd.DataFrame({value_col_name:df_count_list})

for sample in samples:
    gt_col = f'{sample}.GT'
    input_cols = [gt_col, f'{sample}_ref']
    var_df[gt_col] = var_df[input_cols].apply(get_bin_gt, sample=sample, axis=1)

# display aggregate binary GT counts
gt_counts_df = pd.DataFrame({'binary_gt':['1/1', '1/0', '0/1', '0/0', './.']})
for sample in samples:
    gt_col = f'{sample}.GT'
    simple_gt_col = pd.DataFrame(var_df[gt_col].apply(lambda x: x.replace('|', '/')))
    val_count_col = get_value_count_col(value_col_name=gt_col, count_index=list(gt_counts_df['binary_gt']), target_col=gt_col, df_slice=simple_gt_col)
    gt_counts_df = pd.concat([gt_counts_df, val_count_col], axis=1)
display(gt_counts_df)

In [None]:
# calculate % ref values for each sample in all entries and %snp values for samples matching the target snp
def calc_pct_ref(ad_str):
    ###################################################################################################
    # Purpose: calculate the % of reads matching the reference nucleotide for a given AD value        #
    # Inputs: 1. ad_str - an AD value                                                                 #
    # Output: a %_ref value which is nan if the ad value is also nan                                  #
    ###################################################################################################
    
    if not pd.isna(ad_str):
        ads_list = ad_str.strip().split(',')
        ads_list = [int(var_ad) for var_ad in ads_list]
        total_ad = sum(ads_list)
        if total_ad > 0:
            ref_ad = ads_list[0]
            pct_ref = round(float((ref_ad/total_ad)*100), 2)
        else:
            pct_ref = np.nan

        return pct_ref
    else:
        return np.nan

def calc_pct_snp(var_row, sample, target_snps):
    ###################################################################################################
    # Purpose: calculate the % of reads matching the target SNP if the sample shows the target SNP    #
    # Inputs: 1. var_row - a subset of the dataframe row passed with pd.apply(axis=1)                 #
    #         2. sample - the sample for which to target %_snp                                        #
    #         3. target_snps - a list of the target SNPs                                              #
    # Output: a %_snp value, this will be nan if the sample does not have a variant that matches a    #
    #         target SNP                                                                              #
    ###################################################################################################

    ad_str = var_row[ad_col]
    sample_ref = var_row[f'{sample}_ref']
    if (not pd.isna(ad_str)) and (sample_ref != ''):
        all_1 = var_row[f'{sample}_all_1']
        all_2 = var_row[f'{sample}_all_2']

        for target_snp in target_snps:
            if ((sample_ref + all_1) == target_snp) or ((sample_ref + all_2) == target_snp):
                ads_list = ad_str.strip().split(',')
                ads_list = [int(var_ad) for var_ad in ads_list]
                total_ad = sum(ads_list)

                if total_ad > 0:
                    if len(ads_list) == 2:
                        pct_snp = round(float((ads_list[1]/total_ad)*100),2)
                    # if there are more than 2 alt nucleotides for a given sample it is ignored, this generally occurs for a miniscule fraction of sample entries
                    elif len(ads_list) > 3:
                        pct_snp = np.nan
                    # if there are 3 ADs check whether the alts match or not and calculate accordingly
                    else:
                        if all_1 == all_2:
                            pct_snp = round(float((sum(ads_list[1:])/total_ad)*100),2)
                        else:
                            pct_snp = round(float((ads_list[([all_1, all_2].index(target_snp[1]) + 1)]/total_ad)*100),2)
                else:
                    pct_snp = np.nan
            else:
                pct_snp = np.nan
    else:
        pct_snp = np.nan

    return pct_snp


for sample in samples:
    ad_col = f'{sample}.AD'
    var_df[f'{sample}_pct_ref'] = var_df[ad_col].apply(calc_pct_ref)
    input_cols = [ad_col, f'{sample}_ref', f'{sample}_all_1', f'{sample}_all_2']
    var_df[f'{sample}_pct_snp'] = var_df[input_cols].apply(calc_pct_snp, sample=sample, target_snps=target_snps, axis=1)

In [None]:
# Survey and check %snp calculations

# print total var_df entries and entries*samples and counts for entries where all or some samples have %_snp values
total_count = len(var_df) * len(samples)
pct_snp_masks = []
pct_snp_count = 0
for sample in samples:
    mask = var_df[f'{sample}_pct_snp'].apply(lambda x: not pd.isna(x))
    pct_snp_masks.append(mask)
    pct_snp_count += mask.value_counts()[True]

all_pct_snp_mask = reduce(lambda x,y: (x&y), pct_snp_masks)
any_pct_snp_mask = reduce(lambda x,y: (x|y), pct_snp_masks)

print(f'Total Entries * Samples: {total_count}')
print(f'Total Assigned pct_snp: {pct_snp_count}\n')
print(f'Total Entries: {len(var_df)}')
print(f'Entries with all samples pct_snp: {len(var_df[all_pct_snp_mask])}')
print(f'Entries with not all samples pct_snp: {len(var_df[any_pct_snp_mask]) - len(var_df[all_pct_snp_mask])}\n')

# print some var_df rows with all samples assigned %_snp and some var_df rows with only some samples assigned %_snp for visual inspection
display_cols = ['chrom', 'pos', 'ref', 'alt']
for sample in samples:
    display_cols = display_cols + [f'{sample}.GT', f'{sample}.AD', f'{sample}_ref', f'{sample}_all_1', f'{sample}_all_2', f'{sample}_pct_snp']
print('All samples assigned pct_snp:')
display(var_df.loc[all_pct_snp_mask, display_cols]) 
print('Some samples assigned pct_snp:')
display(var_df.loc[any_pct_snp_mask, display_cols]) 

In [None]:
def hist_table(var_df, target_cols, bucket_lims, mode):
    ###################################################################################################
    # Purpose: create a histogram table of values in multiple dataframe columns                       #
    # Inputs: 1. var_df - the dataframe holding variant data                                          #
    #         2. target_cols - the columns whose values should be counted for the table               #
    #         3. bucket_lims - a list of numerical limits for the histogram buckets                   #
    #         4. mode - if mode == 'log' the beckets will be made on log scale and values will be     #
    #                   incremented by 1 so that values of 0 are on scale                             #
    # Output: a dataframe containing the histogram table                                              #
    ###################################################################################################
    if ('log' in mode) and bucket_lims[0] == 0:
        bucket_lims[0] = 1
    
    buckets = []
    for i in range(len(bucket_lims[1:])):
        buckets.append(f'{bucket_lims[i]}-{bucket_lims[i+1]}') 

    hist_df = pd.DataFrame({
        f'bucket':buckets
    })

    for target_col in target_cols:
        bucket_counts = []
        for i in range(len(buckets)):
            bucket_counts.append(0)

        out_of_range_low_count = 0
        out_of_range_high_count = 0
        for var_idx, var_row in tqdm.tqdm(var_df.iterrows(), total=len(var_df)):
            metric_val = var_row[target_col]
            if not pd.isna(metric_val):
                if metric_val >= bucket_lims[-1]:
                    metric_val = bucket_lims[-1]*0.999
                    out_of_range_high_count += 1
                elif metric_val < bucket_lims[0]:
                    metric_val = bucket_lims[0]
                    out_of_range_low_count += 1

                if mode == 'scalar':
                    bucket_idx = math.floor(metric_val/(bucket_lims[-1]/len(buckets)))
                if 'log' in mode:
                    log_n = int(mode.split('log')[1])
                    bucket_idx = math.floor(math.log(metric_val, log_n))
                try:
                    bucket_counts[bucket_idx] += 1
                except IndexError:
                    print(f'{bucket_lims}\tMetric:{target_col}\tValue:{metric_val}\tBucket:{bucket_idx}')
                    
        hist_row = pd.DataFrame({f'{target_col}':bucket_counts})
        hist_df = pd.concat([hist_df, hist_row], axis=1)
    
    return hist_df

In [None]:
# create histogram tables for %_ref, %_snp and read depth
target_cols = [f'{sample}_pct_ref' for sample in samples]
num_buckets = 10
bucket_lims = list(range(0,101,int(100/num_buckets)))
pct_hist_df = hist_table(mode='scalar', var_df=var_df, target_cols=target_cols, bucket_lims=bucket_lims)

target_cols = target_cols + [f'{sample}_pct_snp' for sample in samples]
pct_snp_hist_cols = pd.DataFrame()
for sample in samples:
    mask = var_df[f'{sample}_pct_snp'].apply(lambda x: not pd.isna(x))
    new_cols = var_df.loc[mask, [f'{sample}_pct_ref', f'{sample}_pct_snp']]
    pct_snp_hist_cols = pd.concat([pct_snp_hist_cols, new_cols], axis=1)
pct_snp_hist_df = hist_table(mode='scalar', var_df=pct_snp_hist_cols, target_cols=target_cols, bucket_lims=bucket_lims)

target_cols = [f'{sample}.DP' for sample in samples] + [f'{sample}.GQ' for sample in samples] + ['qual']
num_buckets = 21
bucket_lims = list(2**i for i in range(num_buckets+1))
dp_hist_df = hist_table(mode='log2', var_df=var_df, target_cols=target_cols, bucket_lims=bucket_lims)

display(pct_hist_df)
print('\n')
display(pct_snp_hist_df)
print('\n')
display(dp_hist_df)

In [None]:
# output var_df as a compressed tsv
file_name = os.path.split(tsv_path)[1].split('.tsv')[0]
out_tsv_path = os.path.join(out_path, f'{file_name}-proc.tsv.gz')
var_df.to_csv(out_tsv_path, sep='\t', index=False, compression='gzip')

# output various metrics tables as compressed tsvs
export_dfs = {
    'bin-gt-counts':gt_counts_df, 
    'strand-counts':strand_counts,
    'strand-warning-counts':strand_warning_counts,
    # 'ref-rc-counts':nt_counts,
    # 'gt-rc-counts':gt_rc_counts,
    'pct-hist':pct_hist_df,
    'pct-snp-hist':pct_snp_hist_df,
    'dp-hist':dp_hist_df
    }

for suffix, export_df in export_dfs.items():
    out_tsv_path = os.path.join(out_path, f'{file_name}-{suffix}.tsv.gz')
    export_df.to_csv(out_tsv_path, sep='\t', index=False, compression='gzip')

# compile metrics tables into an excel output
excel_path = os.path.join(out_path, f'{file_name}-summary-stats.xlsx')
with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    for sheet_name, export_df in export_dfs.items():
        export_df.to_excel(writer, sheet_name=sheet_name, index=False)

# print target edit %_ref and %_snp and output full var_df row for the target edit to a tsv file
display_col_list = []
for sample in samples:
    display_col_list.append(f'{sample}_pct_ref')
    display_col_list.append(f'{sample}_pct_snp')

display_mask = (var_df['chrom'] == target_edit[0]) & (var_df['pos'] == target_edit[1])
display(var_df.loc[display_mask,display_cols])

out_tsv_path = os.path.join(out_path, f'tgt-edit.tsv')
var_df[display_mask].to_csv(out_tsv_path, sep='\t', index=False)
