In [1]:
from collections import defaultdict
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib as mpl
%matplotlib inline

In [2]:
### Global variables
ROOT = "/users/kcochran/projects/domain_adaptation_nosexchr/"

# shorthand for each TF name
tfs = ["CTCF", "CEBPA", "Hnf4a", "RXRA"]

# plot-acceptable TF names
tfs_latex_names = ["CTCF", "CEBPA", "HNF4A", "RXRA"]

# shorthand names for species to include in plots
all_trainspecies = ["mm10", "DA", "hg38"]
# plot-acceptable names for species
model_names_dict = {"mm10" : "Mouse", "hg38" : "Human"}

# Load Predictions, Alu Annotations, and Labels

In [3]:
def get_alu_intersect_file_chr2():
    # See make_repeat_files.sh for creating this file.
    # Basically:
    # awk '$1 == "chr2"' [repeatmaker alu file] > rmsk_alus_chr2.bed
    # bedtools intersect -a [get_test_bed_file(species)] -b rmsk_alus_chr2.bed -u -sorted > chr2_alus_intersect.bed
    
    # This file should contain all windows in the test data
    # that intersect with Alus (this is different from all
    # annotated Alus -- model is expecting windows of the
    # correct size).
    
    # This file is not specific to a TF.
    return(ROOT + "data/hg38/chr2_alus_intersect.bed")


def get_preds_file(tf, train_species, test_species):
    # this file is created by the 0_generate_predictions notebook
    return ROOT + "/model_out/" + tf + "_" + train_species + "-trained_" + test_species + "-test.preds.npy"


def load_average_test_set_preds(test_species = "hg38"):
    # takes a while to run.
    preds_dict = defaultdict(lambda : dict())

    # loop over mouse-trained, human-trained models, and DA mouse-trained models
    for train_species in all_trainspecies:
        for tf in tfs:
            print("=== " + tf + ", " + train_species + "-trained ===")

            # load predictions for all 5 independent model runs
            preds_file = get_preds_file(tf, train_species, test_species)
            try:
                preds = np.load(preds_file)
                preds_dict[train_species][tf] = np.mean(preds, axis = 1)
            except:
                print("Could not load preds file:", preds_file)
            
    return preds_dict


def get_test_bed_file(tf, species):
    # should be in BED file format
    # this file is specific to each tf -- the last column
    # should contain the binding label for each window
    return(ROOT + "data/" + species + "/" + tf + "/chr2.bed")


def get_test_labels(tf, species = "hg38"):
    # This function reads in the test-data bed file 
    # for a given species and TF and returns the binding labels
    # for each example in that file.
    labels_file = get_test_bed_file(tf, species)
    with open(labels_file) as f:
        return np.array([int(line.split()[-1]) for line in f])
        

def get_all_test_labels_and_indices(preds_dict, species = "hg38"):
    labels = dict()
    bound_indices = dict()
    for tf in tfs:
        # load binding labels from bed file
        labels_for_tf = get_test_labels(tf)
        # assuming the mm10 entry exists, species doesn't matter here
        # shape of preds array should be the same in all cases
        len_to_truncate_by = preds_dict["mm10"][tf].shape[0]
        # truncate binding labels to be multiple of batch size, like preds
        labels[tf] = labels_for_tf[:len_to_truncate_by]
        # get indices where label is bound
        bound_indices[tf] = np.nonzero(labels[tf])[0]
    return labels, bound_indices

# Process Data To Calculate Alu Overlap

In [8]:
def get_fp_fn_indices(preds_dict, bound_indices, species1 = "mm10", species2 = "hg38"):
    #tp_indices = defaultdict(lambda : defaultdict(lambda : []))  # not needed for now
    fp_indices = defaultdict(lambda : defaultdict(lambda : []))
    fn_indices = defaultdict(lambda : defaultdict(lambda : []))

    for tf in tfs:
        for train_species in [species1, species2]:
            print(tf, train_species)
            # using a threshold of 0.5 here to decide what is predicted as bound
            predicted_as_bound_indices = np.nonzero(preds_dict[train_species][tf] > 0.5)[0]

            for index in predicted_as_bound_indices:
                if not index in bound_indices[tf]:
                    fp_indices[tf][train_species].append(index)
                #else:
                #    tp_indices[tf][train_species].append(index)
            for index in bound_indices[tf]:
                if index not in predicted_as_bound_indices:
                    fn_indices[tf][train_species].append(index)
            
    return fp_indices, fn_indices


