In [None]:
import torch
import random

torch.cuda.empty_cache()

import numpy as np
import pandas as pd

from torch.utils.data import Dataset, DataLoader
import h5py

import torch.optim as optim
from torch import nn
import torch.nn.functional as f
from torch.nn.functional import one_hot, softmax

import sys
from datetime import datetime
import os

sys.path.insert(0,'/hpc/compgen/projects/fragclass/analysis/mvivekanandan/script/madhu_scripts')

import config
import utils

import importlib
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_curve

import seaborn as sns

import time

In [None]:
importlib.reload(utils)
importlib.reload(config)

#Set arguments from config file.
arguments = {}
arguments["trainingEnformerOutputStoreFile"] = config.testFilePaths.get("trainingEnformerOutputStoreFile")
arguments["validationEnformerOutputStoreFile"] = config.testFilePaths.get("validationEnformerOutputStoreFile")
arguments["testEnformerOutputStoreFile"] = config.testFilePaths.get("testEnformerOutputStoreFile")
arguments["batchSize"] = config.modelHyperParameters.get("batchSize")
arguments["learningRate"] = config.modelHyperParameters.get("learningRate")
arguments["numberOfWorkers"] = config.modelHyperParameters.get("numberOfWorkers")
arguments["numberEpochs"] = config.modelHyperParameters.get("numberEpochs")
arguments["threshold"] = config.modelHyperParameters.get("threshold")
arguments["storePlots"] = config.modelGeneralConfigs.get("storePlots")
arguments["modelName"] = config.modelGeneralConfigs.get("modelName")
arguments["trainingAndValidationOutputsDirectory"] = config.filePaths.get("trainingAndValidationOutputsDirectory")
arguments["interchangeLabels"] = config.modelGeneralConfigs.get("interchangeLabels")
arguments["useClassWeights"] = config.modelGeneralConfigs.get("useClassWeights")
arguments["useCosineLearningFunction"] = config.modelGeneralConfigs.get("useCosineLearningFunction")
arguments["trainingStartIndex"] = config.modelGeneralConfigs.get("startIndexEnformerSamplesTraining")
arguments["trainingEndIndex"] = config.modelGeneralConfigs.get("endIndexEnformerSamplesTraining")
arguments["validationStartIndex"] = config.modelGeneralConfigs.get("startIndexEnformerSamplesValidation")
arguments["validationEndIndex"] = config.modelGeneralConfigs.get("endIndexEnformerSamplesValidation")
arguments["normalizeFeatures"] = config.modelGeneralConfigs.get("normalizeFeatures")
arguments["runWithControls"] = config.modelGeneralConfigs.get("runWithControls")
print(arguments)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The device used is : {device}")

now = datetime.now()
filename_extension = now.strftime("%d_%m_%H_%M_%S")
plotsDirectoryName = filename_extension + "_" + str(arguments["modelName"])
plotsDirectoryPath = os.path.join(arguments["trainingAndValidationOutputsDirectory"], plotsDirectoryName)
if(arguments["storePlots"]):
    os.mkdir(plotsDirectoryPath)

masterList = []

