In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [2]:
GENOMES = { "mm10" : "/users/kcochran/genomes/mm10_no_alt_analysis_set_ENCODE.fasta",
            "hg38" : "/users/kcochran/genomes/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta" }

ROOT = "/users/kcochran/projects/domain_adaptation_nosexchr/"
PEAKS_DIR = ROOT + "data/"
BIGWIGS_DIR = ROOT + "profile_model_data/"

# shorthand names for all model types to include in plots
all_trainspecies = ["mm10", "hg38"]

# plot-acceptable names for model types
model_names_dict = {"mm10" : "Mouse-trained",
                    "hg38" : "Human-trained"}


tfs = ["CTCF", "CEBPA", "Hnf4a", "RXRA"]
# plot-acceptable TF names
tfs_latex_names = ["CTCF", "CEBPA", "HNF4A", "RXRA"]

In [3]:
MAX_JITTER = 200
INPUT_SEQ_LEN = 2114
OUTPUT_PROF_LEN = 1000

In [49]:
import torch
from model_arch import *
from attr_prior_utils import *   # Alex's code
from data_transforms import *
from generators import *
import math
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


# This is a modified version of the usual generator used
# for binary model training. It is now in Pytorch, and it
# processes data on the fly instead of all when the generator
# is initialized, since the test set is large.

class JITGenerator(Dataset):
    letter_dict = {
        'a':[1,0,0,0],'c':[0,1,0,0],'g':[0,0,1,0],'t':[0,0,0,1],
        'n':[0,0,0,0],'A':[1,0,0,0],'C':[0,1,0,0],'G':[0,0,1,0],
        'T':[0,0,0,1],'N':[0,0,0,0]}

    def __init__(self, filepaths_dict,
                 seq_len,
                 profile_len,
                 max_jitter,
                 batch_size = 1024,  # set to max # GPU can hold
                 transform = None,
                 return_labels = True, return_controls = True):
        
        for key in filepaths_dict:
            setattr(self, key, filepaths_dict[key])
        
        self.prof_len = profile_len
        self.max_jitter = max_jitter
        self.transform = transform
        self.return_labels = return_labels
        self.return_controls = return_controls
        self.seq_len = seq_len
        self.batch_size = batch_size

        self.set_len()
        self.coords = self.get_coords()


    def __len__(self):
        return self.len
    
    
    def set_len(self):
        with open(self.peakfile) as f:
            self.len = math.ceil(sum([1 for _ in f]) / self.batch_size)


    def get_coords(self):
        # this function loads in the coordinates for each
        # example in the test file
        with open(self.peakfile) as posf:
            coords_tmp = [line.split()[:3] for line in posf]  # expecting bed file format
        
        coords = []
        for coord in coords_tmp:
            chrom, start, end = coord[0], int(coord[1]), int(coord[2])
            # since the peak file may not have the right window size,
            # this function ensures the window is the correct length
            window_start, window_end = expand_window(start, end,
                                                     self.seq_len + 2 * self.max_jitter)
            coords.append((coord[0], window_start, window_end))  # no strand consideration
        return coords
            

    def get_profiles_and_logcounts(self, coords, pos_bw_file, neg_bw_file):
        profiles = []
        logcounts = []
        for chrom, start, end in coords:
            # we need to edit the start and end coords, to get a profile
            # that is the right length to match the model's output size
            # this is smaller than the input size (the size written in the coords files)
            # because of the model's receptive field and deconv layer kernel width
            prof_start, prof_end = expand_window(start, end, self.prof_len + 2 * self.max_jitter)

            # pyBigWig can read from bigWig files and fetch data at a specific genomic region
            # we have two bigWig readers open, one for each DNA strand
            with pyBigWig.open(pos_bw_file) as pos_bw_reader:
                # read in profile values for the positive strand
                pos_profile = np.array(pos_bw_reader.values(chrom, prof_start, prof_end))
            with pyBigWig.open(neg_bw_file) as neg_bw_reader:
                # read in profile values for the negative strand
                neg_profile = np.array(neg_bw_reader.values(chrom, prof_start, prof_end))

            # pyBigWig sometimes returns nan when the real data is just 0
            pos_profile[np.isnan(pos_profile)] = 0
            neg_profile[np.isnan(neg_profile)] = 0

            # stick the strands together in an array of shape (2, profile_len)
            profile = np.array([pos_profile, neg_profile])
            profiles.append(profile)
            
            # derive values for the counts task by adding up profile
            # we take the log bfor technical reasons -- counts are vaguely
            # Poisson or negative-binomial distributed, and it is easier
            # for the model to model them in log space because of that
            pos_logcount = np.log(np.sum(pos_profile) + 1)
            neg_logcount = np.log(np.sum(neg_profile) + 1)

            # stick the strands together in an array of shape (2,)
            logcount = np.array([pos_logcount, neg_logcount])
            logcounts.append(logcount)
            
        return np.array(profiles), np.array(logcounts)
                

    def convert(self, coords):
        # fetch the sequence for a given site/region in the genome, and then one-hot encode
        seqs_onehot = []
        with Fasta(self.genome_file) as converter:
            for chrom, start, stop in coords:
                assert chrom in converter
                # get sequence
                seq = converter[chrom][start:stop].seq
                # convert to one-hot
                # this array will have shape (4, seq_len)
                # this is transposed relative to other code you've written, Kelly
                seq = np.array([self.letter_dict.get(x,[0,0,0,0]) for x in seq]).T
                seqs_onehot.append(seq)
            
        return np.array(seqs_onehot)


    def __getitem__(self, batch_index):	
        # this function returns one batch's worth of data
        coords_batch = self.coords[batch_index * self.batch_size : min((batch_index + 1) * self.batch_size, len(self.coords))]
        
        # get one-hot sequences for this batch
        onehot = self.convert(coords_batch)
        assert onehot.shape[0] > 0, onehot.shape
        to_return = [onehot]

        if self.return_labels:
            # get this batch's profiles and logcounts
            profiles, logcounts = self.get_profiles_and_logcounts(coords_batch,
                                                                  self.pos_bw,
                                                                  self.neg_bw)
            to_return.extend([profiles, logcounts])
            
        if self.return_controls:
            # get this batch's profiles and logcounts for the control track
            control_profiles, control_logcounts = self.get_profiles_and_logcounts(coords_batch,
                                                                  self.pos_control_bw,
                                                                  self.neg_control_bw)
            to_return.extend([control_profiles, control_logcounts])
             
        # run optional jittering on data, if applicable
        if self.transform is not None:
            to_return = self.transform(to_return)
        
        # convert numpy arrays to tensors for Pytorch to use
        to_return = [torch.tensor(x.squeeze(), dtype=torch.float) for x in to_return]
        return to_return

    
    