def get_site_indices(preds_dict, bound_indices, species1 = "mm10", species2 = "hg38"):
    # This function uses the predictions and binding labels for
    # all sites in the test set to determine if sites are
    # both-model false positives, false negatives, mouse-model FPs, etc.
    
    # Here we use a definition of "differentially predicted site"
    # to find mouse-model FPs or FNs that requires the difference
    # between the mouse-model prediction and the human-model prediction
    # to be at least 0.5
    fp_indices, fn_indices = get_fp_fn_indices(preds_dict, bound_indices, species1, species2)
    
    site_indices = defaultdict(lambda : dict())

    for tf in tfs:
        site_indices[tf]["bothFP"] = set(fp_indices[tf][species1]).intersection(set(fp_indices[tf][species2]))
        site_indices[tf]["bothFN"] = set(fn_indices[tf][species1]).intersection(set(fn_indices[tf][species2]))

        diff_pred_mm10_overpred = set(np.nonzero(preds_dict[species1][tf] - preds_dict[species2][tf] > 0.5)[0])
        diff_pred_mm10_underpred = set(np.nonzero(preds_dict[species2][tf] - preds_dict[species1][tf] > 0.5)[0])

        site_indices[tf]["mFP"] = set(fp_indices[tf][species1]).intersection(diff_pred_mm10_overpred)
        site_indices[tf]["mFN"] = set(fn_indices[tf][species1]).intersection(diff_pred_mm10_underpred)
        
    return site_indices


def get_bed_start_coords_fast(filename):
    # assuming the file is in bed format and col 2 is what we want
    
    # later analysis will assume the coords are sorted,
    # as in `sort -k1,1 -k2,2n $bed_file`
    
    df = pd.read_csv(filename, sep='\t', header=None)
    starts = np.array(df[1])
    return starts


def is_lista_in_listb_sorted(list_a, list_b):
    # This function returns a list of len len(list_a).
    # Each element i in the output list is True if 
    # element i in list_a is in list_b, False otherwise.
    # this function is NOT symmetric!!!
    
    # This function assumes that list_b is a subset of list_a
    # (doesn't contain elements not found in list_a)
    matches = []
    b_index = 0
    for a_item in list_a:
        while True:
            if b_index >= len(list_b):
                matches.append(False)
                break
            if list_b[b_index] > a_item:
                matches.append(False)
                break
            else:
                assert list_b[b_index] == a_item
                matches.append(True)
                b_index += 1
                break
    return np.array(matches)

   
def get_alu_site_indices(test_species = "hg38"):
    # This returns a list containing the index of a site
    # in the test set iff that site overlaps with an Alu.
    
    # First, get the start coords of all windows that overlap
    # with an Alu element (from Alu intersect file already created)
    alu_starts = get_bed_start_coords_fast(get_alu_intersect_file_chr2())
    
    # Second, get the start coords of all test set windows
    # which tf here doesn't matter; not using labels
    all_starts = get_bed_start_coords_fast(get_test_bed_file(tfs[0], test_species))
    
    # Third, find which windows in the sorted test set are in
    # the set of Alu-containing windows
    alu_labels = is_lista_in_listb_sorted(all_starts, alu_starts)
    
    # Convert list of binary yes/no overlaps into
    # list of indices for all yes overlaps
    alu_indices = set(np.nonzero(alu_labels)[0])
    return alu_indices