In [None]:
class BasicDenseLayer(nn.Module):
    def __init__(self):
        super(BasicDenseLayer, self).__init__()
        self.fc1 = nn.Linear(2 * 5313, 5000)
        self.fc2 = nn.Linear(5000, 1000)
        self.fc2 = nn.Linear(1000, 200)
        self.fc3 = nn.Linear(200, 2)

    def forward(self, x):
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class SimpleDenseLayer(nn.Module):
    def __init__(self):
        super(SimpleDenseLayer, self).__init__()
        self.fc1 = nn.Linear(25, 20)
        self.fc2 = nn.Linear(20, 10)
        self.fc3 = nn.Linear(10, 2)

    def forward(self, x):
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class EnformerOutputDataset(Dataset):
    def __init__(self, sampleType):
        self.sampleType = sampleType
        self.enformerOutputDatasetName = sampleType + "EnformerOutput"
        self.labelsDatasetName = sampleType + "Labels"

        self.enformerOutputFileKey = sampleType + "EnformerOutputStoreFile"
        self.enformerOutputFilePath = arguments[self.enformerOutputFileKey]
        self.startIndex = arguments[sampleType + "StartIndex"]
        self.endIndex = arguments[sampleType + "EndIndex"]
        print(f"inside init fn, start and end index for type {self.sampleType} are {self.startIndex} and {self.endIndex}")

    """
    The indexes fetched by dataloader iteration are not in order, because shuffling is set to true. This will not cause a mismatch
    between the enformer output and the label. Because enformer output and label are fetched for the same index, so they will still
    correspond to each other.
    """

    def __getitem__(self, indices):
        with h5py.File(self.enformerOutputFilePath, 'r') as f:
            assert all(np.array(indices) >= self.startIndex), f"Some indices are smaller than {self.startIndex} - {indices}"
            if self.endIndex != "all":
                assert all(np.array(indices) < self.endIndex), f"Some indices are greater than {self.endIndex} - {indices}"

            #Indices will be an array of indices. The size of the array is equal to the batch size. At a time, the entire batch will be loaded.
            enformerOutput = f[self.enformerOutputDatasetName][indices][:, 0:25]
            encoded_enformer_output = torch.tensor(np.float32(enformerOutput))

            #Normalize each feature with zscore values.
            if(arguments["normalizeFeatures"]):
                for i, single_batch in enumerate(encoded_enformer_output):
                    mean, std = torch.mean(single_batch), torch.std(single_batch)
                    encoded_enformer_output[i] = (single_batch - mean)/std

            labels = f[self.labelsDatasetName][indices]

        #For old coordinate files, the positives were incorrectly labelled as 0 and the negatives as 1.
        if(arguments["interchangeLabels"]):
            positive_indices = np.where(labels == 0)
            negative_indices = np.where(labels == 1)
            labels[positive_indices] = 1
            labels[negative_indices] = 0

        #POSITIVE AND NEGATIVE CONTROL ADDITION.
        #Replace some features in all positives with a higher value and the negatives with a negative value
        if(arguments["runWithControls"]):
            for i, output in enumerate(encoded_enformer_output):
                if labels[i] == 1:
                    replacement_val = random.uniform(0.5, 1.5)
                else:
                    replacement_val = random.uniform(-0.5, -1.5)

                replacementFeatList = [1, 2]
                for j in replacementFeatList:
                    output[j] = replacement_val

                encoded_enformer_output[i] = output

            # #The enoded enformer output only has 20 features. Append 0's for the rest of the features so that the dense layer matrix size is consistent with input
            # nrows, ncols = encoded_enformer_output.shape
            # zeroes = torch.zeros(nrows, 10606)
            # encoded_enformer_output = torch.cat((encoded_enformer_output, zeroes), axis = 1)

        return encoded_enformer_output, labels

    def __len__(self):
        with h5py.File(self.enformerOutputFilePath, 'r') as f:
            if self.endIndex == "all":
                labels = f[self.labelsDatasetName][self.startIndex:]
            else:
                labels = f[self.labelsDatasetName][self.startIndex:self.endIndex]

            length_positives = (labels == 1).sum()
            length_negatives = (labels == 0).sum()
            arguments[self.sampleType + "PositivesLength"] = length_positives
            arguments[self.sampleType + "NegativesLength"] = length_negatives
            print(f"Inside length of the {self.sampleType} dataset, the total number of samples is {length_positives + length_negatives}")
            return length_positives + length_negatives

In [None]:
def getParametersDescription():
    learningRate = arguments["learningRate"]
    batchSize = arguments["batchSize"]
    numTrainingPositives = arguments["trainingPositivesLength"]
    numTrainingNegatives = arguments["trainingNegativesLength"]
    numValidationPositives = arguments["validationPositivesLength"]
    numValidationNegatives = arguments["validationNegativesLength"]
    numTraining = numTrainingPositives + numTrainingNegatives
    numValidation = numValidationPositives + numValidationNegatives

    numLayers = 3
    description = (f"learning rate = {learningRate},\n"
                   + f"number of training samples = {numTraining} ({numTrainingPositives} positives and {numTrainingNegatives} negatives),\n"
                    + f"number of validation samples = {numValidation} ({numValidationPositives} positives and {numValidationNegatives} negatives),\n"
                    + f"batch size = {batchSize},\n number of layers = {numLayers}")

    return description

