In [None]:
import pysam
import torch

torch.cuda.empty_cache()

import numpy as np

from torch.utils.data import Dataset, DataLoader
import h5py

import torch.optim as optim
from torch import nn
import torch.nn.functional as f
from enformer_pytorch import Enformer

import sys

sys.path.insert(0,'/hpc/compgen/projects/fragclass/analysis/mvivekanandan/script/madhu_scripts')

import config
import sequenceUtils

import importlib
import matplotlib.pyplot as plt

In [None]:
importlib.reload(sequenceUtils)

#Set arguments from config file.
arguments = {}
arguments["enformerOutputStoreFile"] = config.filePaths.get("enformerOutputStoreFile")
arguments["batchSize"] = config.modelHyperParameters.get("batchSize")
arguments["learningRate"] = config.modelHyperParameters.get("learningRate")
arguments["numberOfWorkers"] = config.modelHyperParameters.get("numberOfWorkers")
arguments["numberEpochs"] = config.modelHyperParameters.get("numberEpochs")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The device used is : {device}")

In [None]:
class BasicDenseLayer(nn.Module):
    def __init__(self):
        super(BasicDenseLayer, self).__init__()
        self.fc1 = nn.Linear(2 * 5313, 1000)
        self.fc2 = nn.Linear(1000, 200)
        self.fc3 = nn.Linear(200, 2)

    def forward(self, x):
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class TrainingDataset(Dataset):
    def __init__(self):
        self.enformerOutputFilePath = arguments["enformerOutputStoreFile"]

    """
    The indexes fetched by dataloader iteration are not in order, because shuffling is set to true. This will not cause a mismatch
    between the enformer output and the label. Because enformer output and label are fetched for the same index, so they will still
    correspond to each other.
    """

    def __getitem__(self, index):
        with h5py.File(self.enformerOutputFilePath, 'r') as f:

            #Will the whole data be loaded everytime we have to fetch an index ? Should be not ! Thats what h5py files are for right ?
            enformerOutput = f['enformerOutput']
            # print(f"Inside get item, index is {index} and the total num samples is {enformerOutput.shape}\n")

            enformerOutput = enformerOutput[index]
            # print(f"Shape of enformer output of index {index} is {enformerOutput.shape}")
            encoded_enformer_output = torch.tensor(np.float32(enformerOutput))

            #Each sample should have only one label, it should be a single value instead of a numpy 1D array.
            labels = f['labels']
            # print(f"Just retrieved the labels from the file, the shape of labels is {labels.shape}")
            label = labels[index][0]

        return encoded_enformer_output, label

    def __len__(self):
        return 140
        # with h5py.File(self.enformerOutputFilePath, 'r') as f:
        #     h5py_dataset = f["labels"]
        #     return len(h5py_dataset)

In [None]:
def objectiveFn(batchSize, learningRate, numWorkers, numEpochs):

    #Set the model to eval mode first and then send it to cuda. This prevents the GPU node from running out of memory.
    enformerModel = Enformer.from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True).eval()
    enformerModel = enformerModel.to(device)

    trainingDataset = TrainingDataset()
    trainingDataloader = DataLoader(trainingDataset, batch_size=batchSize, shuffle=True, num_workers=numWorkers)
    denseLayerModel = BasicDenseLayer().to(device)

    #Define the loss function and optimizer.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(denseLayerModel.parameters(), lr=learningRate, momentum=0.9)

    loss_list = []
    #Training the model
    for epoch in range(numEpochs):
        trainingFn(epoch, criterion, optimizer, denseLayerModel, trainingDataloader, enformerModel, loss_list)

    print(f"Printing the type of loss list{type(loss_list)}")

    print(f"Finished training, type of loss_list is {loss_list}")
    xs = [x for x in range(len(loss_list))]
    plt.plot(xs, loss_list)
    plt.show()
    print("Training is complete !!! Starting validations ")
    #Validation of model predictions
    # validationDataset = ValidationDataset()
    # validationFn(validationDataset, criterion, denseLayerModel, enformerModel)


def trainingFn(epoch, criterion, optimizer, denseLayerModel, dataloader, enformerModel, loss_list):
    running_loss = 0.0
    count = 0

    for i, data in enumerate(dataloader, 0):
        count = count + 1
        print(f"Inside epoch {epoch}, Dataloader index is {i}")
        enformerPrediction, label = data

        if torch.cuda.is_available():
            #While creating torch.tensor, device can be passed as cuda. But that was a suspect for GPU node running out of memory.
            #After iterating through dataset and fetching each sample, send the labels and sequence to cuda,
            enformerPrediction = enformerPrediction.to(device)

        print(f"After batching, printing enformer prediction and label shapes: {enformerPrediction.shape}, {label.shape}\n", flush=True)

        optimizer.zero_grad()

        modelPrediction = denseLayerModel(enformerPrediction).to(device)
        print(f"The shape of model prediction is {modelPrediction.shape}\n",flush=True)
        print(f"The model prediction is {modelPrediction}\n", flush=True)

        #Without this conversion, model throws RuntimeError: Expected Scalar type Long but found Float Error.
        label = label.type(torch.LongTensor).to(device)

        print(f"The shape of the label is {label.size()}\n", flush=True)
        print(f"The label is {label}\n")

        # Get cross entropy loss between model's prediction and true label.
        loss = criterion(modelPrediction, label)

        print(f"Just computed the loss for epoch {epoch} and dataloader iteration {i}. It is {loss}\n", flush=True)

        # Backward pass and calculate the gradients
        loss.backward()

        # Uses the gradients from backward pass to nudge the learning weights.
        optimizer.step()

        # Print loss for every training set
        # Check that the loss is continuosly decreasing over training samples.
        running_loss += loss.item()

        # if i % 500 == 0:
        print(f"Inside Training function, the running loss is {running_loss}\n",flush=True)

    #The final running loss should be divided by this number of to get the average running loss.
    print(f"Total number of iterations for the current epoch is {count}\n")

    #The running_loss is the sum of individual losses for each batch.
    #The average running loss for the epoch should be runnning_loss divided by the number of batches.
    num_batches = len(dataloader)
    print(f"Finished iterating through the dataLoader for epoch {epoch}. The number of batches is {num_batches} and running loss is {running_loss}\n",flush=True)
    avg_running_loss = running_loss/num_batches
    print(f"Average running loss for epoch {epoch} after backward pass is {avg_running_loss}\n")
    loss_list.append(avg_running_loss)

In [None]:
if __name__ == "__main__":
    print(f"The arguments is {arguments}")
    batchsize = arguments["batchSize"]
    learningRate = arguments["learningRate"]
    numWorkers = arguments["numberOfWorkers"]
    numEpochs = arguments["numberEpochs"]
    print(f"The arguments are {batchsize}, {learningRate}, {numWorkers}, {numEpochs}")

    objectiveFn(batchsize, learningRate, numWorkers, numEpochs)