In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [None]:
import keras
import numpy as np
from collections import defaultdict

In [None]:
ROOT = "/users/kcochran/projects/domain_adaptation_nosexchr/"

# shorthand for each TF name
tfs = ["CTCF", "CEBPA", "Hnf4a", "RXRA"]

# shorthand names for all model training types to generate predictions for
all_trainspecies = ["mm10", "DA", "hg38", "NS"]

# these are specifically the "species" with test datasets to evaluate on
all_testspecies = ["mm10", "hg38"]

# Data Loading

In [None]:
from keras.utils import Sequence
from seqdataloader.batchproducers.coordbased.core import Coordinates
from seqdataloader.batchproducers.coordbased.coordstovals.fasta import PyfaidxCoordsToVals

GENOMES = {"mm10" : "/users/kcochran/genomes/mm10_no_alt_analysis_set_ENCODE.fasta",
            "hg38" : "/users/kcochran/genomes/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"}


def get_test_bed_file(species):
    # This function returns the path to a BED-format file
    # containing the chromosome names, starts, and ends for
    # all examples to test the model with.
    # Note binding labels will not be loaded in.
    # This file should contain the same examples for any TF.
    return(ROOT + "data/" + species + "/" + tfs[0] + "/chr2.bed")


class ValGenerator(Sequence):
    # This generator retrieves all coordinates for windows in the test set
    # and converts the sequences in those windows to one-hot encodings.
    # Which species to retrieve test windows for is specified with
    # the "val_species" argument. 
    
    def __init__(self, batchsize, val_species = "hg38"):
        self.valfile = get_test_bed_file(val_species)
        self.get_steps(batchsize)
        self.converter = PyfaidxCoordsToVals(GENOMES[val_species])
        self.batchsize = batchsize
        self.get_coords()
        
        
    def get_steps(self, batchsize):
        # calculates the number of steps needed to get through
        # all batches of examples in the test dataset
        # (Keras predict_generator code needs to know this)
        with open(self.valfile) as f:
            lines_in_file = sum(1 for line in f)
        
        self.steps = lines_in_file // batchsize


    def __len__(self):
        return self.steps

    def get_coords(self):
        # load all coordinates for the test data into memory
        with open(self.valfile) as f:
            coords_tmp = [line.rstrip().split()[:3] for line in f]
            
        assert [len(line_split) == 3 for line_split in coords_tmp]
        self.coords = [Coordinates(coord[0], int(coord[1]), int(coord[2])) for coord in coords_tmp]

    def __getitem__(self, batch_index):
        # convert a batch's worth of coordinates into one-hot sequences
        batch = self.coords[batch_index * self.batchsize : (batch_index + 1) * self.batchsize]
        return self.converter(batch)
    

def get_preds_batched_fast(model, batch_size, test_species = "hg38"):
    # Make predictions for all test data using a specified model.
    # Batch_size can be as big as your compute can handle.
    # Use test_species = "mm10" to test on mouse data instead of human data.
    
    print("Generating predictions...")
    return np.squeeze(model.predict_generator(ValGenerator(batch_size, test_species),
                                               use_multiprocessing = True, workers = 8, verbose = 1))

# Model Loading

In [None]:
# needed to load DA models

from flipGradientTF import GradientReversal
import tensorflow

def custom_loss(y_true, y_pred):
    # this should be the same implementation as what was used when the DA model trained
    y_pred = tensorflow.boolean_mask(y_pred, tensorflow.not_equal(y_true, -1))
    y_true = tensorflow.boolean_mask(y_true, tensorflow.not_equal(y_true, -1))
    return keras.losses.binary_crossentropy(y_true, y_pred)

In [None]:
def get_model_file(tf, train_species, run = 1, model_type = "best"):
    # This function returns the filepath where the model for a given
    # TF, training species, and run is saved.
    # By default, the file for the best model across all training epochs
    # is returned, you can change model_type to select the last model instead.
    # This function specifically looks for the most recent model file,
    # if there are multiple for the same run-TF-species combo.
    try:
        run_int = int(run)
    except:
        print("Error: You need to pass in a run number that can be cast to int.")
    
    model_file_prefix = ROOT + "/".join(["models", tf, train_species + "_trained", "basic_model/"])
        
    if train_species == "DA":
        # assuming all DA models are mouse-trained
        model_file_prefix = model_file_prefix.replace("DA", "mm10")  
        model_file_prefix = model_file_prefix.replace("basic_model", "DA")
    if train_species == "NS":
        # assuming the no-SINEs models are trained on human data
        model_file_prefix = model_file_prefix.replace("basic_model", "basic_model_nosines")
        model_file_prefix = model_file_prefix.replace("NS", "hg38")        
    
    # these models were saved as part of training
    # see ../2_train_and_test_models/callbacks.py for model saving details 
    if model_type == "best":
        model_file_suffix = "_run" + str(run) + "_best.model"
    else:
        model_file_suffix = "_run" + str(run) + "_15E_end.model"
    
    # get all files that match the prefix and suffix
    files = [f for f in os.listdir(model_file_prefix) if f.endswith(model_file_suffix)]
    
    # sort files and return the one that is most recent
    latest_file = max([model_file_prefix + f for f in files], key=os.path.getctime)
    return latest_file


def load_keras_model(model_file, DA = False):
    print("Loading " + model_file + ".")
    if DA:
        # need to tell Keras how the GRL and the custom loss was implemented
        # (these need to match the definitions from when the model was saved)
        return keras.models.load_model(model_file,
                custom_objects = {"GradientReversal":GradientReversal,
                                  "custom_loss":custom_loss})
    return keras.models.load_model(model_file)


def get_models_all_runs(tf, train_species, runs = 5):
    # load in models for all runs, for a given TF and training species
    # returns a list of Keras model objects
    models = []
    for run in range(runs):
        model_file = get_model_file(tf, train_species, run + 1)
        models.append(load_keras_model(model_file, DA = train_species == "DA"))
    return models

# Generate + Save Predictions

In [None]:
def get_preds_file(tf, train_species, test_species):
    preds_root = ROOT + "model_out/"
    os.makedirs(preds_root, exist_ok=True)
    return preds_root + tf + "_" + train_species + "-trained_" + test_species + "-test.preds"

In [None]:
### This cell takes a while (hours) to run.

# loop over mouse and human, the two species to evaluate models in
for test_species in all_testspecies:
    # loop over mouse-trained, human-trained models
    for train_species in all_trainspecies:  
        for tf in tfs:
            print("\n===== " + tf + " " + test_species + " test, " + train_species + " trained =====\n")

            # load the 5 independently trained models for the given tf and training species
            models = get_models_all_runs(tf, train_species)
            
            # generate predictions for all 5 independent model runs on human data
            all_model_preds = np.array([get_preds_batched_fast(model, 1024, test_species = test_species) for model in models])
            
            # if we got the output of DA model, throw out species preds and keep binding preds
            if train_species == "DA" and len(all_model_preds.shape) > 2:
                all_model_preds = all_model_preds[:, 0, :]
            assert len(all_model_preds.shape) == 2, all_model_preds.shape
            
            # save predictions to file
            preds_file = get_preds_file(tf, train_species, test_species)
            np.save(preds_file, all_model_preds.T)

            # clear variables and model to avoid unnecessary memory usage
            del all_model_preds, tf, models
            keras.backend.clear_session()
        del train_species
    del test_species