def storeLossFunctionPlot(training_loss_list, validation_loss_list):
    xs_train = [x for x in range(len(training_loss_list))]
    fig = plt.figure()
    plot_description = getParametersDescription()
    fig.text(.5, -0.3, plot_description, ha = 'center')
    plt.plot(xs_train, training_loss_list, '-.', label = "Training")
    xs_validation = [x for x in range(len(validation_loss_list))]
    plt.plot(xs_validation, validation_loss_list, label = "validation")

    plt.xlabel("Epochs")
    plt.ylabel("Cross Entropy Loss")
    plt.legend(loc="upper left")
    plt.title("Training and Validation Cross entropy loss over epochs.")
    if(arguments["storePlots"]):
        plotPath = os.path.join(plotsDirectoryPath, "lossFunctionPlot")
        plt.savefig(plotPath, bbox_inches='tight')

    plt.show()

"""
Create a file(the name of the file has the current date and time) and save the state_dict of the model
"""
def saveModel(model):
    if(arguments["storePlots"]):
        filepath = os.path.join(plotsDirectoryPath, "modelState")
        f = open(filepath, "x")
        torch.save(model.state_dict(), filepath)
        f.close()

#TODO To be removed once confusion matrix problems are fixed.
def printDonorRecipientLabelsVsPredictions(true_labels, predictions, sampleType):
    true_count_0 = 0
    true_count_1 = 0

    for i in range(len(true_labels)):
        if(true_labels[i] == 0):
            true_count_0 = true_count_0 + 1
        if(true_labels[i] == 1):
            true_count_1 = true_count_1 + 1

    pred_count_0 = 0
    pred_count_1 = 0

    for i in range(len(predictions)):
        if(predictions[i] == 0):
            pred_count_0 = pred_count_0 + 1
        if(predictions[i] == 1):
            pred_count_1 = pred_count_1 + 1

    print(f"num of 0's and 1's predictions in the {sampleType} set is {pred_count_0} and {pred_count_1}")

def getClassWeights():
    #If we store all training and validation labels, it might be too big a variable and slow things down.
    #We are assuming that the first 10,000 samples are reflective of
    with h5py.File(arguments["trainingEnformerOutputStoreFile"], 'r') as f:
        training_labels = f["trainingLabels"][0:10000].flatten().tolist()

    with h5py.File(arguments["validationEnformerOutputStoreFile"], 'r') as f:
        validation_labels = f["validationLabels"][0:10000].flatten().tolist()

    all_labels = training_labels + validation_labels
    class_weights = compute_class_weight(class_weight = "balanced", classes = [0, 1], y = all_labels)
    return class_weights

#Once all errors with the function are fixed, move this inside objective function. This is pulled outside only to ensure
#that the model variable is not lost if there are errors during saving.
denseLayerModel = SimpleDenseLayer().to('cuda')

def getConfusionMatrixAndLabels(true_labels, predictions, sampleType):

    #TODO To be removed after testing is done
    printDonorRecipientLabelsVsPredictions(true_labels, predictions, sampleType)

    #Build the confusion matrix
    cf_matrix = confusion_matrix(true_labels, predictions)
    #train_cfmatrix_df = pd.DataFrame(data = training_cf_matrix, index = ["true 0", "true 1"], columns = ["predicted 0","predicted 1"])

    group_names = ["True Neg","False Pos","False Neg","True Pos"]
    group_counts = ["{0:0.0f}".format(value) for value in cf_matrix.flatten()]
    group_percentages = ["{0:.2%}".format(value) for value in cf_matrix.flatten()/np.sum(cf_matrix)]
    labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(group_names,group_counts,group_percentages)]
    labels = np.asarray(labels).reshape(2,2)
    return cf_matrix, labels

def storeConfusionMatrixHeatMap(training_true_labels, validation_true_labels, training_predictions, validation_predictions):
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(14, 8))
    heatmap_description = getParametersDescription()
    fig.text(.5, -0.1, heatmap_description, ha = 'center', fontsize=12)
    training_cf_matrix, training_cf_matrix_labels = getConfusionMatrixAndLabels(training_true_labels, training_predictions, "training")
    validation_cf_matrix, validation_cf_matrix_labels = getConfusionMatrixAndLabels(validation_true_labels, validation_predictions, "validation")
    s1 = sns.heatmap(training_cf_matrix, annot=training_cf_matrix_labels, fmt = '', cmap="Blues", ax=ax1, annot_kws={"fontsize":12})
    s2 = sns.heatmap(validation_cf_matrix, annot=validation_cf_matrix_labels, fmt = '', cmap="Blues", ax=ax2, annot_kws={"fontsize":12})
    s1.set_xlabel("Predicted Label", fontsize=12)
    s1.set_ylabel("True Label", fontsize=12)
    s2.set_xlabel("Predicted Label", fontsize=12)
    s2.set_ylabel("True Label", fontsize=12)
    fig.subplots_adjust(hspace=0.75, wspace=0.75)

    ax1.title.set_text(f'Training')
    ax2.title.set_text(f'Validation')

    if(arguments["storePlots"]):
        plotPath = os.path.join(plotsDirectoryPath, "confusionMatrix")
        plt.savefig(plotPath, bbox_inches='tight')

    plt.show()


