- Purpose
    - filter annotated variant table for final off-tgt hits
    - off target hits defined as passing the following filters:
        1. var    - sample was called as a variant by HapplotypeCaller
        2. VOI    - sample variant matches the variant of interest
        3. DP     - read depth >= 20
        4. GQ     - depth by quality >=20 (corresponds to 99% confidence)
        5. non-wt - pct_ref in the matched wt sample is >99%, meaning the SNP was introduced by editing
    - strategy similar to Huang, X. et al. Programmable C-to-U RNA editing using the human APOBEC3A deaminase. The EMBO Journal 39, e104741 (2020). 
- Outputs
    - a subdirectory for each condition containing tsvs of entries that passed each filter
    - a tsv containing a count matrix of number of entries that passed each filter
    - a tsv containing identity and %_snp data for each off-tgt hit

In [None]:
#####################
# import statements #
#####################

import os
import pandas as pd
import numpy as np
import tqdm as tqdm
from Bio.Seq import Seq
import pysam
import math
from functools import reduce

In [None]:
##########################
# User-Defined Variables #
##########################
# - define all variables below with paths to the required files
# - this should be the only cell that requires modification

# full path to compressed tsv from initial processing of the variant table
tsv_path =''
# full path to the sample map tsv, should contain 'sample', 'condition', and 'rep' columns mapping unique sample identifiers to a biological condidtion and biological replicate
sample_map_path = ''
# full path to reference genome gtf
gtfgz_path = ''
# full path to clinvar vcf
clinvar_vcf_path=''

# number of biological replicates
replicates = 3
# reference condition
wt_condition = '01_transfection.control'
# filter_strat = 'CURE-by-rep'

# target snp to be analyzed in this notebook (ref, alt)
target_snp = ('C','T')

In [None]:
# make output directory 'off-tgt-analysis/{replicates}-reps' in the same directory as the variant table tsv
out_dir = os.path.join(os.path.split(tsv_path)[0], 'off-tgt-analysis')
os.makedirs(out_dir, exist_ok=True)
out_dir = os.path.join(out_dir, f'{replicates}-reps')
os.makedirs(out_dir, exist_ok=True)

# add a replicate column to the sample map, display the sample map for inspection
sample_map_df = pd.read_csv(sample_map_path, sep='\t')

display(sample_map_df)

In [None]:
# read in var_df
var_df = pd.read_csv(tsv_path, sep='\t', compression='infer', low_memory=False)

#print a list of all var_df columns and the head of var_df
print(f'\tcolumns:\n')
max_col_len = 0
line = []
for var_col in var_df.columns:
    if len(var_col) > max_col_len:
        max_col_len = len(var_col)
for var_col in var_df.columns:
    while len(var_col) < max_col_len:
        var_col = var_col + ' '
    line.append(var_col)
    if len(line) > 2:
        print(f'\t{'\t'.join(line)}')
        line = []
if len(line) > 0:
    print(f'\t{'\t'.join(line)}')
print(f'\n')
display(var_df.head())

