In [1]:
import numpy as np
import pandas as pd
from collections import defaultdict

In [2]:
### Global variables
ROOT = "/users/kcochran/projects/domain_adaptation_nosexchr/"

tfs = ["CTCF", "CEBPA", "Hnf4a", "RXRA"]
tfs_latex_names = ["CTCF", "CEBPA", "HNF4A", "RXRA"]

all_trainspecies = ["mm10", "hg38"]
model_names_dict = {"mm10" : "Mouse", "hg38" : "Human"}

# Load Predictions and Labels

In [3]:
def get_preds_file(tf, train_species, test_species):
    # this file is created by the 0_generate_predictions notebook
    return ROOT + "/model_out/" + tf + "_" + train_species + "-trained_" + test_species + "-test.preds.npy"


def load_average_test_set_preds(test_species):
    # takes a while to run.
    preds_dict = defaultdict(lambda : dict())

    # loop over mouse-trained, human-trained models, and DA mouse-trained models
    for train_species in all_trainspecies:
        for tf in tfs:
            print("=== " + tf + ", " + train_species + "-trained ===")

            # load predictions for all 5 independent model runs
            preds_file = get_preds_file(tf, train_species, test_species)
            try:
                preds_dict[train_species][tf] = np.mean(np.load(preds_file), axis = 1)
            except:
                print("Could not load preds file:", preds_file)
            
    return preds_dict


avg_preds_human_test = load_average_test_set_preds("hg38")

=== CTCF, mm10-trained ===
=== CEBPA, mm10-trained ===
=== Hnf4a, mm10-trained ===
=== RXRA, mm10-trained ===
=== CTCF, hg38-trained ===
=== CEBPA, hg38-trained ===
=== Hnf4a, hg38-trained ===
=== RXRA, hg38-trained ===


In [4]:
def get_test_bed_file(tf, species):
    # should be in BED file format
    # this file is specific to each tf -- the last column
    # should contain the binding label for each window
    return(ROOT + "data/" + species + "/" + tf + "/chr2.bed")


def get_test_labels(tf, species = "hg38"):
    # This function reads in the test-data bed file 
    # for a given species and TF and returns the binding labels
    # for each example in that file.
    labels_file = get_test_bed_file(tf, species)
    with open(labels_file) as f:
        return np.array([int(line.split()[-1]) for line in f])
        

def get_all_test_labels_and_indices(preds_dict, species = "hg38"):
    labels = dict()
    bound_indices = dict()
    for tf in tfs:
        # load binding labels from bed file
        labels_for_tf = get_test_labels(tf)
        # assuming the mm10 entry exists, species doesn't matter here
        # shape of preds array should be the same in all cases
        len_to_truncate_by = preds_dict["mm10"][tf].shape[0]
        # truncate binding labels to be multiple of batch size, like preds
        labels[tf] = labels_for_tf[:len_to_truncate_by]
        # get indices where label is bound
        bound_indices[tf] = np.nonzero(labels[tf])[0]
    return labels, bound_indices

labels, bound_indices = get_all_test_labels_and_indices(avg_preds_human_test)

## Find FPs, FNs, etc.

In [5]:
def get_fp_fn_indices(preds_dict, bound_indices):
    #tp_indices = defaultdict(lambda : defaultdict(lambda : []))  # not needed for now
    fp_indices = defaultdict(lambda : defaultdict(lambda : []))
    fn_indices = defaultdict(lambda : defaultdict(lambda : []))

    for tf in tfs:
        for train_species in all_trainspecies:
            print(tf, train_species)
            # using a threshold of 0.5 here to decide what is predicted as bound
            predicted_as_bound_indices = np.nonzero(preds_dict[train_species][tf] > 0.5)[0]

            for index in predicted_as_bound_indices:
                if not index in bound_indices[tf]:
                    fp_indices[tf][train_species].append(index)
                #else:
                #    tp_indices[tf][train_species].append(index)
            for index in bound_indices[tf]:
                if index not in predicted_as_bound_indices:
                    fn_indices[tf][train_species].append(index)
            
    return fp_indices, fn_indices


def get_site_indices(preds_dict, bound_indices):
    # This function uses the predictions and binding labels for
    # all sites in the test set to determine if sites are
    # both-model false positives, false negatives, mouse-model FPs, etc.
    
    # Here we use a definition of "differentially predicted site"
    # to find mouse-model FPs or FNs that requires the difference
    # between the mouse-model prediction and the human-model prediction
    # to be at least 0.5
    fp_indices, fn_indices = get_fp_fn_indices(preds_dict, bound_indices)
    
    site_indices = defaultdict(lambda : dict())

    for tf in tfs:
        site_indices[tf]["bothFP"] = set(fp_indices[tf]["mm10"]).intersection(set(fp_indices[tf]["hg38"]))
        site_indices[tf]["bothFN"] = set(fn_indices[tf]["mm10"]).intersection(set(fn_indices[tf]["hg38"]))

        diff_pred_mm10_overpred = set(np.nonzero(preds_dict["mm10"][tf] - preds_dict["hg38"][tf] > 0.5)[0])
        diff_pred_mm10_underpred = set(np.nonzero(preds_dict["hg38"][tf] - preds_dict["mm10"][tf] > 0.5)[0])

        site_indices[tf]["mFP"] = set(fp_indices[tf]["mm10"]).intersection(diff_pred_mm10_overpred)
        site_indices[tf]["mFN"] = set(fn_indices[tf]["mm10"]).intersection(diff_pred_mm10_underpred)
        
    return site_indices


site_subset_indices = get_site_indices(avg_preds_human_test, bound_indices)

CTCF mm10
CTCF hg38
CEBPA mm10
CEBPA hg38
Hnf4a mm10
Hnf4a hg38
RXRA mm10
RXRA hg38


# Load Repeat Annotations

In [6]:
REPEAT_TYPES = ["DNA", "LINE", "Low_complexity", "LTR", "Simple_repeat", "SINE", "Unknown"]
# Removed due to < 500 instances in the test set: ["RC", "Retroposon", "RNA", "rRNA", "Satellite", "scRNA", "srpRNA", "tRNA"]

def get_rmsk_file():
    # this file is downloaded by an earlier script
    return ROOT + "data/hg38/rmsk.bed"


def read_and_filter_rmsk_file(repeat_name, is_subfam = False, test_chrom = "chr2"):
    # This function reads in the RepeatMasker bed file,
    # filters for only rows listing annotations of one
    # repeat type, and then returns only the start and 
    # end coordinate for each annotation.
    
    # We're assuming all the test set examples are on
    # one chromosome, so we don't need the first column.
    
    # assuming bed format
    filename = get_rmsk_file()
    df = pd.read_csv(filename, sep = "\t", usecols = [5, 6, 7, 10, 11], header = None)

    if is_subfam:
        df = df[df[10] == repeat_name]
        df = df[df[5] == test_chrom]
    else:
        df = df[df[11] == repeat_name]
        df = df[df[5] == test_chrom]

    sorted_repeat_coords = sorted(list(zip(df[6], df[7])), key = lambda tup : tup[0])
    return np.array(sorted_repeat_coords)


def get_repeat_and_test_set_overlap(list_a, list_b):
    # This function is similar to bedtools intersect,
    # but returns a binary yes/no for overlap for each
    # window in list_a.
    
    # Assumes everything's on the same chromosome
    # Assumes inputs are lists of 2-ples: (start, stop) 
    
    # output is list with len == len(list_a)
    matches = []
    b_index = 0
    for a_item in list_a:
        a_start, a_end = a_item
        while True:
            if b_index >= len(list_b):
                matches.append(False)
                break
                
            b_start, b_end = list_b[b_index]
            # the -1 is because bed files are 1-indexed
            if b_start > a_end - 1:  
                matches.append(False)
                break
            elif b_end <= a_start:
                b_index += 1
            else:
                matches.append(True)
                break
    assert len(matches) == len(list_a)
    return np.array(matches)


def get_test_bed_coords(species = "hg38"):
    # This function loads in the bed file for the test set
    # and keeps only the start and end coords for each entry.
    # Here we assume the test set is 1 chromosome
    
    # later analysis will assume the coords are sorted,
    # as in `sort -k1,1 -k2,2n $bed_file`
    
    # TF doesn't matter here because we're not using labels
    test_bed = get_test_bed_file(tfs[0], species)
    df = pd.read_csv(test_bed, sep = "\t", usecols = [1, 2],
                     header = None)
    return df.to_numpy()


def get_all_repeat_labels_and_indices():
    all_windows_coords = get_test_bed_coords()
    repeat_labels = dict()
    repeat_indices = dict()

    for repeat_type in REPEAT_TYPES:
        print(repeat_type)
        repeat_type_coords = read_and_filter_rmsk_file(repeat_type)
        # filtering for repeat types with at least 500 instances
        # in the test set, so we don't get incorrectly extreme results
        assert len(repeat_type_coords) > 500, (repeat_type, len(repeat_type_coords))
        repeat_labels[repeat_type] = get_repeat_and_test_set_overlap(all_windows_coords, repeat_type_coords)
        repeat_indices[repeat_type] = set(np.nonzero(repeat_labels[repeat_type])[0])
        
    return repeat_labels, repeat_indices


repeat_labels, repeat_indices = get_all_repeat_labels_and_indices()

DNA
LINE
Low_complexity
LTR
Simple_repeat
SINE
Unknown


In [7]:
def calc_repeat_fracs_from_site_overlap(bound_indices, labels, site_subset_indices, repeat_indices):
    # This function calculates what fraction of sites are overlapping a given repeat type,
    # for various categories of sites. The output is returned in nested dictionaries, one
    # for each TF, because the site categorizations are specific to the TF.
    
    repeat_fracs = defaultdict(lambda : dict())

    for tf in tfs:
        # Bound site repeat fraction = (# bound sites overlapping repeat) / (# bound sites)
        num_bound_sites_with_repeat = len(set(bound_indices[tf]).intersection(repeat_indices))
        repeat_fracs[tf]["bound"] = num_bound_sites_with_repeat / len(bound_indices[tf])
        
        # this arithmetic reverses binary-numeric labels (0 is now 1, 1 is now 0)
        num_unbound_sites = sum(labels[tf] * -1 + 1)
        
        # Unbound site repeat fraction = (# unbound sites overlapping repeat) / (# unbound sites)
        # where (# unbound sites overlapping repeat) = (# repeat-overlap sites in test set not in bound site set)
        num_unbound_sites_with_repeat = len(repeat_indices.difference(set(bound_indices[tf])))
        repeat_fracs[tf]["unbound"] = num_unbound_sites_with_repeat / num_unbound_sites

        # for each of the specific categories of sites we're interested in...
        # (e.g. "false positives", "mouse-model false negatives")
        for site_type in site_subset_indices[tf].keys():
            # calc total # of sites in this category
            num_sites = len(site_subset_indices[tf][site_type])
            if num_sites > 0:
                # calc # of sites in this category that overlap the given repeat type
                num_sites_with_repeat = len(site_subset_indices[tf][site_type].intersection(repeat_indices))
                # finally, calc fraction of sites in this category that overlap the repeat type
                repeat_fracs[tf][site_type] = num_sites_with_repeat / num_sites
            else:
                repeat_fracs[tf][site_type] = np.nan
            
    return repeat_fracs


def generate_all_repeat_fracs_for_table(bound_indices, labels, site_subset_indices, repeat_indices):     
    # This function creates the dictionary needed for the print_full_table() function below,
    # where each key is a repeat type and the value is nested dictionaries of the fraction
    # of sites overlapping that repeat type for a given TF (since sets of sites, such as
    # "bound site" and "mouse-model false positive", are different for each TF).
    repeat_fracs = dict()

    for repeat_type in REPEAT_TYPES:
        repeat_fracs[repeat_type] = calc_repeat_fracs_from_site_overlap(bound_indices, labels,
                                                                        site_subset_indices,
                                                                        repeat_indices[repeat_type])
    return repeat_fracs



all_repeat_fracs = generate_all_repeat_fracs_for_table(bound_indices, labels,
                                                       site_subset_indices,
                                                       repeat_indices)

In [8]:
def fix_repeat_name(repeat_name):
    # If a repeat name has an underscore in it, this will
    # mess up the latex formatting, so we replace the
    # underscores with spaces and then capitalize the first
    # letter of each word in the repeat name
    if "_" in repeat_name:
        return " ".join(repeat_name.split("_")).title()
    return repeat_name
    

def print_full_table(all_repeat_fracs, header = None, row_order = None):
    # This function prints a fully latex-formatted table, made up of
    # one sub-table for each repeat type in the all_repeat_fracs dict.
    # The columns of the table are given below in header, and the
    # rows of the sub-tables are the TFs.
    print(r'\begin{table*}{')
    print(r'\setlength{\tabcolsep}{0.8em}')
    print(r'\centering \begin{tabular}{@{}ccccccc@{}}\toprule')
    
    if header is None:
        header = "TF & Bound & FN (Both Models) & FN (Mouse Only) & Unbound & FP (Both Models) & FP (Mouse Only)"
        col_order = ["bound", "bothFN", "mFN", "unbound", "bothFP", "mFP"]
    print(header + r' \\\midrule')
    
    if row_order is None:
        row_order = tfs
    last_row = row_order[-1]
    
    last_repeat_type = REPEAT_TYPES[-1]
    for repeat_type in REPEAT_TYPES:
        print(r'\multicolumn{7}{c}{\textbf{' + fix_repeat_name(repeat_type) + r'}} \\\midrule')
        repeat_fracs = all_repeat_fracs[repeat_type]
        
        for row_key in row_order:
            row = [repeat_fracs[row_key][col] for col in col_order]
            row_as_str = ["%0.1f" % (100 * num) + r'\%' for num in row]
            row_as_str[-1] = r'\textbf{' + row_as_str[-1] + r'}'
            tf_fancy_name = tfs_latex_names[tfs.index(row_key)]
            if row_key is not last_row:
                print(tf_fancy_name + " & " + " & ".join(row_as_str) + r' \\')
            else:
                if repeat_type == last_repeat_type:
                    print(tf_fancy_name + " & " + " & ".join(row_as_str) + r' \\\bottomrule')
                else:
                    print(tf_fancy_name + " & " + " & ".join(row_as_str) + r' \\\midrule')

    print(r'\end{tabular}}{}')
    print(r'\captionof{table}{Percent of windows overlapping various RepeatMasker-defined ' + \
            'repeat elements, for different categories of genomic windows from the held-out test set. ' + \
            'Only RepeatMasker repeat classes with at least 500 distinct annotations within the test ' + \
            'set are shown. FPs: false positives. FNs: false negatives. Mouse Only: specific to' + \
            r'mouse-trained models. See Methods for more details on site categorization.}')
    print(r'\end{table*}')
    

print_full_table(all_repeat_fracs)

\begin{table*}{
\setlength{\tabcolsep}{0.8em}
\centering \begin{tabular}{@{}ccccccc@{}}\toprule
TF & Bound & FN (Both Models) & FN (Mouse Only) & Unbound & FP (Both Models) & FP (Mouse Only) \\\midrule
\multicolumn{7}{c}{\textbf{DNA}} \\\midrule
CTCF & 10.1\% & 11.4\% & 7.3\% & 11.4\% & 8.9\% & \textbf{9.0\%} \\
CEBPA & 12.3\% & 10.4\% & 8.3\% & 11.3\% & 13.0\% & \textbf{9.2\%} \\
HNF4A & 10.7\% & 12.0\% & 11.3\% & 11.4\% & 9.5\% & \textbf{9.0\%} \\
RXRA & 10.1\% & 11.7\% & 8.8\% & 11.4\% & 10.0\% & \textbf{9.4\%} \\\midrule
\multicolumn{7}{c}{\textbf{LINE}} \\\midrule
CTCF & 18.3\% & 22.5\% & 21.3\% & 37.6\% & 17.8\% & \textbf{31.6\%} \\
CEBPA & 25.6\% & 26.3\% & 25.0\% & 37.6\% & 29.0\% & \textbf{32.3\%} \\
HNF4A & 21.0\% & 25.3\% & 26.3\% & 37.6\% & 21.5\% & \textbf{30.5\%} \\
RXRA & 21.0\% & 27.9\% & 22.1\% & 37.8\% & 22.1\% & \textbf{33.2\%} \\\midrule
\multicolumn{7}{c}{\textbf{Low Complexity}} \\\midrule
CTCF & 2.5\% & 1.0\% & 2.6\% & 1.9\% & 4.0\% & \textbf{1.5\%} \\
CEBPA & 1.