def storePerformanceMetrics(true_labels, predictions, sampleType):
    target_names = ["donor", "recipient"]
    report = classification_report(true_labels, predictions, target_names=target_names, output_dict=True)
    report_df = pd.DataFrame(data = report).transpose()
    model_parameters_text = getParametersDescription()
    parameters_df = {'donor': 'model parameters', 'recipient': model_parameters_text}
    report_df = report_df.append(parameters_df, ignore_index = True)

    if(arguments["storePlots"]):
        filename = "performanceMetrics_" + sampleType + ".csv"
        csv_path =  os.path.join(plotsDirectoryPath, filename)
        report_df.to_csv(csv_path, index= True)

def storeProbabilityFrequencyPlot(positive_training, negative_training, positive_validation, negative_validation):
    positive_training_df = pd.DataFrame(positive_training, columns = ["probabilities"])
    positive_training_df["type"] = "positive_training"
    negative_training_df = pd.DataFrame(negative_training, columns = ["probabilities"])
    negative_training_df["type"] = "negative_training"
    positive_validation_df = pd.DataFrame(positive_validation, columns = ["probabilities"])
    positive_validation_df["type"] = "positive_validation"
    negative_validation_df = pd.DataFrame(negative_validation, columns = ["probabilities"])
    negative_validation_df["type"] = "negative_validation"

    df1 = pd.concat([positive_training_df, negative_training_df], ignore_index=True, axis=0)
    df2 = pd.concat([positive_validation_df, negative_validation_df], ignore_index=True, axis=0)
    prob_data_df = pd.concat([df1, df2], ignore_index=True, axis=0)

    fig = plt.figure()
    plot_description = getParametersDescription()
    fig.text(.5, -0.1, plot_description, ha = 'center')
    plt.figure(figsize=(10,5))
    sns.histplot(data = prob_data_df, x = "probabilities", hue = "type", element = "step")
    plt.yscale('log')
    plt.title("Positive and Negative probabily distributions from model predictions. ")

    if(arguments["storePlots"]):
        plotPath = os.path.join(plotsDirectoryPath, "probabilityDistributionPlot")
        plt.savefig(plotPath, bbox_inches='tight')

    plt.show()

def plotPrecisionRecallCurve(training_positive_probs, training_true_labels, validation_positive_probs, validation_true_labels):
    fig, (ax1, ax2) = plt.subplots(2, figsize=(6, 12))

    training_precision, training_recall, training_thresholds = precision_recall_curve(training_true_labels, training_positive_probs)
    ax1.plot(training_recall, training_precision)
    ax1.title.set_text('Training')

    validation_precision, validation_recall, validation_thresholds = precision_recall_curve(validation_true_labels, validation_positive_probs)
    ax2.plot(validation_recall, validation_precision)
    ax2.title.set_text('Validation')

    plt.xlabel("Recall")
    plt.ylabel("Precision")
    fig.subplots_adjust(hspace=0.2, wspace=0.2)

    if(arguments["storePlots"]):
        plotPath = os.path.join(plotsDirectoryPath, "precisionRecallPlot")
        plt.savefig(plotPath, bbox_inches = "tight")

    plt.show()