# This generator specifically points at the same
# test set used for the binary models (chromosome 2).

# It extends the generator above in that it also loads
# the binary labels for each example, in addition to
# the start and stop coordinates for the example.

class UnbalancedTestGenerator(JITGenerator):
    def __init__(self, species, tf,
                 seq_len = INPUT_SEQ_LEN,
                 profile_len = OUTPUT_PROF_LEN,
                 max_jitter = MAX_JITTER,
                 batch_size = 1024,
                 transform = None,
                 return_labels = True, return_controls = True):
        
        self.peakfile = PEAKS_DIR + species + "/" + tf + "/chr2.bed"
            
        self.pos_bw = BIGWIGS_DIR + species + "/" + tf + "/all_reps.pos.bigWig"
        self.neg_bw = BIGWIGS_DIR + species + "/" + tf + "/all_reps.neg.bigWig"
        self.pos_control_bw = BIGWIGS_DIR + species + "/" + tf + "_control/all_reps.pos.bigWig"
        self.neg_control_bw = BIGWIGS_DIR + species + "/" + tf + "_control/all_reps.neg.bigWig"
        self.prof_len = profile_len
        self.max_jitter = max_jitter
        self.transform = transform
        self.return_labels = return_labels
        self.return_controls = return_controls
        
        self.batch_size = batch_size
        
        self.genome_file = GENOMES[species]
        self.seq_len = seq_len

        self.set_len()
        self.coords, self.labels = self.get_coords_and_labels()
        
        
    def __len__(self):
        return self.len
    

    def get_coords_and_labels(self):
        with open(self.peakfile) as posf:
            coords_tmp = [line.split()[:4] for line in posf]  # expecting bed file format
        
        coords = []
        labels = []
        for coord in coords_tmp:
            chrom, start, end, label = coord[0], int(coord[1]), int(coord[2]), int(coord[3])
            window_start, window_end = expand_window(start, end,
                                                     self.seq_len + 2 * self.max_jitter)
            coords.append((chrom, window_start, window_end))  # no strand consideration
            labels.append(label)
        return coords, labels

