In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# shorthand names for each TF
tfs = ["CTCF", "CEBPA", "Hnf4a", "RXRA"]
# plot-acceptable names for each TF
tfs_latex_names = ["CTCF", "CEBPα", "HNF4α", "RXRα"]

# shorthand names for each species
all_trainspecies = ["mm10", "hg38"]
# plot-acceptable names for each species
model_names_dict = {"mm10" : "Mouse", "hg38" : "Human"}
# colors to associate with each species in plots
colors_dict = {"mm10" : "tab:blue", "hg38" : "tab:orange"}

# constants to be used in plotting functions
DOT_SIZE = 5
ALPHA = 0.5
AXIS_SIZE = 11
AX_OFFSET = 0.02
TF_TWINAX_OFFSET = 0.35
FIG_SIZE_UNIT = 5
FIG_SIZE_2_by_4 = (FIG_SIZE_UNIT, FIG_SIZE_UNIT * 2)
FIG_SIZE_1_by_4 = (FIG_SIZE_UNIT / 2, FIG_SIZE_UNIT * 2)

In [3]:
log_out_root = "/users/kcochran/projects/domain_adaptation/logs/training/"

def get_model_files(species, tf, DA = False, runs = 5):
    # This function supplies a list of log filenames.
    # Each filename is for the log file for a specific run/replicate.
    # See the scripts run_training.sh and run_DA_training.sh for
    # log file creation (output of model training). See the function
    # get_both_species_auprcs() below for expected log file content.
    
    if DA:
        prefix = log_out_root + "DA_" + species + "_" + tf + "_run"
    else:
        prefix = log_out_root + "BM_" + species + "_" + tf + "_run"
    suffix = ".log"
    
    return [prefix + str(i) + suffix for i in range(1, runs + 1)]

In [4]:
def get_both_species_auprcs(model_out_filename):
    # This function reads in the info stored in a single log file.
    # The log files contain a lot of info/junk, but we only care
    # about the auPRCs for each epoch. Those can be found on lines
    # that start with "auPRC:" and are listed in the file in order
    # of epoch. Each epoch, both the source species' and target
    # species' auPRCs are listed, with the target species' auPRC
    # listed first.
    # This function returns a tuple of two lists: list 1 is the
    # auPRCs across each epoch when the model was evaluated on
    # mouse data; list 2 is the auPRCs across each epoch when the
    # model was evaluated on human data.
    
    if "hg38" in model_out_filename:
        source = "hg38"
        target = "mm10"
    else:
        source = "mm10"
        target = "hg38"
        
    lines = {source : [], target : []}
    line_counter = 0
    with open(model_out_filename) as f:
        # assuming auPRCs are listed by epoch
        # with target species listed first, then source species
        for line in f:
            if line.startswith("auPRC"):
                auprc = float(line.strip().replace("auPRC:\t", ""))
                if line_counter % 2 == 0:
                    lines[target].append(auprc)
                else:
                    lines[source].append(auprc)
                line_counter += 1
               
    assert len(lines["mm10"]) == len(lines["hg38"])
    return lines["mm10"], lines["hg38"]

In [21]:
def auprc_lineplot(model_files, tf_name, plot_index, plot_y_index, y_max = None, title = None):
    # This function creates a single line subplot (to be called repeatedly).
    # Arguments:
    #     - model_files: paths for the log files for all model runs,
    #           for a given TF (output of get_model_files())
    #     - tf_name: name of the TF to display on the plot
    #     - plot_index: the top-to-bottom index of the subplot
    #     - plot_y_index: the left-to-right index of the subplot
    #     - y_max: optional, manually set the top limit of the y-axis
    #           for this subplot (default auto-detects max of data plotted)
    #     - title: optional, 
    ax = plt.gca()
    
    # First, load in the auPRCs across all epochs for all model runs
    # Keep track of the max auPRC so the y-axis limits can be set properly
    # Also keep track of the legend handles to use later
    max_auprc_so_far = 0
    legend_handles = []
    for model_out_file in model_files:
        mm10_auprcs, hg38_auprcs = get_both_species_auprcs(model_out_file)
        l1 = ax.plot(range(1, len(mm10_auprcs) + 1), mm10_auprcs,
                     c = colors_dict["mm10"], alpha = ALPHA)[0]
        l2 = ax.plot(range(1, len(hg38_auprcs) + 1), hg38_auprcs,
                     c = colors_dict["hg38"], alpha = ALPHA)[0]
        legend_handles = [l1, l2]
        ax.set_xticks([])
        
        # keep trakc of max auPRC seen so far
        if max_auprc_so_far < max([max_auprc_so_far] + mm10_auprcs + hg38_auprcs):
            max_auprc_so_far = max([max_auprc_so_far] + mm10_auprcs + hg38_auprcs)

    # if we are plotting a subplot in the leftmost column...
    if plot_y_index == 0:
        # label the y-axis with "auPRC"
        ax.set_ylabel("auPRC", fontsize = AXIS_SIZE)
        
        # add the TF name label to the far left of the plot
        ax2 = plt.gca().twinx()
        ax2.spines["left"].set_position(("axes", 0 - TF_TWINAX_OFFSET))
        ax2.yaxis.set_label_position('left')
        ax2.yaxis.set_ticks_position('none')
        ax2.set_yticklabels([])
        ax2.set_ylabel(tf_name, fontsize = AXIS_SIZE + 2)
    else:
        ax.set_yticklabels([])
        
    # set top limit of y-axis
    if y_max is None:
        y_max = max_auprc_so_far
    ax.set_ylim(0, y_max + 0.02)
    
    # if we're drawing a subplot in the top row of the plot...
    if plot_index == 0:
        # draw an invisible extra axis on top of the subplot
        ax3 = plt.gca().twiny()
        ax3.spines["top"].set_position(("axes", 1))
        ax3.set_xticklabels([])
        ax3.set_xticks([])
        
        # if we're drawing a subplot in the left column...
        if plot_y_index == 0:
            # add a column title (not the x-axis, just coded hackily)
            if title is None:
                ax3.set_xlabel("Mouse-trained Models", fontsize = AXIS_SIZE + 1)
            else:
                ax3.set_xlabel(title, fontsize = AXIS_SIZE + 1)
        else:  # otherwise, you're in the right column...
            # add a column title (not the x-axis, just coded hackily)
            if title is None:
                ax3.set_xlabel("Human-trained Models", fontsize = AXIS_SIZE + 1)
            else:
                ax3.set_xlabel(title, fontsize = AXIS_SIZE + 1)
        
    # if you're drawing a subplot in the bottom row of the plot...
    if plot_index == 3:
        # add an x-axis for epochs
        ax.set_xlabel("Epochs", fontsize = AXIS_SIZE)
        ax.set_xticks([1, 5, 10, 15])
        ax.set_xticklabels([1, 5, 10, 15])
    else:
        # otherwise don't label the x-axis
        ax.set_xticks([])

    return legend_handles
        
        