def getProbsForEachConfusionMatrixBlock(pos_probs, class_labels, threshold, sampleType):
    true_pos = []
    true_neg = []
    false_pos = []
    false_neg = []
    for i, val in enumerate(class_labels):
        if val == 1 and pos_probs[i] > threshold:
            true_pos.append(pos_probs[i])
        if val == 1 and pos_probs[i] < threshold:
            false_neg.append(pos_probs[i])
        if val == 0 and pos_probs[i] > threshold:
            false_pos.append(pos_probs[i])
        if val == 0 and pos_probs[i] < threshold:
            true_neg.append(pos_probs[i])

    # print(f"Printing confusion matrix individual block probabilities for {sampleType}")
    # print(f"true positives probs: size is {len(true_pos)} and {true_pos[1:20]}")
    # print(f"false positives probs: size is {len(false_pos)} and {false_pos[1:20]}")
    # print(f"true negatives probs: size is {len(true_neg)} and {true_neg[1:20]}")
    # print(f"false negatives probs: size is {len(false_neg)} and {false_neg[1:20]}")

    true_pos_df = pd.DataFrame(true_pos, columns = ["probabilities"])
    true_pos_df["type"] = "true_pos"
    false_pos_df = pd.DataFrame(false_pos, columns = ["probabilities"])
    false_pos_df["type"] = "false_pos"
    true_neg_df = pd.DataFrame(true_neg, columns = ["probabilities"])
    true_neg_df["type"] = "true_neg"
    false_neg_df = pd.DataFrame(false_neg, columns = ["probabilities"])
    false_neg_df["type"] = "false_neg"

    df1 = pd.concat([true_pos_df, false_pos_df], ignore_index=True, axis=0)
    df2 = pd.concat([true_neg_df, false_neg_df], ignore_index=True, axis=0)
    final_probs_df = pd.concat([df1, df2], ignore_index=True, axis=0)

    desc = f"{sampleType} True Pos: {len(true_pos)}, False Pos: {len(false_pos)}, True Neg: {len(true_neg)}, False Neg: {len(false_neg)} \n"
    return final_probs_df, desc

def confsionMatrixLevelProbDistribtuionPlot(training_pos_probs, training_class_labels, valid_pos_probs,
                                            valid_class_labels, threshold, epoch):

    training_probs_df, training_desc = getProbsForEachConfusionMatrixBlock(training_pos_probs, training_class_labels, threshold, "Training")
    validation_probs_df, validation_desc = getProbsForEachConfusionMatrixBlock(valid_pos_probs, valid_class_labels, threshold, "Validation")

    fig_desc = training_desc + validation_desc

    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(9, 6))
    s1 = sns.kdeplot(data = training_probs_df, x = "probabilities", hue = "type", ax = ax1)
    s2 = sns.kdeplot(data = validation_probs_df, x = "probabilities", hue = "type", ax = ax2)
    fig.text(.5, -0.2, fig_desc, ha = 'center')
    ax1.title.set_text(f'Training Total epochs {epoch}')
    ax2.title.set_text(f'Validation Total epochs {epoch}')

    if(arguments["storePlots"]):
        plotPath = os.path.join(plotsDirectoryPath, "confusionMatrixLevelProbabilityDistributionPlot")
        plt.savefig(plotPath, bbox_inches = "tight")

    plt.show()

def plotLearningRate(scheduler_learning_rates, numEpochs, num_batches):
    fig, ax = plt.subplots()
    xs_learning_rate = [x for x in range(len(scheduler_learning_rates))]
    plt.plot(xs_learning_rate, scheduler_learning_rates)
    epoch_locations = []

    for i in range(1, numEpochs):
        epoch_locations.append(i * num_batches)

    ax.vlines(x=epoch_locations, ymin = 0, ymax = max(scheduler_learning_rates), colors='r', ls="--")
    ax.set_xlabel("Training steps")
    ax.set_title("Learning rate curve used by optimizer")
    if(arguments["storePlots"]):
        plotPath = os.path.join(plotsDirectoryPath, "learningRatePlot")
        plt.savefig(plotPath, bbox_inches = "tight")

    plt.show()