def calc_alu_fracs_from_site_overlap(bound_indices, labels, site_subset_indices, alu_indices):
    alu_fracs = defaultdict(lambda : dict())

    for tf in tfs:
        # Bound site Alu fraction = (# bound sites overlapping an Alu) / (# bound sites)
        alu_fracs[tf]["bound"] = len(set(bound_indices[tf]).intersection(alu_indices)) / len(bound_indices[tf])
        
        # the -1 + 1 reverses True labels 
        num_unbound_sites = sum(labels[tf] * -1 + 1)
        # Unbound site Alu fraction = (# unbound sites overlapping an Alu) / (# unbound sites)
        # where (# unbound sites overlapping an Alu) = (# alu-overlap sites in test set not in bound site set)
        alu_fracs[tf]["unbound"] = len(alu_indices.difference(set(bound_indices[tf]))) / num_unbound_sites

        # for each of the specific categories of sites we're interested in...
        # (e.g. "false positives", "mouse-model false negatives")
        for site_type in site_subset_indices[tf].keys():
            # calc total # of sites in this category
            num_sites = len(site_subset_indices[tf][site_type])
            if num_sites > 0:
                # calc # of sites in this category that overlap the given repeat type
                alu_intersect_count = len(site_subset_indices[tf][site_type].intersection(alu_indices))
                # finally, calc fraction of sites in this category that overlap the repeat type
                alu_fracs[tf][site_type] = alu_intersect_count / num_sites
            else:
                alu_fracs[tf][site_type] = np.nan
            
    return alu_fracs


def generate_alu_fracs_for_table():
    # This function runs all the functions above to
    # generate nested dictionaries containing the fraction
    # of sites overlapping Alus for various sets of sites,
    # such as false-positives.
    avg_preds_human_test = load_average_test_set_preds()
    labels, bound_indices = get_all_test_labels_and_indices(avg_preds_human_test)
    alu_indices = get_alu_site_indices()
    
    # get alu overlap fractions when mouse and human models are compared
    site_subset_indices = get_site_indices(avg_preds_human_test, bound_indices,
                                           "mm10", "hg38")
    alu_fracs = calc_alu_fracs_from_site_overlap(bound_indices, labels,
                                                 site_subset_indices,
                                                 alu_indices)
    
    # do the same, but for when mouse+DA and human models are compared
    site_subset_indices_DA = get_site_indices(avg_preds_human_test, bound_indices,
                                              "DA", "hg38")
    alu_fracs_DA = calc_alu_fracs_from_site_overlap(bound_indices, labels,
                                                    site_subset_indices_DA,
                                                    alu_indices)
    return alu_fracs, alu_fracs_DA


alu_fracs, alu_fracs_DA = generate_alu_fracs_for_table()

=== CTCF, mm10-trained ===
=== CEBPA, mm10-trained ===
=== Hnf4a, mm10-trained ===
=== RXRA, mm10-trained ===
=== CTCF, DA-trained ===
=== CEBPA, DA-trained ===
=== Hnf4a, DA-trained ===
=== RXRA, DA-trained ===
=== CTCF, hg38-trained ===
=== CEBPA, hg38-trained ===
=== Hnf4a, hg38-trained ===
=== RXRA, hg38-trained ===
CTCF mm10
CTCF hg38
CEBPA mm10
CEBPA hg38
Hnf4a mm10
Hnf4a hg38
RXRA mm10
RXRA hg38
CTCF DA
CTCF hg38
CEBPA DA
CEBPA hg38
Hnf4a DA
Hnf4a hg38
RXRA DA
RXRA hg38