def get_y_max(list_of_file_lists):
    # To ensure that the y-axis is the same scale across
    # a row of subplots, calculate the max limit in advance.
    # This max is calculated over all model log files to be
    # used in plotting (one for each replicate run).
    y_max = 0
    for file_list in list_of_file_lists:
        for model_out_file in file_list:
            mm10_auprcs, hg38_auprcs = get_both_species_auprcs(model_out_file)
            y_max = max([y_max] + mm10_auprcs + hg38_auprcs)
    return y_max
    

def generate_all_auprc_plots(tf_list, save_file = None):
    # This function draws Figure 2.
    
    # For each TF and each species, retrieve the model log files
    mm10_trained_files = {tf : get_model_files("mm10", tf, False) for tf in tf_list}
    hg38_trained_files = {tf : get_model_files("hg38", tf, False) for tf in tf_list}
    
    plt.rcParams.update(plt.rcParamsDefault)

    fig, ax = plt.subplots(nrows = len(tfs), ncols = 2, figsize = FIG_SIZE_2_by_4,
                           gridspec_kw = {'hspace': 0.08, 'wspace': 0.08})

    legend_handles = []
    for plot_index, tf in enumerate(tfs):  # iterating over rows of subplots
        y_max = get_y_max([mm10_trained_files[tf], hg38_trained_files[tf]])
        
        # draw the left subplot in this row
        plt.sca(ax[plot_index][0])
        legend_handles = auprc_lineplot(mm10_trained_files[tf],
                                        tfs_latex_names[plot_index],
                                        plot_index, 0, y_max = y_max)
        # draw the right subplot in this row
        plt.sca(ax[plot_index][1])
        _ = auprc_lineplot(hg38_trained_files[tf],
                           tfs_latex_names[plot_index],
                           plot_index, 1, y_max = y_max)
    
    # add a legend below all the subplots
    if len(legend_handles) > 0:
        fig.legend(legend_handles,
                   ["Mouse Validation Set", "Human Validation Set"],
                  loc = "lower center", ncol = 2,
                  bbox_to_anchor=[0.5, 0.012])
    
    if save_file is None:
        plt.show()
    else:
        plt.savefig(save_file, bbox_inches = 'tight', pad_inches = 0)
        
        

def generate_all_auprc_plots_DA(tf_list, save_file = None):
    # Same as the function above, but for DA models.
    mm10_trained_files = {tf : get_model_files("mm10", tf, DA = True) for tf in tf_list}
    
    plt.rcParams.update(plt.rcParamsDefault)

    fig, ax = plt.subplots(nrows = len(tfs), ncols = 1, figsize = FIG_SIZE_1_by_4,
                           gridspec_kw = {'hspace': 0.08, 'wspace': 0.08})

    legend_handles = []
    for plot_index, tf in enumerate(tfs):
        y_max = get_y_max([mm10_trained_files[tf]])
        
        plt.sca(ax[plot_index])
        legend_handles = auprc_lineplot(mm10_trained_files[tf],
                                        tfs_latex_names[plot_index],
                                        plot_index, 0, y_max = y_max)
    
    if len(legend_handles) > 0:
        fig.legend(legend_handles,
                   ["Mouse Validation Set", "Human Validation Set"],
                  loc = "lower center", ncol = 1,
                  bbox_to_anchor=[0.75, 0.005])
    
    if save_file is None:
        plt.show()
    else:
        plt.savefig(save_file, bbox_inches = 'tight', pad_inches = 0)
    
    

In [11]:
generate_all_auprc_plots(tfs, save_file = "../plots/auprc_over_epochs_mm10_hg38.pdf")

In [23]:
generate_all_auprc_plots_DA(tfs, save_file = "../plots/auprc_over_epochs_mm10_hg38_DA.pdf")