In [None]:
def getModelPredictionAndLoss(denseLayerModel, dataloader, criterion, lossList, threshold, predictions, isTraining=False, epoch=False, optimizer = False, learning_rates = False):
    running_loss = 0.0

    time_to_train = 0
    start_time = time.time()
    print(f"Start time is {start_time}")

    trueLabels = []
    all_positive_probs = []
    all_negative_probs = []
    violin_plot_positives =  np.zeros(25)
    violin_plot_negatives = np.zeros(25)

    for i, data in enumerate(dataloader, 0):
        enformerPrediction, classLabels = data

        if torch.cuda.is_available():
            #While creating torch.tensor, device can be passed as cuda. But that was a suspect for GPU node running out of memory.
            #After iterating through dataset and fetching each sample, send the labels and sequence to cuda
            #The class labels should be of type integer.
            enformerPrediction = enformerPrediction.to('cuda')
            classLabels = classLabels.to(torch.int64).to('cuda')

        #TODO modify sampler function
        #Because we use the sampler, there is an extra dimension for the labels and enformer output. [1*128*10626].
        #Take only the 1st element to remove the extra dimension.
        classLabels = classLabels[0].flatten()
        enformerPrediction = enformerPrediction[0]
        for j, singleSample in enumerate(enformerPrediction):
            if classLabels[j] == 1:
                violin_plot_positives = np.row_stack([violin_plot_positives, singleSample.detach().cpu().numpy()])
            else:
                violin_plot_negatives = np.row_stack([violin_plot_negatives, singleSample.detach().cpu().numpy()])

        #The class labels have to be encoded into probabilities of type floating point
        probabilityLabels = one_hot(classLabels, num_classes=2).to(torch.float32)

        trainStartTime = time.time()
        modelPrediction = denseLayerModel(enformerPrediction)

        #Apply softmax function to convert the prediction into probabilities between 0 and 1. This is used for plotting
        #the frequency of the outcomes to know how sure the model was for different data points.
        softmaxProbabilities = softmax(modelPrediction.cpu().detach()).numpy()
        softmax_positive_probs = softmaxProbabilities.transpose()[1].flatten()
        softmax_negative_probs = softmaxProbabilities.transpose()[0].flatten()
        all_positive_probs.extend(softmax_positive_probs)
        all_negative_probs.extend(softmax_negative_probs)

        #Model prediction is a 2D array of size (batchSize * 2). The 2 values are the probabilities for the positive and negative for each sample in the batch.
        # Whichever of the 2 labels has the highest probabilities is taken as the final predicted label of the model.
        #If the probabilities added upto 1, this would count as taking 0.5 as the threshold for considering a class as the prediction.
        #Iterate through all the positive probabilities predictions in the batch and extend the predictions list with all the predictions for the batch.
        #TODO Uncomment this later - for now we are considering predictions for all epochs.
        # if(epoch == arguments["numberEpochs"]):
        predictedLabels = []
        for prob in softmax_positive_probs:
            if prob > threshold:
                predictedLabels.append(1)
            else:
                predictedLabels.append(0)

        predictions.extend(predictedLabels)

        trueLabels.extend(classLabels.cpu().numpy())

        # Get cross entropy loss between model's prediction and true label.
        loss = criterion(modelPrediction, probabilityLabels)

        #If the model is being trained, then do backpropagation and calculate loss.
        if(isTraining):
            #zero grad is applicable only for optimizers and not for cosine annealing function schedulers.
            if arguments["useCosineLearningFunction"] != True:
                optimizer.zero_grad()

            # Backward pass and calculate the gradients
            loss.backward()

            # Uses the gradients from backward pass to nudge the learning weights.
            if arguments["useCosineLearningFunction"]:
                learning_rates.append(optimizer.get_lr())

            optimizer.step()

        time_to_train += (time.time() - trainStartTime)

        # Print loss for every training set
        # Check that the loss is continuosly decreasing over training samples.
        running_loss += loss.item()

    end_time = time.time()
    total_time = end_time - start_time
    time_to_load = total_time - time_to_train

    #The running_loss is the sum of individual losses for each batch.
    #The average running loss for the epoch should be runnning_loss divided by the number of batches.
    num_batches = len(dataloader)
    avg_running_loss = running_loss/num_batches
    print(f"Average running loss for epoch {epoch} is {avg_running_loss}\n")
    lossList.append(avg_running_loss)

    #Make a violin plot of the values of enformer prediction to check if controls are working properly
    if(epoch == 1):
        print(f"Making a violin plot .......")
        column_values = [str(i) for i in range(0, 25)]
        positives_df = pd.DataFrame(data = violin_plot_positives[1:, :], columns = column_values)
        plt.figure()
        sns.violinplot(data = positives_df)
        plt.show()
        negatives_df = pd.DataFrame(data = violin_plot_negatives[1:, :], columns = column_values)
        plt.figure()
        sns.violinplot(data = negatives_df)
        plt.show()

    # print(f"Average time to train model per batch for epoch {epoch} is {time_to_train/num_batches}")
    # print(f"Average time to load input per batch for epoch {epoch} is {time_to_load/num_batches}")
    return all_positive_probs, all_negative_probs, trueLabels


