In [1]:
import torch.nn.functional as F
import torch.utils.data
import torch.nn as nn
from torch import relu, sigmoid
import torch.nn.modules.activation as activation

import math
import numpy as np
import pandas as pd

import copy

import os
from Bio import SeqIO

In [2]:
#function to train a model - no validation
def train_model(train_loader, model, device, criterion, optimizer, num_epochs,
               weights_folder, name_ind, verbose):
    total_step = len(train_loader)
    train_error = []
    
    for epoch in range(num_epochs):
        model.train() #tell model explicitly that we train
        logs = {}
        running_loss = 0.0
        for seqs, labels in train_loader:
            x = seqs.to(device) #the input here is (batch_size, 4, 200)
            labels = labels.to(device)
            #zero the existing gradients so they don't add up
            optimizer.zero_grad()
            # Forward pass
            outputs = model(x)
            loss = criterion(outputs, labels) 
            # Backward and optimize
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        #save training loss 
        epoch_loss = running_loss / len(train_loader)
        train_error.append(epoch_loss)
        
        if verbose:
            print ('Epoch [{}], Current Train Loss: {:.5f}' 
                       .format(epoch+1, epoch_loss))
            
        #save model weights
        model_wts = copy.deepcopy(model.state_dict())
        torch.save(model_wts, weights_folder + "/"+"model_epoch_"+str(epoch+1)+"_"+
                       name_ind+".pth") 
    return model, train_error

#function to one-hot encode
############################################################
def dna_one_hot(seq, seq_len=None, flatten=True):
    if seq_len == None:
        seq_len = len(seq)
        seq_start = 0
    else:
        if seq_len <= len(seq):
            # trim the sequence
            seq_trim = (len(seq)-seq_len) // 2
            seq = seq[seq_trim:seq_trim+seq_len]
            seq_start = 0
        else:
            seq_start = (seq_len-len(seq)) // 2

    seq = seq.upper()

    seq = seq.replace('A','0')
    seq = seq.replace('C','1')
    seq = seq.replace('G','2')
    seq = seq.replace('T','3')

    # map nt's to a matrix 4 x len(seq) of 0's and 1's.
    #  dtype='int8' fails for N's
    seq_code = np.zeros((4,seq_len), dtype='float16')
    for i in range(seq_len):
        if i < seq_start:
            seq_code[:,i] = 0.25
        else:
            try:
                seq_code[int(seq[i-seq_start]),i] = 1
            except:
                seq_code[:,i] = 0.25

    # flatten and make a column vector 1 x len(seq)
    if flatten:
        seq_code = seq_code.flatten()[None,:]

    return seq_code

In [3]:
#READ THE FILES
pos_seqs = {}
neg_seqs = {}

fasta_sequences = SeqIO.parse(open("./examples/example_pos_sequences.fa"),'fasta')
for fasta in fasta_sequences:
    name, sequence = fasta.id, str(fasta.seq).upper()
    pos_seqs[name] = sequence
    
fasta_sequences = SeqIO.parse(open("./examples/example_neg_sequences.fa"),'fasta')
for fasta in fasta_sequences:
    name, sequence = fasta.id, str(fasta.seq).upper()
    neg_seqs[name] = sequence
    
pos_seqs = pd.Series(pos_seqs)
neg_seqs = pd.Series(neg_seqs)

In [4]:
#ONE HOT ENCODE
pos_seqs = pos_seqs.map(lambda x: dna_one_hot(x, flatten=False))
pos_seqs = np.stack(pos_seqs, axis=0)

neg_seqs = neg_seqs.map(lambda x: dna_one_hot(x, flatten=False))
neg_seqs = np.stack(neg_seqs, axis=0)

all_seq = np.concatenate((pos_seqs, neg_seqs))

In [5]:
pos_labels = np.ones((len(pos_seqs),1))
neg_labels = np.zeros((len(pos_seqs),1))

all_labels = np.concatenate((pos_labels, neg_labels))

In [6]:
#create a dataloader
x = torch.Tensor(all_seq)

x_lab = torch.Tensor(all_labels)

all_dataset = torch.utils.data.TensorDataset(x, x_lab)
dataloader = torch.utils.data.DataLoader(all_dataset, 
                                                  batch_size=100, shuffle=True,
                                                  num_workers=0)
x_lab

tensor([[1.],
        [1.],
        [1.],
        ...,
        [0.],
        [0.],
        [0.]])

In [7]:
#the model
class DanQ(nn.Module):
    def __init__(self, sequence_length, num_classes, weight_path=None):
        super(DanQ, self).__init__()

        self._n_channels = math.floor(
            (sequence_length - 25) / 13)

        self.Conv1 = nn.Conv1d(in_channels=4, out_channels=320, kernel_size=26)
        self.Maxpool = nn.MaxPool1d(kernel_size=13, stride=13)
        self.Drop1 = nn.Dropout(p=0.2)
        self.BiLSTM = nn.LSTM(input_size=320, hidden_size=320, num_layers=2,
                                 batch_first=True,
                                 dropout=0.5,
                                 bidirectional=True)
        self.Linear1 = nn.Linear(self._n_channels*640, 925)
        self.Linear2 = nn.Linear(925, num_classes)
        
        if weight_path :
            self.load_weights(weight_path)

    def forward(self, input):
        x = self.Conv1(input)
        x = F.relu(x)
        x = self.Maxpool(x)
        x = self.Drop1(x)
        x_x = torch.transpose(x, 1, 2)
        x, (h_n,h_c) = self.BiLSTM(x_x)
        x = x.contiguous().view(-1, self._n_channels*640)
        x = self.Linear1(x)
        x = F.relu(x)
        x = self.Linear2(x)
        return x
    
    def load_weights(self, weight_path):
        sd = torch.load(weight_path)
        new_dict = OrderedDict()
        keys = list(self.state_dict().keys())
        values = list(sd.values())
        for i in range(len(values)):
            v = values[i]
            if v.dim() > 1 :
                if v.shape[-1] ==1 :
                    new_dict[keys[i]] = v.squeeze(-1)
                    continue
            new_dict[keys[i]] = v
        self.load_state_dict(new_dict)

In [8]:
#create a model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = DanQ(1000, 1).to(device)

#criterion = nn.BCEWithLogitsLoss() #- no weights
criterion = nn.BCEWithLogitsLoss()

#optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [9]:
if not os.path.exists('weights_DanQ'):
    os.makedirs('weights_DanQ')

In [11]:
num_epochs = 30
model, train_error = train_model(dataloader, model, device, criterion,  
                                 optimizer, num_epochs, 
                                 "weights_DanQ", 
                                 "CTCF", verbose=True)

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED