In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [2]:
import numpy as np
import keras
import tensorflow
from matplotlib import pyplot as plt
import matplotlib as mpl
import pandas as pd
import pybedtools as pbt
from collections import defaultdict
%matplotlib inline

Using TensorFlow backend.


In [3]:
# Global variables
ROOT = "/users/kcochran/projects/domain_adaptation/"

# shorthand for each TF name
tfs = ["CTCF", "CEBPA", "Hnf4a", "RXRA"]
# plot-acceptable TF names
tfs_latex_names = ["CTCF", "CEBPα", "HNF4α", "RXRα"]

# shorthand names for all model types to include in plots
all_trainspecies = ["mm10", "DA", "hg38", "NS"]
# plot-acceptable names for model types
model_names_dict = {"mm10" : "Mouse",
                    "hg38" : "Human",
                    "DA" : "Mouse+DA",
                    "NS" : "Human-NS"}

# Constants to be used for plot appearance details
DOT_SIZE = 5
ALPHA = 0.03
AXIS_SIZE = 11
AX_OFFSET = 0.02
TF_TWINAX_OFFSET = 0.35
FIG_SIZE_UNIT = 5
FIG_SIZE_2_by_4 = (FIG_SIZE_UNIT, FIG_SIZE_UNIT * 2)
FIG_SIZE_1_by_2 = (FIG_SIZE_UNIT / 2, FIG_SIZE_UNIT)
BOUND_SUBSAMPLE_RATE = 4

import random
random.seed(1234) 

# If you don't care about testing the model on all examples
# and want to speed things up, you can set SKIP to not None;
# every SKIP-th ***UNBOUND*** example will be used in model evaluation.
# Be careful -- make sure that *everywhere* in the code,
# the same SKIP value is being used!
# Note that since bound sites are so sparse, SKIP only applies
# to UNBOUND sites.
SKIP = 200

MODEL_TYPE = "best"  # used to use early stopping but not anymore

In [4]:
# needed to load DA models

from flipGradientTF import GradientReversal

def custom_loss(y_true, y_pred):
    y_pred = tensorflow.boolean_mask(y_pred, tensorflow.not_equal(y_true, -1))
    y_true = tensorflow.boolean_mask(y_true, tensorflow.not_equal(y_true, -1))
    return keras.losses.binary_crossentropy(y_true, y_pred)

In [6]:
from keras.utils import Sequence
from seqdataloader.batchproducers.coordbased.core import Coordinates
from seqdataloader.batchproducers.coordbased.coordstovals.fasta import PyfaidxCoordsToVals

ROOT = "/users/kcochran/projects/domain_adaptation/"

GENOMES = {"mm10" : "/users/kcochran/genomes/mm10_no_alt_analysis_set_ENCODE.fasta",
           "hg38" : "/users/kcochran/genomes/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"}


def get_test_bed_file(species, tf):
    # This function returns the path to a BED-format file
    # containing the chromosome names, starts, and ends for
    # all examples to test the model with.
    # Note this is specific to a TF (binding labels
    # are loaded in from this file)!
    return(ROOT + "data/" + species + "/" + tf + "/chr2.bed")