In [None]:
# def generate_filter_list(var_df, sample_map, replicates, wt_condition, target_snp, filter_mode):
def generate_filter_list(var_df, sample_map, replicates, wt_condition, target_snp):
    ###################################################################################################
    # Purpose: filter variant table for likely off targets, filtering strategy similar to CURE paper  #
    #          and all filtering is done within matched replicates. Filtering summary:                #
    #          1. var    - sample was called as a variant by HapplotypeCaller                         #
    #          2. VOI    - sample variant matches the variant of interest                             #
    #          3. DP     - read depth >= 20                                                           #
    #          4. GQ     - depth by quality >=20 (corresponds to 99% confidence)                      #
    #          5. non-wt - pct_ref in the matched wt sample is >99%, meaning the SNP was introduced   #
    #                      by editing                                                                 #
    # Inputs: 1. var_df - the annotated varaints table                                                #
    #         2. sample_map - the dataframe matching sample identifiers to biological conditions      #
    #         3. replicates - the number of replicates                                                #
    #         4. wt_condition - the reference biological condition                                    #
    #         5. target_snp - the snp being analyzed in this notebook                                 #
    # Output: a list of filter names and a dictionary with the following structure, every sample id   #
    #         in the map is a key:                                                                    #
    #         sample_filters                                                                          #
    #         |_ key: <sample_id>                                                                     #
    #            |_value: dictionary of filter masks for <sample id>                                  #
    #              |_key: <filter name>                                                               #
    #                |_value: a filter mask for var_df matching <filter name>                         # 
    ###################################################################################################
    
    # generate list of all samples, will match columns of the var df
    # sample_list = []
    # for bio_condition in sample_map['condition'].unique():
    #     for rep in range(1,replicates+1):
    #         search_mask = sample_map['condition'] == bio_condition
    #         sample_list.append((bio_condition, rep, sample_map.loc[search_mask, 'sample'].iloc[rep-1])) # list consists of tuples (bio_condition, rep, sample)

    sample_filters = {}
    # for bio_condition, rep, sample in sample_list:
    for map_idx, map_row in sample_map_df.iterrows():
        # bio_condition = map_row['condition']
        rep = map_row['rep']
        sample = map_row['sample']

        sample_filters[sample] = {}
        
        # var filter components
        filter_list = [
            var_df[f'{sample}.GT'] != '0/0',
            var_df[f'{sample}.GT'] != '0|0',
            var_df[f'{sample}.GT'] != './.',
            var_df[f'{sample}.GT'] != '.|.',
            ~(var_df[f'{sample}.GT'].apply(pd.isna)),
        ]
        sample_filters[sample]['var'] = reduce(lambda x,y: x&y, filter_list)

        # VOI filter components
        filter_list = [
            var_df[f'{sample}_ref'] == target_snp[0],
            (var_df[f'{sample}_all_1'] == target_snp[1]) | (var_df[f'{sample}_all_2'] == target_snp[1])
        ]
        sample_filters[sample]['VOI'] = reduce(lambda x,y: x&y, filter_list)

        # DP and GP filters
        sample_filters[sample]['DP'] = var_df[f'{sample}.DP'] >= 20
        sample_filters[sample]['GQ'] = var_df[f'{sample}.GQ'] >= 20

        # non-wt filter

        search_mask = sample_map['condition'] == wt_condition
        search_mask = search_mask & (sample_map['rep'] == rep)
        wt_sample = sample_map.loc[search_mask, 'sample'].iloc[0]
        sample_filters[sample]['non-wt'] = var_df[f'{wt_sample}_pct_ref'] > 99

    sample_filter_names = [
        'var',
        'VOI',
        'DP',
        'GQ',
        'non-wt',
    ]

    return sample_filters , sample_filter_names

In [None]:
# build summary table showing number of entries passing each sucessive filter
# sample_filters_dict, sample_filter_names = generate_filter_list(var_df=var_df, sample_map=sample_map_df, replicates=replicates, wt_condition=wt_condition, target_snp=target_snp, filter_mode=filter_strat)
sample_filters_dict, sample_filter_names = generate_filter_list(var_df=var_df, sample_map=sample_map_df, replicates=replicates, wt_condition=wt_condition, target_snp=target_snp)

filter_counts_df = pd.DataFrame()

# add 'subset' column to filter counts_df that name every filtered subset
subset_list = ['total_entries']
for bio_condition in sample_map_df['condition'].unique():
    total_filt_name = []
    for filt_name in sample_filter_names:
        total_filt_name.append(filt_name)
        subset_list.append(f'{bio_condition}_{'_'.join(total_filt_name)}')
filter_counts_df['subset'] = subset_list

# for each biological replicate tabulate the hits after each filter then add the column of counts to filter_counts_df
for rep in range(1, replicates + 1):
    count_col_list = []
    count_col_list.append(len(var_df))

    for bio_condition in sample_map_df['condition'].unique():
        sample_mask = sample_map_df['condition'] == bio_condition
        sample = sample_map_df.loc[sample_mask, 'sample'].iloc[rep-1]
        sample_filters = sample_filters_dict[sample]
        
        total_filter = []
        for filt_name, filt_mask in sample_filters.items():
            total_filter.append(filt_mask)
            count_col_list.append(len(var_df[reduce(lambda x,y: x&y, total_filter)]))

    filter_counts_df[f'r{rep}'] = count_col_list

# the last column in filter_counts_df will hold counts of var_df entries where all three replicates passed a given filter
count_col_list = []
count_col_list.append(len(var_df))

rep_count_df = pd.DataFrame()
rep_passing_list = list(range(1,replicates + 1))
rep_count_df['replicates_passed'] = rep_passing_list

for bio_condition in sample_map_df['condition'].unique():
    # iterating through indices of the filter list i indicates the index of the last filter that should be combined to get the next filter step
    for i in range(len(sample_filter_names)):
        # will store the combined filters for each of the three replicates for the given condition
        rep_filters = []
        filter_name = []
        for rep in range(1, replicates + 1):
            # combines the names of all filters being applied
            filter_name = [key for key in list(sample_filters.keys())[:i+1]]
            filter_name = '_'.join(filter_name)

            # determines the relevant sample for the given biological condition and biological replicate
            sample_mask = sample_map_df['condition'] == bio_condition
            sample_mask = sample_mask & (sample_map_df['rep'] == rep)
            sample = sample_map_df.loc[sample_mask, 'sample'].iloc[0]

            # dictionary of filters for the sample
            sample_filters = sample_filters_dict[sample]
            # combined filter for the sample and filtering step i is added to rep_filters
            sample_filter = [sample_filters[key] for key in list(sample_filters.keys())[:i+1]]
            rep_filters.append(reduce(lambda x,y: x&y, sample_filter))

        # add a count of how many entries passed all replicate filters
        count_col_list.append(len(var_df[reduce(lambda x,y: x&y, rep_filters)]))
        