In [9]:
def print_table(alu_fracs, header = None, row_order = None, caption = None):
    # This function prints a fully latex-formatted table.
    # The columns of the table are given below in header, and
    # there is one row in the table for each TF.
    
    print(r'\begin{table*}{')
    print(r'\centerline{ \begin{tabular}{@{}c|ccc|ccc@{}}\toprule')
    
    if header is None:
        header = "TF & Bound & FN (Both Models) & FN (Mouse Only) & Unbound & FP (Both Models) & FP (Mouse Only)"
        col_order = ["bound", "bothFN", "mFN", "unbound", "bothFP", "mFP"]
    print(header + r' \\\midrule')
    
    if row_order is None:
        row_order = tfs
    last_row = row_order[-1]
    for row_key in row_order:
        row = [alu_fracs[row_key][col] for col in col_order]
        row_as_str = ["%0.1f" % (100 * num) + r'\%' for num in row]
        row_as_str[-1] = r'\textbf{' + row_as_str[-1] + r'}'
        tf_fancy_name = tfs_latex_names[tfs.index(row_key)]
        if row_key is not last_row:
            print(tf_fancy_name + " & " + " & ".join(row_as_str) + r' \\')
        else:
            print(tf_fancy_name + " & " + " & ".join(row_as_str) + r' \\\bottomrule')
    
    print(r'\end{tabular}}}{}')
    if caption is not None:
        print(r'\captionof{table}{' + caption + r'\label{Tab:01}}')
    print(r'\end{table*}')

In [10]:
caption1 = r'Percent of windows overlapping an \textit{Alu} element, ' + \
r'for various categories of genomic windows from the held-out test set. \textit{Alu} ' + \
r'elements dominate the false positives unique to the mouse models. FPs: false ' + \
r'positives. FNs: false negatives. See Methods for more details on site categorization.'

print_table(alu_fracs, caption = caption1)

\begin{table*}{
\centerline{ \begin{tabular}{@{}c|ccc|ccc@{}}\toprule
TF & Bound & FN (Both Models) & FN (Mouse Only) & Unbound & FP (Both Models) & FP (Mouse Only) \\\midrule
CTCF & 12.6\% & 12.8\% & 9.9\% & 21.3\% & 10.0\% & \textbf{78.6\%} \\
CEBPA & 18.3\% & 11.1\% & 0.0\% & 21.3\% & 22.9\% & \textbf{84.8\%} \\
HNF4A & 13.6\% & 10.4\% & 8.0\% & 21.3\% & 16.9\% & \textbf{95.1\%} \\
RXRA & 13.7\% & 10.6\% & 5.5\% & 21.4\% & 20.3\% & \textbf{97.4\%} \\\bottomrule
\end{tabular}}}{}
\captionof{table}{Percent of windows overlapping an \textit{Alu} element, for various categories of genomic windows from the held-out test set. \textit{Alu} elements dominate the false positives unique to the mouse models. FPs: false positives. FNs: false negatives. See Methods for more details on site categorization.\label{Tab:01}}
\end{table*}


In [11]:
caption2 = r'Percent of windows overlapping an \textit{Alu} element when ' + \
r'domain-adaptive mouse models are compared to human models (compare to Table 1). ' + \
r'The fraction of mouse-model-unique false positives overlapping \textit{Alu} elements ' + \
r'(right-most column) have decreased notably for all TFs. ' + \
r'FPs: false positives. FNs: false negatives.'

print_table(alu_fracs_DA, caption = caption2)

\begin{table*}{
\centerline{ \begin{tabular}{@{}c|ccc|ccc@{}}\toprule
TF & Bound & FN (Both Models) & FN (Mouse Only) & Unbound & FP (Both Models) & FP (Mouse Only) \\\midrule
CTCF & 12.6\% & 13.5\% & 13.7\% & 21.3\% & 9.0\% & \textbf{28.8\%} \\
CEBPA & 18.3\% & 16.8\% & 0.0\% & 21.3\% & 21.9\% & \textbf{49.5\%} \\
HNF4A & 13.6\% & 14.8\% & 13.7\% & 21.3\% & 14.0\% & \textbf{34.3\%} \\
RXRA & 13.7\% & 17.7\% & 10.7\% & 21.4\% & 15.8\% & \textbf{58.7\%} \\\bottomrule
\end{tabular}}}{}
\captionof{table}{Percent of windows overlapping an Alu element when domain-adaptive mouse models are compared to human models (compare to Table 1). The fraction of mouse-model-unique false positives overlapping Alu elements (right-most column) have decreased drastically for all TFs. FPs: false positives. FNs: false negatives.\label{Tab:01}}
\end{table*}