class UnboundTestGenerator(Sequence):
    # This generator retrieves all coordinates for ***UNBOUND*** windows
    # in the test set, converting their sequence to one-hot encodings.
    
    # If you don't care about testing the model on all examples
    # and want to speed things up, uyou can set SKIP to not None;
    # every SKIP-th example will be used in model evaluation.
    # Be careful -- be consistent with the SKIP value you choose!
    
    def __init__(self, batchsize, val_file, skip = None):
        self.valfile = val_file
        self.get_steps(skip, batchsize)
        # the plots we're making are for the human test data only
        assert "hg38" in val_file
        self.converter = PyfaidxCoordsToVals(GENOMES["hg38"])
        # batchsize here is just how many examples to evaluate at once,
        # not the training batchsize; go a high as your GPU can fit
        self.batchsize = batchsize
        self.get_unbound_coords(skip)
        
        
    def get_steps(self, skip, batchsize):
        # calculates the number of steps needed to get through
        # all batches of UNBOUND examples
        # (Keras predict_generator code needs to know this)
        with open(self.valfile) as f:
            # if condition tests that the example is unbound
            lines_in_file = sum([1 for line in f if line.rstrip().split()[-1] == "0"])
        if skip is None:
            self.steps = lines_in_file // batchsize
        else:
            self.steps = (lines_in_file // skip) // batchsize


    def __len__(self):
        return self.steps


    def get_unbound_coords(self, skip):
        # read through the test data BED file and load in
        # coordinates for each example into memory,
        # ***IF*** the example is NOT bound
        coords = []
        line_count = 0
        with open(self.valfile) as f:
            for line in f:
                line_split = line.rstrip().split()
                # if example is bound, value in last column is 1;
                # else it is 0
                if line_split[-1] == "0":  # if unbound
                    if skip is None or line_count % skip == 0:
                        coords.append(line_split[:3])
                    line_count += 1
        self.coords = [Coordinates(c[0], int(c[1]), int(c[2])) for c in coords]


    def __getitem__(self, batch_index):
        # convert a batch's worth of coordinates into one-hot sequences
        batch = self.coords[batch_index * self.batchsize : (batch_index + 1) * self.batchsize]
        return self.converter(batch)

        
class BoundTestGenerator(Sequence):
    # This generator retrieves all coordinates for ***BOUND*** windows
    # in the test set, converting their sequence to one-hot encodings.


    def __init__(self, batchsize, val_file):
        self.valfile = val_file
        self.get_steps(batchsize)
        # the plots we're making are for the human test data only
        assert "hg38" in val_file
        self.converter = PyfaidxCoordsToVals(GENOMES["hg38"])
        # batchsize here is just how many examples to evaluate at once,
        # not the training batchsize; go a high as your GPU can fit
        self.batchsize = batchsize
        self.get_bound_coords()
        
        
    def get_steps(self, batchsize):
        # calculates the number of steps needed to get through
        # all batches of BOUND examples
        # (Keras predict_generator code needs to know this)
        with open(self.valfile) as f:
            # if condition tests that the example is bound
            lines_in_file = sum([1 for line in f if line.rstrip().split()[-1] == "1"])
        self.steps = lines_in_file // batchsize


    def __len__(self):
        return self.steps
    
        
    def get_bound_coords(self):
        # read through the test data BED file and load in
        # coordinates for each example into memory,
        # ***IF*** the example is BOUND
        with open(self.valfile) as f:
            # if example is bound, value in last column is 1;
            # else it is 0
            coords_tmp = [line.split()[:3] for line in f if line.rstrip().split()[-1] == "1"]
            self.coords = [Coordinates(c[0], int(c[1]), int(c[2])) for c in coords_tmp]
            assert len(coords_tmp) > 0


    def __getitem__(self, batch_index):
        # convert a batch's worth of coordinates into one-hot sequences
        batch = self.coords[batch_index * self.batchsize : (batch_index + 1) * self.batchsize]
        assert len(batch) > 0
        return self.converter(batch)
    

def get_preds_batched_fast(model, batch_size, bound, val_file, skip = None):
    # Make predictions on test data using a specified model.
    # Batch_size can be as big as your compute can handle.
    # Different generators are used for bound/unbound sites to
    # to simplify downstream analysis (Alu site plots are made
    # for unbound sites only).
    
    # NOTE the use of SKIP here for the unbound site generator --
    # needs to be consistent everywhere in this notebook!
    if bound:
        generator = BoundTestGenerator(batch_size, val_file)
    else:
        generator = UnboundTestGenerator(batch_size, val_file, skip)
    return np.squeeze(model.predict_generator(generator,
                                              use_multiprocessing = True,
                                              workers = 4, verbose = 1))


def get_avg_preds_on_seqs(models, bound, val_file, DA = False, skip = None):
    # Generate predictions on test data for a set of models,
    # and then compute the average prediction across models
    # for each example in the test data.
    
    all_preds = [get_preds_batched_fast(model, 1024, bound, val_file, skip = skip) for model in models]
    avg_preds = np.mean(np.array(all_preds), axis = 0)
    if DA:
        # DA models return 2 predictions: 1 for the binding task,
        # and 1 for the species discriminator task.
        # We only want the binding task predictions
        avg_preds = avg_preds[0]
    return avg_preds


In [None]:
### File and Model Loading


def get_alu_intersect_file_chr2():
    # See make_repeat_files.sh for creating this file.
    # Basically:
    # awk '$1 == "chr2"' [repeatmaker alu file] > rmsk_alus_chr2.bed
    # bedtools intersect -a [get_test_bed_file(species)] -b rmsk_alus_chr2.bed -u -sorted > chr2_alus_intersect.bed
    
    # This file should contain all windows in the test data
    # that intersect with Alus (this is different from all
    # annotated Alus -- model is expecting windows of the
    # correct size).
    
    # This file is not specific to a TF.
    return(ROOT + "data/hg38/chr2_alus_intersect.bed")


def get_model_file(tf, train_species, run = 1, model_type = MODEL_TYPE):
    # This function returns the filepath where the model for a given
    # TF, training species, and run is saved.
    # By default, the file for the best model across all training epochs
    # is returned, you can change model_type to select the last model instead.
    # This function specifically looks for the most recent model file,
    # if there are multiple for the same run-TF-species combo.
    try:
        run_int = int(run)
    except:
        print("Error: You need to pass in a run number.")
    
    model_file_prefix = ROOT + "/".join(["models", tf, train_species + "_trained", "basic_model/"])
    
    if train_species == "DA":
        model_file_prefix = model_file_prefix.replace("DA", "mm10")  # assuming all DA models are mouse-trained
        model_file_prefix = model_file_prefix.replace("basic_model", "DA")
    if train_species == "NS":
        model_file_prefix = model_file_prefix.replace("NS", "hg38")
        model_file_prefix = model_file_prefix.replace("basic_model", "basic_model_nosines")
        
    # leftover from when I tried early stopping models instead
    if model_type == "end":
        model_file_suffix = "_run" + str(run) + "_15E_end.model"
    elif model_type == "earlystop":
        model_file_suffix = "_run" + str(run) + "_earlystop.model"
    elif model_type == "best":
        model_file_suffix = "_run" + str(run) + "_best.model"
    else:
        assert model_type is None, model_type
    
    files = [f for f in os.listdir(model_file_prefix) if f.endswith(model_file_suffix)]
    if len(files) == 1:
        return model_file_prefix + files[0]
    # sort files and return the one that is most recent
    latest_file = max([model_file_prefix + f for f in files], key=os.path.getctime)
    print(latest_file)
    return latest_file


def load_keras_model(model_file, DA = False):
    if DA:
        return keras.models.load_model(model_file,
                    custom_objects = {"GradientReversal":GradientReversal,
                                      "custom_loss":custom_loss})
    return keras.models.load_model(model_file)


def get_models_all_runs(tf, train_species):
    # load in models for all runs, for a given TF and training species
    # returns a list of Keras model objects
    files = [get_model_file(tf, train_species, run + 1) for run in range(5)]
    if train_species == "DA":
        return [load_keras_model(f, DA = True) for f in files]
    else:
        return [load_keras_model(f) for f in files]

In [5]:
### Alu functions

def get_window_starts(bed_file, skip = None):
    # This function reads in a BED file and returns a list of
    # the value in column 2 for every line.
    # If "skip" is none, this is equivalent to: awk '{print $2}'
    # If "skip" is NOT None, only every skip-th ***UNBOUND***
    # line will be retained.
    # Note that the filtering out of bound sites only
    # happens if skip is not None.
    # This function is intended for use only in analysis of
    # UNBOUND windows that may overlap with Alus. Modify before
    # using this for bound site analysis.
    window_starts = []
    line_count = 0
    with open(bed_file) as f:
        for line in f:
            line_split = line.rstrip().split()
            if skip is None:
                # if skip is None, we will read in every line,
                # no matter what
                window_starts.append(int(line_split[1]))
            elif line_split[-1] == "0":
                # if skip is not None, we will only keep
                # every skip-th UNBOUND line (ending in 0)
                if line_count % skip == 0:
                    window_starts.append(int(line_split[1]))
                line_count += 1
    return window_starts


def matches_across_sorted_lists(list_a, list_b):
    # This function takes in 2 SORTED lists of integers
    # and returns a boolean array the length of list_a,
    # where the entry at index i = list_a[i] in list_b
    matches = []
    b_index = 0
    for a_item in list_a:
        while True:
            if b_index >= len(list_b):
                matches.append(False)
                break
            if list_b[b_index] > a_item:
                matches.append(False)
                break
            elif list_b[b_index] == a_item:
                matches.append(True)
                b_index += 1
                break
            else:
                b_index += 1
    return matches


def get_alu_intersect(tf, skip = None):
    # This function generates binary labels for every example
    # in the test set, indicating whether or not that window
    # intersects with an annotated Alu element. The code for
    # doing this has been streamlined, because:
    # 1. The file returned by get_alu_intersect_file_chr2()
    #    already contains all in the test set that intersect 
    #    with annotated Alus (generated by bedtools intersect).
    # 2. We can uniquely identify windows by their start coordinate
    #    only, since all test examples are from the same chromosome.
    #
    # Thus, the process is:
    # 1. Load in (sorted) window starts for each test example
    # 2. Load in (sorted) window starts for each test example that
    #        overlaps with an annotated Alu
    # 3. Create boolean vector, the same length as the result of #1,
    #        indicating if each example in #1 is also in #2
    #
    # For #3, we take advantage of the fact that the window start
    # lists are already sorted and use a more efficient algorithm
    # than just "[start_coord in #2 for start_coord in #1]"
    
    # Step 1
    test_bed = get_test_bed_file("hg38", tf)
    test_windows = get_window_starts(test_bed, skip)
    # Step 2
    alu_intersect_bed = get_alu_intersect_file_chr2()
    alu_windows = get_window_starts(alu_intersect_bed, None)
    # Step 3
    return matches_across_sorted_lists(test_windows, alu_windows)
    
    

In [7]:
# This cell takes a while to run. To decrease time, increase SKIP.

preds_dict = defaultdict(lambda: defaultdict(lambda: dict()))

for tf in tfs:
    print("\n=====", tf, "=====\n")
    
    # get filename for test data
    val_file = get_test_bed_file("hg38", tf)
    
    # for each model type...
    for train_species in all_trainspecies:
        print("Loading models...")
        models = get_models_all_runs(tf, train_species)
    
        print("Predicting on bound sequences with " + train_species + "-trained models...")
        preds_dict["bound"][tf][train_species] = get_avg_preds_on_seqs(models, True, val_file,
                                                                       DA = train_species == "DA")
        print("Predicting on unbound sequences with " + train_species + "-trained models...")
        preds_dict["unbound"][tf][train_species] = get_avg_preds_on_seqs(models, False, val_file,
                                                                         DA = train_species == "DA",
                                                                        skip = SKIP)
        # clear memory to reduce unnecessary consumption
        del train_species, models
        keras.backend.clear_session()
del tf


===== CTCF =====

Loading models...
/users/kcochran/projects/domain_adaptation/models/CTCF/mm10_trained/basic_model/2020-08-12_16-21-33_run1_best.model
/users/kcochran/projects/domain_adaptation/models/CTCF/mm10_trained/basic_model/2020-08-12_17-16-29_run2_best.model
/users/kcochran/projects/domain_adaptation/models/CTCF/mm10_trained/basic_model/2020-08-12_18-11-29_run3_best.model
/users/kcochran/projects/domain_adaptation/models/CTCF/mm10_trained/basic_model/2020-08-12_19-06-35_run4_best.model
/users/kcochran/projects/domain_adaptation/models/CTCF/mm10_trained/basic_model/2020-08-12_20-01-19_run5_best.model
Predicting on bound sequences with mm10-trained models...
Predicting on unbound sequences with mm10-trained models...
Loading models...
/users/kcochran/projects/domain_adaptation/models/CTCF/mm10_trained/DA/2020-08-13_16-49-49_run1_best.model
/users/kcochran/projects/domain_adaptation/models/CTCF/mm10_trained/DA/2020-08-15_01-13-00_run2_best.model
/users/kcochran/projects/domain_a



In [8]:
### Plot functions

def bound_scatterplot(model1_preds, model2_preds, tf_name,
                      plot_index, model_names):
    # This function draws a single scatterplot of bound sites (subplot of figure).
    # model1_preds: x-axis values for all points to plot
    # model2_preds: y-axis values for all points to plot
    # plot_index: either 0 or 1. 0 = top plot in column, 1 = bottom plot.
    # model_names: plot-acceptable names for the models that generated the x-axis
    #     and y-axis predictions, respectively. Expecting a list of length 2.
    
    # First, a random sample of sites are chosen, so that
    # the plot is not too overcrowded
    model_preds_subsample = random.sample(list(zip(model1_preds, model2_preds)),
                            k = int(len(model1_preds) / BOUND_SUBSAMPLE_RATE))
    model1_preds_subsample = [pair[0] for pair in model_preds_subsample]
    model2_preds_subsample = [pair[1] for pair in model_preds_subsample]
    
    # Then each bound site is plotted as an individual dot on a scatter plot
    plt.scatter(model1_preds_subsample, model2_preds_subsample,
                alpha = ALPHA, s = DOT_SIZE, c = "#007DEA")
    
    # adjust axes to show all points, add ticks
    plt.xlim(0 - AX_OFFSET, 1 + AX_OFFSET)
    plt.ylim(0 - AX_OFFSET, 1 + AX_OFFSET)
    plt.xticks([0, 0.5, 1])
    plt.yticks([0, 0.5, 1])
    
    # add axis labels
    plt.ylabel(model_names[1] + " Model Prediction", fontsize = AXIS_SIZE)
    # add x-axis label only if this subplot is the bottom row of the figure
    if plot_index == len(tfs) - 1:
        if len(model_names[0]) > 5:  # adjust fontsize for longer model names
            plt.xlabel(model_names[0] + " Model Prediction", fontsize = AXIS_SIZE - 1)
        else:
            plt.xlabel(model_names[0] + " Model Prediction", fontsize = AXIS_SIZE)
        
    # add second "axis" to write TF name to the left of the plot
    # only do this for bound scatterplots because they are in left column of figure
    ax2 = plt.gca().twinx()
    ax2.spines["left"].set_position(("axes", 0 - TF_TWINAX_OFFSET))
    ax2.yaxis.set_label_position('left')
    ax2.yaxis.set_ticks_position('none')
    ax2.set_yticklabels([])
    ax2.set_ylabel(tf_name, fontsize = AXIS_SIZE + 2)
    
    # add text above subplot only if we are drawing in the top row of the figure
    if plot_index == 0:
        ax3 = plt.gca().twiny()
        ax3.spines["top"].set_position(("axes", 1))
        ax3.set_xticklabels([])
        ax3.set_xticks([])
        ax3.set_xlabel("Bound Sites", fontsize = AXIS_SIZE + 2)
    
    
    
def unbound_scatterplot(model1_preds, model2_preds,
                        plot_index, model_names):
    # This function draws a single scatterplot of unbound sites.
    # model1_preds: x-axis values for all points to plot
    # model2_preds: y-axis values for all points to plot
    # plot_index: either 0 or 1. 0 = top plot in column, 1 = bottom plot.
    # model_names: plot-acceptable names for the models that generated the x-axis
    #     and y-axis predictions, respectively. Expecting a list of length 2.
    
    # no subsampling here, as in bound_scatterplot(),
    # because we already subsampled unbound sites using SKIP
    plt.scatter(model1_preds, model2_preds, alpha = ALPHA, s = DOT_SIZE, c = "#D60242")
    
    # adjust axes
    plt.xlim(0 - AX_OFFSET, 1 + AX_OFFSET)
    plt.ylim(0 - AX_OFFSET, 1 + AX_OFFSET)
    plt.xticks([0, 0.5, 1])
    
    # label x-axis only if we are drawing subplot in bottom row of figure
    if plot_index == len(tfs) - 1:
        if len(model_names[0]) > 5:  # adjust fontsize for longer model names
            plt.xlabel(model_names[0] + " Model Prediction", fontsize = AXIS_SIZE - 1)
        else:
            plt.xlabel(model_names[0] + " Model Prediction", fontsize = AXIS_SIZE)
        
    # add text above subplot only if we are drawing in the top row of the figure
    if plot_index == 0:
        ax2 = plt.gca().twiny()
        ax2.spines["top"].set_position(("axes", 1))
        ax2.set_xticklabels([])
        ax2.set_xticks([])
        ax2.set_xlabel("Unbound Sites", fontsize = AXIS_SIZE + 2)

        
        
def generate_bound_unbound_scatters(preds_dict, train_species,
                                    save_files = False):
    # This function generates the full Figure 4,7, or 10 (bound and unbound sites).
    # preds_dict: a 3-layer dictionary, where keys for layer 1 are ["bound", "unbound"],
    #     keys for layer 2 are TF names, and keys for layer 3 are model type / species
    #     names (["mm10", "DA", "hg38"]).
    # train_species: a list of length 2 containing the model type / species names for
    #     the model predictions to plot on the x and y axes, respectively. Will be used
    #     to index into layer 3 of preds_dict.
    
    assert len(train_species) == 2, train_speies
    
    # translate short-hand model type names for plot-acceptable names
    model_names = [model_names_dict[string] for string in train_species]

    # setup subplots: two columns (1 for bound sites, 1 for unbound, 4 rows (1 per TF)
    mpl.rcParams.update(mpl.rcParamsDefault)
    fig, ax = plt.subplots(nrows = len(tfs), ncols = 2, figsize = FIG_SIZE_2_by_4,
                           sharex = True, sharey = True,
                           gridspec_kw = {'hspace': 0.08, 'wspace': 0.13})

    # iterate over rows of subplots
    for plot_index,tf in enumerate(tfs):
        # left subplot in this row will be for bound sites
        plt.sca(ax[plot_index][0])
        bound_scatterplot(preds_dict["bound"][tf][train_species[0]],
                          preds_dict["bound"][tf][train_species[1]],
                          tfs_latex_names[plot_index], plot_index, model_names)

        # right subplot in this row will be for unbound sites
        plt.sca(ax[plot_index][1])
        unbound_scatterplot(preds_dict["unbound"][tf][train_species[0]],
                            preds_dict["unbound"][tf][train_species[1]],
                            plot_index, model_names)
    
    if not save_files:
        plt.show()
    else:
        plt.savefig(ROOT + "plots/scatter_" + train_species[0] + "_" + train_species[1] + ".pdf",
                    bbox_inches='tight', pad_inches = 0)
        plt.savefig(ROOT + "plots/scatter_" + train_species[0] + "_" + train_species[1] + ".png",
                    bbox_inches='tight', pad_inches = 0)
        

        
def alu_unbound_scatterplot(model1_preds, model2_preds, tf_name, plot_index, model_names):
    # This function draws one scatterplot of unbound Alu windows.
    # model1_preds: x-axis values for all points to plot
    # model2_preds: y-axis values for all points to plot
    # plot_index: either 0 or 1. 0 = top plot in column, 1 = bottom plot.
    # model_names: plot-acceptable names for the models that generated the x-axis
    #     and y-axis predictions, respectively. Expecting a list of length 2.
    
    plt.scatter(model1_preds, model2_preds, alpha = ALPHA, s = DOT_SIZE, c = "#D60242")
    plt.xlim(0 - AX_OFFSET, 1 + AX_OFFSET)
    plt.ylim(0 - AX_OFFSET, 1 + AX_OFFSET)
    plt.xticks([0, 0.5, 1])
    plt.yticks([0, 0.5, 1])
    plt.ylabel(model_names[1] + " Model Prediction", fontsize = AXIS_SIZE)
    
    # if plot_index is 1, we are at the bottom row in this column of subplots
    # so we should add the x-axis label
    if plot_index == len(tfs) // 2 - 1:
        if len(model_names[0]) > 5:
            plt.xlabel(model_names[0] + " Model Prediction", fontsize = AXIS_SIZE - 1)
        else:
            plt.xlabel(model_names[0] + " Model Prediction", fontsize = AXIS_SIZE)
    
    # this second axis is actually the TF name on the left of the normal y-axis label
    ax2 = plt.gca().twinx()
    ax2.spines["left"].set_position(("axes", 0 - TF_TWINAX_OFFSET))
    ax2.yaxis.set_label_position('left')
    ax2.yaxis.set_ticks_position('none')
    ax2.set_yticklabels([])
    ax2.set_ylabel(tf_name, fontsize = AXIS_SIZE + 2)
    
    # if we are at the top subplot in the column, add the column "title" above
    if plot_index == 0:
        ax2 = plt.gca().twiny()
        ax2.spines["top"].set_position(("axes", 1))
        ax2.set_xticklabels([])
        ax2.set_xticks([])
        ax2.set_xlabel("Unbound " + r"$\bf{Alus}$", fontsize = AXIS_SIZE + 2)
        
        
# run twice, to get 2 x 2 for 4 TFs
# once with tf_split_half = 1 and once with tf_split_half = 2
def generate_unbound_alu_scatters(preds_dict, train_species, tf_split_half = 1,
                                  skip = None, save_files = False):
    # This function generates half of the plot shown in Figures 5 and 11.
    # if tf_split_half = 1, the left half of the plot is drawn; if 2, the right half.
    # preds_dict: a 3-layer dictionary, where keys for layer 1 are ["bound", "unbound"],
    #     keys for layer 2 are TF names, and keys for layer 3 are model type / species
    #     names (["mm10", "DA", "hg38"]).
    # train_species: a list of length 2 containing the model type / species names for
    #     the model predictions to plot on the x and y axes, respectively. Will be used
    #     to index into layer 3 of preds_dict.
    # skip here should correspond to the skip argument used elsewhere in the code.
    
    
    model_names = [model_names_dict[string] for string in train_species]
    
    mpl.rcParams.update(mpl.rcParamsDefault)

    # create a plot with 2 rows and 1 column of subplots
    fig, ax = plt.subplots(nrows = len(tfs) // 2, ncols = 1, figsize = FIG_SIZE_1_by_2,
                           sharex = True, gridspec_kw = {'hspace': 0.08})

    # generate the "left half" of the full 2x2 plot
    if tf_split_half == 1:
        for plot_index,tf in enumerate(tfs[:2]):
            alu_labels = get_alu_intersect(tf, skip = skip)
            num_preds = len(preds_dict["unbound"][tf][train_species[0]])
            alu_labels = alu_labels[:num_preds]
            plt.sca(ax[plot_index])
            alu_unbound_scatterplot(np.array(preds_dict["unbound"][tf][train_species[0]])[alu_labels],
                                    np.array(preds_dict["unbound"][tf][train_species[1]])[alu_labels],
                                    tfs_latex_names[plot_index], plot_index, model_names)
    
    # generate the "right half" of the full 2x2 plot
    else:
        for plot_index,tf in enumerate(tfs[2:]):
            alu_labels = get_alu_intersect(tf, skip = skip)
            num_preds = len(preds_dict["unbound"][tf][train_species[0]])
            alu_labels = alu_labels[:num_preds]
            plt.sca(ax[plot_index])
            alu_unbound_scatterplot(np.array(preds_dict["unbound"][tf][train_species[0]])[alu_labels],
                                    np.array(preds_dict["unbound"][tf][train_species[1]])[alu_labels],
                                    tfs_latex_names[plot_index + 2], plot_index, model_names)
    
    if not save_files:
        plt.show()
    else:
        plt.savefig("../plots/scatter_" + train_species[0] + "_" + train_species[1] + "_alus_" + str(tf_split_half) + ".png",
                    bbox_inches='tight', pad_inches = 0)
        plt.savefig("../plots/scatter_" + train_species[0] + "_" + train_species[1] + "_alus_" + str(tf_split_half) + ".pdf",
                    bbox_inches='tight', pad_inches = 0)
        
        
        

In [17]:
SAVE_FILES = True

In [18]:
generate_bound_unbound_scatters(preds_dict, train_species = ["mm10", "hg38"], save_files = SAVE_FILES)

In [19]:
generate_unbound_alu_scatters(preds_dict, train_species = ["mm10", "hg38"], tf_split_half = 1, skip = SKIP, save_files = SAVE_FILES)

In [20]:
generate_unbound_alu_scatters(preds_dict, train_species = ["mm10", "hg38"], tf_split_half = 2, skip = SKIP, save_files = SAVE_FILES)

In [21]:
generate_bound_unbound_scatters(preds_dict, train_species = ["DA", "hg38"], save_files = SAVE_FILES)

In [22]:
generate_unbound_alu_scatters(preds_dict, train_species = ["DA", "hg38"], tf_split_half = 1, save_files = SAVE_FILES)

In [23]:
generate_unbound_alu_scatters(preds_dict, train_species = ["DA", "hg38"], tf_split_half = 2, save_files = SAVE_FILES)

In [24]:
generate_bound_unbound_scatters(preds_dict, train_species = ["NS", "hg38"], save_files = SAVE_FILES)