def objectiveFn(batchSize, learningRate, numWorkers, numEpochs):

    #Training dataloader
    trainingDataset = EnformerOutputDataset("training")
    rangeTrainingSampler = range(arguments["trainingStartIndex"] , len(trainingDataset) + arguments["trainingStartIndex"])
    trainingsampler = torch.utils.data.BatchSampler(rangeTrainingSampler, batch_size=batchSize,
                                            drop_last=False )
    trainingDataloader = DataLoader(trainingDataset,  num_workers=numWorkers, sampler=trainingsampler)
    print(f"Finished getting training data loader,time for validation...")
    #Validation dataloader
    validationDataset = EnformerOutputDataset("validation")
    rangeValidationSampler = range(arguments["validationStartIndex"] , len(validationDataset) + arguments["validationStartIndex"])
    validation_sampler = torch.utils.data.BatchSampler(rangeValidationSampler, batch_size=batchSize,
                                            drop_last=False )
    validationDataloader = DataLoader(validationDataset, num_workers=numWorkers, sampler=validation_sampler)

    training_num_batches = len(trainingDataloader)

    #For training function, give higher weights to donor class because of sample number imbalance.
    #For valdiation, treat both the classes equally.
    if(arguments["useClassWeights"]):
        training_class_weights = getClassWeights()
    else:
        #If we are using a balanced dataset, there is no need to give weights to one class
        training_class_weights = [1, 1]

    print(f"Training class weights are {training_class_weights}")
    # training_class_weights = [20.0, 1.0]
    training_criterion = nn.CrossEntropyLoss(weight = torch.tensor(training_class_weights).to(device))

    #Add weights for the validation as well, otherwise the training and validation loss will have different ranges and
    #comparison will be difficult.
    validation_criterion = nn.CrossEntropyLoss(weight = torch.tensor(training_class_weights).to(device))

    optimizer = optim.Adam(denseLayerModel.parameters(), lr=learningRate)

    if arguments["useCosineLearningFunction"]:
        optimizer_steps = (training_num_batches * numEpochs) #Number of steps in gradient descent.
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, optimizer_steps, last_epoch = -1, eta_min=0)
    else:
        scheduler = optimizer

    threshold = arguments["threshold"]

    training_loss_list = []
    validation_loss_list = []

    training_positive_probabilities_list = []
    training_negative_probabilities_list = []
    validation_positive_probabilities_list = []
    validation_negative_probabilities_list = []

    scheduler_learning_rates = []

    #Training the model and validating for each epoch
    for epoch in range(1, numEpochs + 1):
        print(f"Starting training for epoch {epoch}")

        training_predictions = []
        validation_predictions = []

        #Training
        train_positive_probs, train_negative_probs, train_class_labels = getModelPredictionAndLoss(denseLayerModel, trainingDataloader, training_criterion,
            training_loss_list, threshold, training_predictions, True, epoch, scheduler, scheduler_learning_rates)

        print(f"Finished training for epoch {epoch}. Starting validations")

        training_positive_probabilities_list.extend(train_positive_probs)
        training_negative_probabilities_list.extend(train_negative_probs)

        #Validation
        with torch.no_grad():
            validation_positive_probs, validation_negative_probs, validation_class_labels = getModelPredictionAndLoss(denseLayerModel, validationDataloader, validation_criterion,
                validation_loss_list, threshold, validation_predictions, False, epoch)

            validation_positive_probabilities_list.extend(validation_positive_probs)
            validation_negative_probabilities_list.extend(validation_negative_probs)


    print(f"Completed training and validation. Saving model and plotting loss function graphs. ")

    saveModel(denseLayerModel)
    storeLossFunctionPlot(training_loss_list, validation_loss_list)
    storePerformanceMetrics(train_class_labels, training_predictions, "training")
    storePerformanceMetrics(validation_class_labels, validation_predictions, "validation")
    storeProbabilityFrequencyPlot(training_positive_probabilities_list, training_negative_probabilities_list,
                                    validation_positive_probabilities_list, validation_negative_probabilities_list)
    confsionMatrixLevelProbDistribtuionPlot(training_positive_probabilities_list, train_class_labels,
                                                validation_positive_probabilities_list, validation_class_labels, 0.5, arguments["numberEpochs"])


    #Plot confusion matrix and precision recall plot only for the probabilities in the last epoch
    storeConfusionMatrixHeatMap(train_class_labels, validation_class_labels, training_predictions, validation_predictions)
    plotPrecisionRecallCurve(train_positive_probs, train_class_labels, validation_positive_probs, validation_class_labels)
    if(arguments["useCosineLearningFunction"]):
        plotLearningRate(scheduler_learning_rates, numEpochs, training_num_batches)