In [50]:
def get_pred_logcounts(data_loader, model):
    # this function generates a model prediction for each
    # example in the data loader.
    pred_logcounts = []

    model.cuda()
    model.eval()
    for seq, control_profile, control_logcounts in data_loader:
        seq = seq.squeeze().cuda()
        control_profile = pad_control_profile(control_profile.squeeze(), model.untrimmed_prof_len).cuda()
        control_logcounts = control_logcounts.squeeze().cuda()
        pred_logcount = model((seq, control_profile, control_logcounts))[1].cpu().detach().numpy()
        pred_logcounts.append(pred_logcount)
    model.cpu()
    
    return np.array(pred_logcounts).squeeze()

In [51]:
import pandas as pd

def format_data_for_seaborn(auPRC_dicts):
    # This function re-formats the "auPRC_dicts" list of dicts
    # into one pandas DataFrame that matches how seaborn expects
    # data to be input for the plot we will be making
    tf_col = []
    species_col = []
    auprc_col = []
    
    for tf in tfs:
        tf_col.extend([tf] * len(all_trainspecies))
        for species in all_trainspecies:
            species_col.append(model_names_dict[species])
            auprc_col.append(auPRC_dicts[species][tf])
        
    return pd.DataFrame({"TF":tf_col, "Species":species_col, "auPRC":auprc_col})

In [None]:
from sklearn.metrics import average_precision_score
import pandas as pd
from collections import defaultdict


human_auprc_dict = defaultdict(lambda : dict())
for train_species in all_trainspecies:
    for tf in tfs:
        print("=== " + tf + " " + train_species + " ===")
        model_save_path = ROOT + "models/profile_models/" + train_species + "-trained/" + tf + "/bestprof.model"
        model = torch.load(model_save_path)

        unbalval_gen = UnbalancedTestGenerator("hg38", tf, 
                           transform = NoJitter(MAX_JITTER, INPUT_SEQ_LEN, OUTPUT_PROF_LEN),
                           return_labels = False, return_controls = True)
        unbalval_data_loader = DataLoader(unbalval_gen, batch_size = 1, shuffle = False)
        
        print("Predicting log-counts...")
        pred_logcounts = get_pred_logcounts(unbalval_data_loader, model)
        auPRC = average_precision_score(unbalval_gen.labels, np.sum(pred_logcounts, axis = 1))
        print(auPRC)
        human_auprc_dict[train_species][tf] = auPRC
        

human_auprc_df = format_data_for_seaborn(human_auprc_dict)

=== CTCF mm10 ===
Predicting log-counts...


In [None]:
# to avoid re-running, save results to a file
human_auprc_df.to_csv("hg38_test_profile_model_auPRCs.csv")

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
import matplotlib as mpl

In [None]:
# Plotting code

# Constants to specify plot appearance details
DOT_SIZE = 10
FIG_SIZE_UNIT = 5
FIG_SIZE = (FIG_SIZE_UNIT + 1.5, FIG_SIZE_UNIT - 1)
FIG_SIZE_SMALL = (FIG_SIZE_UNIT, FIG_SIZE_UNIT - 1)
COLORS = ["#0062B8", "#FF0145"]
AX_FONTSIZE = 16
AXTICK_FONTSIZE = 13
TITLESIZE = 15

from matplotlib.lines import Line2D