filter_counts_df['intersection'] = count_col_list
display(filter_counts_df)

In [None]:
# for each biological condition and each filter output a compressed tsv of all entries passing the filter these will be grouped into subdirectories for each condition
# the count matrix will be output as a tsv in the top level output directory

for bio_condition in sample_map_df['condition'].unique():    
    print(f'{bio_condition}:\n')

    bio_condition_out_dir = os.path.join(out_dir, bio_condition)
    os.makedirs(bio_condition_out_dir, exist_ok=True)

    for i in range(1, len(sample_filter_names) + 1):
        filter_names_list = sample_filter_names[:i]
        combined_filter_name = '_'.join(filter_names_list)

        rep_filters = []
        for rep in range(1,replicates + 1):
            sample_mask = sample_map_df['condition'] == bio_condition
            sample = sample_map_df.loc[sample_mask, 'sample'].iloc[rep-1]
            sample_filter = []
            for filter_name in filter_names_list:
                sample_filter.append(sample_filters_dict[sample][filter_name])
            rep_filters.append(reduce(lambda x,y: x&y, sample_filter))
        
        combined_filter = reduce(lambda x,y: x&y, rep_filters)

        if len(var_df[combined_filter]) > 0:
            out_tsv_path = os.path.join(bio_condition_out_dir, f'{bio_condition}-{combined_filter_name}.tsv.gz')
            var_df[combined_filter].to_csv(out_tsv_path, sep='\t', float_format='%.2f', compression='gzip')

        print(f'\tFilter:     \t{combined_filter_name}')
        print(f'\ttsv entries:\t{len(var_df[combined_filter])}\n')
out_tsv_path = os.path.join(out_dir, 'filter-counts.tsv')
filter_counts_df.to_csv(out_tsv_path, sep='\t', index=False, float_format='%.2f')

In [None]:
# output a tsv containing each of the final off-tgt hits and the mean pct_snp for each non wt condition to the output directory

merge_cols = ['chrom', 'pos', 'ref', 'alt', 'gene_name']
final_hits_df = pd.DataFrame(columns=merge_cols)
final_filt_name = '_'.join(sample_filter_names)

for bio_condition in sample_map_df['condition'].unique():
    if bio_condition != wt_condition:
        rep_filter_list = []
        rep_sample_list = []
        for rep in range(1,replicates + 1):
            sample_mask = sample_map_df['condition'] == bio_condition
            sample = sample_map_df.loc[sample_mask, 'sample'].iloc[rep-1]
            rep_sample_list.append(sample)

            sample_filter_list = []
            for filter_name in sample_filter_names:
                sample_filter_list.append(sample_filters_dict[sample][filter_name])

            rep_filter_list.append(reduce(lambda x,y: x&y, sample_filter_list))

        combined_filter = reduce(lambda x,y: x&y, rep_filter_list)

        pct_snp_col_list = [f'{sample}_pct_snp' for sample in rep_sample_list]
        merge_df = var_df.loc[combined_filter, merge_cols]
        mean_pct_snp_col = pd.Series()
        
        for pct_snp_col in pct_snp_col_list:
            merge_df[pct_snp_col] = var_df.loc[combined_filter, pct_snp_col]

            if len(mean_pct_snp_col) == 0:
                mean_pct_snp_col = var_df.loc[combined_filter, pct_snp_col]
            else:
                mean_pct_snp_col = mean_pct_snp_col + var_df.loc[combined_filter, pct_snp_col]
        
        mean_pct_snp_col = mean_pct_snp_col.apply(lambda x: x/len(rep_sample_list))

        merge_df[f'{bio_condition.split('_')[1]}'] = mean_pct_snp_col

        final_hits_df = pd.merge(left=final_hits_df, right=merge_df, on=merge_cols, how='outer')

final_hits_df['gene_name'] = final_hits_df['gene_name'].apply(lambda x: ','.join(set(x.split(','))) if not pd.isna(x) else '')

mask = sample_map_df['condition'] != wt_condition
out_col_order = merge_cols + [f'{sample}_pct_snp' for sample in list(sample_map_df.loc[mask, 'sample'])] + [bio_condition.split('_')[1] for bio_condition in list(sample_map_df.loc[mask, 'condition'].unique())]

out_tsv_path = os.path.join(out_dir, 'final-hit-mean-pct-snps.tsv')
final_hits_df[out_col_order].to_csv(out_tsv_path, sep='\t', index=True, float_format='%.2f')