This file has functions for the following
1. Define a Basic dense layer network
2. Create a dataset which reads the enformerOutput from the H5PY file and returns batches of the enformer predictions (as training data)
3. Run the basic dense layer model for the enformer predictions training data.
4. Run the trained model for the validation dataset
5. Plot loss for the training and validation dataset.

In [None]:
import pysam
import torch

torch.cuda.empty_cache()

import numpy as np

from torch.utils.data import Dataset, DataLoader
import h5py

import torch.optim as optim
from torch import nn
import torch.nn.functional as f
from enformer_pytorch import Enformer

import sys

sys.path.insert(0,'/hpc/compgen/projects/fragclass/analysis/mvivekanandan/script/madhu_scripts')

import config
import sequenceUtils

import importlib
import matplotlib.pyplot as plt

In [None]:
importlib.reload(sequenceUtils)

#Set arguments from config file.
arguments = {}
arguments["enformerOutputStoreFile"] = config.filePaths.get("enformerOutputStoreFile")
arguments["batchSize"] = config.modelHyperParameters.get("batchSize")
arguments["learningRate"] = config.modelHyperParameters.get("learningRate")
arguments["numberOfWorkers"] = config.modelHyperParameters.get("numberOfWorkers")
arguments["numberEpochs"] = config.modelHyperParameters.get("numberEpochs")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The device used is : {device}")

In [None]:
"""
Basic Dense Layer network. Consists of 3 linear layers. Size of input is 2*5313 and the size of the output is 2.
Relu functions are placed between each of these layers.
"""
class BasicDenseLayer(nn.Module):
    def __init__(self):
        super(BasicDenseLayer, self).__init__()
        self.fc1 = nn.Linear(2 * 5313, 1000)
        self.fc2 = nn.Linear(1000, 200)
        self.fc3 = nn.Linear(200, 2)

    def forward(self, x):
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
"""
Dataset which reads the H5PY file containing enformer out (enformerOutputStoreFile provided in config file) and returns the output along with the labels

Ouptut from get_item method - a tuple of encoded_enformer_output and label. encoded_enformer_output is a 1D torch tensor of size (10626). label is a single integer of 1 or 0 denoting whether the cfDNA fragment came from a tumour tissue or a regular tissue.
"""
class TrainingDataset(Dataset):
    def __init__(self):
        self.enformerOutputFilePath = arguments["enformerOutputStoreFile"]

    """
    The indexes fetched by dataloader iteration are not in order, because shuffling is set to true. This will not cause a mismatch
    between the enformer output and the label. Because enformer output and label are fetched for the same index, so they will still
    correspond to each other.
    """

    def __getitem__(self, index):
        with h5py.File(self.enformerOutputFilePath, 'r') as f:

            enformerOutput = f['enformerOutput']

            enformerOutput = enformerOutput[index]
            encoded_enformer_output = torch.tensor(np.float32(enformerOutput))

            labels = f['labels']
            label = labels[index][0]

        return encoded_enformer_output, label

    def __len__(self):
        with h5py.File(self.enformerOutputFilePath, 'r') as f:
            h5py_dataset = f["labels"]
            return len(h5py_dataset)

In [None]:
"""
This function does the following
1. Creates dataloader for the trainingDataset. The dataset in turn returns the enformer predictions read from enformerOuput h5py file.
2. Iterates over the number of epochs and calls the training function for getting the predictions of the Basic dense layer model for the enformer output.
3. Gets predictions of the trained model for the validation datastet.
4. Plots the loss function for the training and validation dataset.

Args:
batchSize(int) - the batch size for training. This value is read from the config file.
learningRate(float) - learning rate for training. This value is again read from the config file.
numWorkers(int) - number of parallel CPU processes that can be run for loading the data, functions etc. Also read from config file.
numEpochs(int) - number of epochs for training. Also read from the config file
"""
def objectiveFn(batchSize, learningRate, numWorkers, numEpochs):
    trainingDataset = TrainingDataset()
    trainingDataloader = DataLoader(trainingDataset, batch_size=batchSize, shuffle=True, num_workers=numWorkers)
    denseLayerModel = BasicDenseLayer().to(device)

    #Define the loss function and optimizer.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(denseLayerModel.parameters(), lr=learningRate, momentum=0.9)

    loss_list = []
    #Training the model
    for epoch in range(numEpochs):
        trainingFn(epoch, criterion, optimizer, denseLayerModel, trainingDataloader, loss_list)

    xs = [x for x in range(len(loss_list))]
    plt.plot(xs, loss_list)
    plt.show()
    print("Training is complete !!! Starting validations ")

    #Validation of model predictions
    # validationDataset = ValidationDataset()
    # validationFn(validationDataset, criterion, denseLayerModel)

"""
For each epoch, this function is responsible for training the basic dense layer network with enformer predictions data.
The running_loss for each batch is added and the average_running_loss for the epoch in question is calculated by dividing the total running loss with the number of samples.
The average running loss for the epoch is added to the loss list.

Args:
epoch(int) : The current epoch for which training is performed.
criterion(function) - The loss function for training.
optimizer(function) - Optimizer function for back propagation.
denseLayerModel - the dense layer model to be trained.
dataloader - the dataloader object. This object iterates over all samples and creates batches of enformer output data for training
loss_lost - list of floats. Has the ongoing loss_list for each epoch.
"""
def trainingFn(epoch, criterion, optimizer, denseLayerModel, dataloader, loss_list):
    running_loss = 0.0
    count = 0

    for i, data in enumerate(dataloader, 0):
        count = count + 1
        print(f"Inside epoch {epoch}, Dataloader index is {i}")
        enformerPrediction, label = data

        if torch.cuda.is_available():
            #While creating torch.tensor, device can be passed as cuda. But that was a suspect for GPU node running out of memory.
            #After iterating through dataset and fetching each sample, send the labels and sequence to cuda,
            enformerPrediction = enformerPrediction.to(device)

        optimizer.zero_grad()

        modelPrediction = denseLayerModel(enformerPrediction).to(device)

        #Without this conversion, model throws RuntimeError: Expected Scalar type Long but found Float Error.
        label = label.type(torch.LongTensor).to(device)

        # Get cross entropy loss between model's prediction and true label.
        loss = criterion(modelPrediction, label)

        print(f"Just computed the loss for epoch {epoch} and dataloader iteration {i}. It is {loss}\n", flush=True)

        # Backward pass and calculate the gradients
        loss.backward()

        # Uses the gradients from backward pass to nudge the learning weights.
        optimizer.step()

        # Print loss for every training set
        # Check that the loss is continuosly decreasing over training samples.
        running_loss += loss.item()

        if i % 500 == 0:
        print(f"Running loss for sample index {i} inside epoch {epoch} is {running_loss}\n",flush=True)


    #The running_loss is the sum of individual losses for each batch.
    #The average running loss for the epoch should be runnning_loss divided by the number of batches.
    num_batches = len(dataloader)
    avg_running_loss = running_loss/num_batches
    print(f"Average running loss for epoch {epoch} after backward pass is {avg_running_loss}\n")
    loss_list.append(avg_running_loss)

In [None]:
if __name__ == "__main__":
    batchsize = arguments["batchSize"]
    learningRate = arguments["learningRate"]
    numWorkers = arguments["numberOfWorkers"]
    numEpochs = arguments["numberEpochs"]

    objectiveFn(batchsize, learningRate, numWorkers, numEpochs)