In [None]:
# TODO : Normalize
# TODO : Convolution + Deconvolution

In [1]:
import pickle
import numpy as np
from tqdm import tqdm
HMs = ["H3K4me3", "H3K27ac", "H3K4me1", "H3K36me3"] # Histone modification

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Current computational resource :", device)

Current computational resource : cuda


In [2]:
class NoisedDataset(Dataset):
    def __init__(self, biosamples, dataset_dir, run_type):
        amount = 100 if run_type == "train" else 20
        x_data = np.zeros((amount*3, 4, 248957), dtype=np.float32)
        y_data = np.zeros((amount*3, 4, 248957), dtype=np.float32)

        for ith, biosample in enumerate(biosamples, start=1):
            path = dataset_dir + biosample + "/{}.pkl".format(biosample)
            with open(path, "rb") as f:
                dataset = pickle.load(f)

            clean_signal = np.array([dataset[HM]["chr1"] for HM in HMs])
            x_data[(ith-1)*amount: ith*amount] = self.generateNoisedData(clean_signal, amount)
            y_data[(ith-1)*amount: ith*amount] = np.tile(clean_signal, (amount,1,1))

        self.x_tensor = torch.from_numpy(x_data)
        self.y_tensor = torch.from_numpy(y_data)
        
    def __getitem__(self, index):
        return (self.x_tensor[index], self.y_tensor[index])

    def __len__(self):
        return len(self.x_tensor)

    def generateNoisedData(self, clean_signal, amount, mean=0, std=5):
        clean_signals = np.tile(clean_signal, (amount,1,1))
        noise = np.random.normal(mean, std, size=(clean_signals.shape))
        
        return clean_signals + noise

In [3]:
class DenoisingModel(nn.Module):
    def __init__(self):
        super(DenoisingModel,self).__init__()
        self.dense = nn.Sequential()

        self.dense.add_module("dense1", nn.Linear(4*248957,256))
        self.dense.add_module("relu1", nn.ReLU())
        self.dense.add_module("dense2", nn.Linear(256,128))
        self.dense.add_module("relu2", nn.ReLU())
        self.dense.add_module("dense3", nn.Linear(128,64))
        self.dense.add_module("relu3", nn.ReLU())
        self.dense.add_module("dense4", nn.Linear(64,128))
        self.dense.add_module("relu4", nn.ReLU())
        self.dense.add_module("dense5", nn.Linear(128,256))
        self.dense.add_module("relu5", nn.ReLU())
        self.dense.add_module("dense6", nn.Linear(256,4*248957))
        self.dense.add_module("relu6", nn.Sigmoid())
    
    def forward(self,x):
        x = self.dense(x)

        return x

In [4]:
category = "cell line"
biosamples = ["A549", "H1", "H9"]
dataset_dir = "dataset/{}/".format(category)

np.random.seed(52)
train_set = NoisedDataset(biosamples, dataset_dir, run_type="train")
test_set = NoisedDataset(biosamples, dataset_dir, run_type="test")
train_loader = DataLoader(train_set, batch_size=30, shuffle=True)
test_loader = DataLoader(test_set, batch_size=10, shuffle=True)

In [5]:
model = DenoisingModel().to(device)
model.train()

DenoisingModel(
  (dense): Sequential(
    (dense1): Linear(in_features=995828, out_features=256, bias=True)
    (relu1): ReLU()
    (dense2): Linear(in_features=256, out_features=128, bias=True)
    (relu2): ReLU()
    (dense3): Linear(in_features=128, out_features=64, bias=True)
    (relu3): ReLU()
    (dense4): Linear(in_features=64, out_features=128, bias=True)
    (relu4): ReLU()
    (dense5): Linear(in_features=128, out_features=256, bias=True)
    (relu5): ReLU()
    (dense6): Linear(in_features=256, out_features=995828, bias=True)
    (relu6): Sigmoid()
  )
)

In [6]:

epochs = 5

for epoch in range(epochs):
    progress_bar = tqdm(train_loader)
    for x, y in progress_bar:
        x = x.view(x.size(0),-1).to(device)
        y = y.view(y.size(0),-1).to(device)
        
        pred = model(x)
        loss = criterion(y, pred)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress_bar.set_description("Current loss : {}".format(loss))


Current loss : 1923.6422119140625: 100%|██████████| 10/10 [00:01<00:00,  5.78it/s]
Current loss : 2189.42626953125: 100%|██████████| 10/10 [00:01<00:00,  6.36it/s]
Current loss : 1909.4498291015625: 100%|██████████| 10/10 [00:01<00:00,  6.43it/s]
Current loss : 1241.249755859375: 100%|██████████| 10/10 [00:01<00:00,  6.38it/s]
Current loss : 838.9619750976562: 100%|██████████| 10/10 [00:01<00:00,  6.46it/s]