def make_boxplot(df, species, save_files = False,
                 fig_size = FIG_SIZE, colors_to_use = COLORS,
                 dot_size = DOT_SIZE, titlesize = TITLESIZE,
                 ax_fontsize = AX_FONTSIZE,
                 axtick_fontsize = AXTICK_FONTSIZE):
    
    # This function creates one boxplot using seaborn.
    # The data plotted must be stored in a pandas DataFrame (input = "df"),
    # including 3 columns: TF, Species, and auPRC (case-sensitive names).

    # Use the argument save_files to toggle between saving plots
    # and outputting them within the notebook.
    
    # If you want to create a plot containing only a subset of the data
    # in your input DataFrame, specify which training species / model types
    # to include by listing the model types by name in a list and give
    # to the argument "include" (see cell below for examples). Plotting
    # will follow the order of the model types as they are listed in "include".
    
    
    # determine y-axis upper limit of plots
    # this is done before data is subsetted to keep axis consistent
    # regardless of which subset of data is used
    yax_max = max(df["auPRC"]) + 0.05
    
    df_to_use = df
    cols_list = colors_to_use
    cols = sns.color_palette(colors_to_use)
    
    sns.set(style = "white")

    # plot individual dots
    ax = sns.barplot(x = "TF", y = "auPRC", hue = "Species",
                       data = df,
                       dodge = True,
                       palette = cols,
                       #size = dot_size,
                       #edgecolor = "0.0001",
                       linewidth = 1)
    
    labels_list = [model_names_dict[species] for species in all_trainspecies]
    legend_elements = [Line2D([0], [0], marker='o', color='w', label=species,
                              markeredgecolor='k', markeredgewidth=1,
                          markerfacecolor=c, markersize=10) for c, species in zip(cols_list, labels_list)]

    ax.legend(handles=legend_elements, loc = 'upper right', ncol = 1)

    # add legend
    #ax.legend(loc = 'upper right', ncol = 1, frameon = False)

    # format and label axes
    ax.set_xlabel("", fontsize = 0)
    ax.set_ylabel("Area Under PRC", fontsize = ax_fontsize)
    ax.set_xticklabels(labels = tfs_latex_names, fontsize = ax_fontsize)
    ax.tick_params(axis='y', which='major', pad = -2, labelsize = axtick_fontsize)
    plt.ylim(0, yax_max) # limit is hard-coded so that it's constant across all plots
    plt.yticks([0, 0.2, 0.4, 0.6])
    
    # use plot-acceptable version of test data species name
    # e.g. "mm10" --> "Mouse"
    title = "Model Performance, "
    title += r"$\bf{" + model_names_dict[species].replace("-trained", "") + "}$"
    title += " Test Data"
    plt.title(title, fontsize = titlesize)

    if save_files:
        plt.savefig(ROOT + "plots/profile_dotplots_" + species + "_test.png",
                    bbox_inches='tight', pad_inches = 0.1, dpi = 300)
        plt.savefig(ROOT + "plots/profile_dotplots_" + species + "_test.pdf",
                    bbox_inches='tight', pad_inches = 0.1)
        
    plt.show()

In [None]:
sns.set(rc = {'figure.figsize' : FIG_SIZE})
plt.figure()
make_boxplot(human_auprc_df, "hg38", save_files = True)

In [None]:
from sklearn.metrics import average_precision_score
import pandas as pd
from collections import defaultdict


mouse_auprc_dict = defaultdict(lambda : dict())
for train_species in all_trainspecies:
    for tf in tfs:
        print("=== " + tf + " " + train_species + " ===")
        model_save_path = ROOT + "models/profile_models/" + train_species + "-trained/" + tf + "/bestprof.model"
        model = torch.load(model_save_path)

        unbalval_gen = UnbalancedTestGenerator("mm10", tf, 
                           transform = NoJitter(MAX_JITTER, INPUT_SEQ_LEN, OUTPUT_PROF_LEN),
                           return_labels = False, return_controls = True)
        unbalval_data_loader = DataLoader(unbalval_gen, batch_size = 1, shuffle = False)

        pred_logcounts = get_pred_logcounts(unbalval_data_loader, model)
        auPRC = average_precision_score(unbalval_gen.labels, np.sum(pred_logcounts, axis = 1))
        print(auPRC)
        mouse_auprc_dict[train_species][tf] = auPRC
        

mouse_auprc_df = format_data_for_seaborn(mouse_auprc_dict)

# to avoid re-running, save results to a file
mouse_auprc_df.to_csv("mm10_test_profile_model_auPRCs.csv")

In [None]:
sns.set(rc = {'figure.figsize' : FIG_SIZE})
plt.figure()
make_boxplot(mouse_auprc_df, "mm10", save_files = True)