In [1]:
import copy
import gzip
import math
import numpy as np
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from time import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
from torch import relu, sigmoid
import torch.nn.modules.activation as activation
from Bio import SeqIO

from utils.pytorchtools import EarlyStopping

In [2]:
class DanQ(nn.Module):
    """DanQ architecture (Quang & Xie, 2016)."""

    def __init__(self, sequence_length, n_targets):
        super(DanQ, self).__init__()

        self._n_channels = math.floor((sequence_length - 25) / 13)

        self.Conv1 = nn.Conv1d(
            in_channels=4, out_channels=320, kernel_size=26
        )
        self.Maxpool = nn.MaxPool1d(kernel_size=13, stride=13)
        self.Drop1 = nn.Dropout(p=0.2)
        self.BiLSTM = nn.LSTM(
            input_size=320, hidden_size=320, num_layers=2,
            batch_first=True, dropout=0.5, bidirectional=True
        )
        self.Linear1 = nn.Linear(self._n_channels*640, 925)
        self.Linear2 = nn.Linear(925, n_targets)

    def forward(self, input):
        x = self.Conv1(input)
        x = F.relu(x)
        x = self.Maxpool(x)
        x = self.Drop1(x)
        x_x = torch.transpose(x, 1, 2)
        x, (h_n,h_c) = self.BiLSTM(x_x)
        x = x.contiguous().view(-1, self._n_channels*640)
        x = self.Linear1(x)
        x = F.relu(x)
        x = self.Linear2(x)

        return(x)

def get_criterion():
    """
    Specify the appropriate loss function (criterion) for this model.

    Returns
    -------
    torch.nn._Loss
    """
    return(nn.BCEWithLogitsLoss())

def get_optimizer(params, lr=0.003):
    return(torch.optim.Adam(params, lr=lr))

In [3]:
def one_hot_encode(seq):
    """One hot encodes a sequence."""

    seq = seq.replace("A", "0")
    seq = seq.replace("C", "1")
    seq = seq.replace("G", "2")
    seq = seq.replace("T", "3")

    encoded_seq = np.zeros((4, len(seq)), dtype="float16")

    for i in range(len(seq)):
        if seq[i].isdigit():
            encoded_seq[int(seq[i]), i] = 1
        else:
            # i.e. Ns
            encoded_seq[:, i] = 0.25

    return(encoded_seq)

def one_hot_decode(encoded_seq):
    """Reverts a sequence's one hot encoding."""

    seq = []
    code = list("ACGT")
 
    for i in encoded_seq.transpose(1, 0):
        try:
            seq.append(code[int(np.where(i == 1)[0])])
        except:
            # i.e. N?
            seq.append("N")

    return("".join(seq))

def reverse_complement(encoded_seqs):
    """Reverse complements one hot encoding for a list of sequences."""
    return(encoded_seqs[..., ::-1, ::-1])

In [4]:
# Parse FASTA sequences
pos_seqs = {}
neg_seqs = {}
with gzip.open("../Data/pos_seqs.fa.gz", "rt") as handle:
    for seq_record in SeqIO.parse(handle, "fasta"):
        pos_seqs[seq_record.id] = str(seq_record.seq).upper()
pos_seqs = pd.Series(pos_seqs)
with gzip.open("../Data/neg_seqs.fa.gz", "rt") as handle:
    for seq_record in SeqIO.parse(handle, "fasta"):
        neg_seqs[seq_record.id] = str(seq_record.seq).upper()
neg_seqs = pd.Series(neg_seqs)
pos_seqs

7          GGACAGGTCAACTTGAGGAGATTTTGGGCCTTCATAGGCCACCAGG...
16         CCACATTATACAGCTTCTGAAAGGGTTGCTTGACCCACAGATGTGA...
22         GAAGGAGACTGATGTGGTTTCTCCTCAGTTTCTCTGTGCGGCACCA...
49         ACCTCTATGGTGTCGGCGAAGACCCGCCCTTGTGACGTCACGGAAG...
107        GGGAATGCTAAACAGAGGCAGATCTAAACTTAGGAGTTAGGCTTCT...
                                 ...                        
1817711    TGCTAGGAGCCGCAGTCATACTGGCTGTGCATGAGACCATCCACCT...
1817721    AAGGCAAAGTGAGAAAAAGAGGAAACTAGAAGGCTGGTTGGGCTGT...
1817723    CCTTGTCTTGGCATTTTCGGAGAGAACATGGACTCTGTGTTGTTTG...
1817732    CTCTTACTCTTTCTGTGTGTGAAATGTGCAAGTAGCTTTACAGTCT...
1817832    TCTTCTTTCCCTTTCCCTCCTCCCTAGGGGGTGTGACTGTAGAGCA...
Length: 78983, dtype: object

In [5]:
# One-hot encode sequences
pos_seqs_1_hot = pos_seqs.map(lambda x: one_hot_encode(x))
pos_seqs_1_hot = np.stack(pos_seqs_1_hot, axis=0)
neg_seqs_1_hot = neg_seqs.map(lambda x: one_hot_encode(x))
neg_seqs_1_hot = np.stack(neg_seqs_1_hot, axis=0)
pos_seqs_1_hot

array([[[0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [1., 1., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 1., 0.]],

       [[0., 0., 1., ..., 0., 0., 0.],
        [1., 1., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 1., 0.]],

       [[0., 1., 1., ..., 1., 1., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 1., 1., 1.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.],
        [0., 1., 0., ..., 1., 0., 1.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 1., 1., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 1., ..., 0., 0., 1.]]], dtype=float16)

In [6]:
# Split sequences into train, validation and test
seed = 123
pos_train_seqs, pos_test_seqs = train_test_split(
    pos_seqs_1_hot, test_size=0.2, random_state=seed
)
pos_validation_seqs, pos_test_seqs = train_test_split(
    pos_test_seqs, test_size=0.5, random_state=seed
)
neg_train_seqs, neg_test_seqs = train_test_split(
    neg_seqs_1_hot, test_size=0.2, random_state=seed
)
neg_validation_seqs, neg_test_seqs = train_test_split(
    neg_test_seqs, test_size=0.5, random_state=seed
)
pos_train_seqs

array([[[0., 1., 1., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 1., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 1., ..., 1., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 1., 1.],
        [1., 0., 1., ..., 1., 0., 0.]],

       ...,

       [[0., 0., 1., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 1., 1.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 1., 1., ..., 0., 0., 1.],
        [1., 0., 0., ..., 0., 1., 0.]],

       [[1., 0., 1., ..., 0., 1., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 1., 0., ..., 0., 0., 0.]]], dtype=float16)

In [7]:
# Reverse complement train
pos_train_seqs_rc = np.append(pos_train_seqs, reverse_complement(pos_train_seqs), axis=0)
neg_train_seqs_rc = np.append(neg_train_seqs, reverse_complement(pos_train_seqs), axis=0)
pos_train_seqs_rc

array([[[0., 1., 1., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 1., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 1., ..., 1., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 1., 1.],
        [1., 0., 1., ..., 1., 0., 0.]],

       ...,

       [[1., 1., 0., ..., 0., 1., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 1., 0., 0.]],

       [[0., 1., 0., ..., 0., 0., 1.],
        [1., 0., 0., ..., 1., 1., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 1., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 1., 0., ..., 1., 0., 1.]]], dtype=float16)

In [8]:
# Create a TensorDatasets
X = np.concatenate((pos_train_seqs_rc, neg_train_seqs_rc))
y = np.concatenate(
    (np.ones((len(pos_train_seqs_rc), 1)), np.zeros((len(neg_train_seqs_rc), 1)))
)
train_dataset = TensorDataset(torch.Tensor(X), torch.Tensor(y))
X = np.concatenate((pos_validation_seqs, neg_validation_seqs))
y = np.concatenate(
    (np.ones((len(pos_validation_seqs), 1)), np.zeros((len(neg_validation_seqs), 1)))
)
validation_dataset = TensorDataset(torch.Tensor(X), torch.Tensor(y))
X = np.concatenate((pos_test_seqs, neg_test_seqs))
y = np.concatenate(
    (np.ones((len(pos_test_seqs), 1)), np.zeros((len(neg_test_seqs), 1)))
)
test_dataset = TensorDataset(torch.Tensor(X), torch.Tensor(y))
train_dataset

<torch.utils.data.dataset.TensorDataset at 0x7fac2530ca00>

In [9]:
# Create DataLoaders
parameters = dict(batch_size=64, shuffle=True, num_workers=8)
train_dataloader = DataLoader(train_dataset, **parameters)
validation_dataloader = DataLoader(validation_dataset, **parameters)
test_dataloader = DataLoader(test_dataset, **parameters)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7fac22f9e100>

In [10]:
# Train and validate
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
max_epochs = 100
learning_rates = [0.003, 0.001, 0.0003, 0.0001]
output_dir = "./CTCF/"
os.makedirs(output_dir, exist_ok=True)

for lr in learning_rates:

    # Initialize model, criterion, optimizer
    model = DanQ(len(pos_seqs[0]), 1).to(device)
    criterion = get_criterion()
    optimizer = get_optimizer(model.parameters(), lr)
    state_dict = os.path.join(output_dir, "model-%s.pth.tar" % lr)
    early_stopping = EarlyStopping(10, True, path=state_dict)
    train_losses = []
    validation_losses = []

    for epoch in range(1, max_epochs + 1):

        # Train
        t_time = time()
        model.train() # set the model in train mode
        train_losses.append([])
        for seqs, labels in train_dataloader:
            x = seqs.to(device) # shape = (batch_size, 4, 200)
            labels = labels.to(device)
            # Zero existing gradients so they don't add up
            optimizer.zero_grad()
            # Forward pass
            outputs = model(x)
            loss = criterion(outputs, labels) 
            # Backward and optimize
            loss.backward()
            optimizer.step()
            # Keep the loss
            train_losses[-1].append(loss.item())
        t_loss = np.average(train_losses[-1])
        t_time = time() - t_time

        # Validate
        v_time = time()
        model.eval() # set the model in evaluation mode
        validation_losses.append([])
        for seqs, labels in validation_dataloader:
            x = seqs.to(device) # shape = (batch_size, 4, 200)
            labels = labels.to(device)
            with torch.no_grad():
                # Forward pass
                outputs = model(x)
                loss = criterion(outputs, labels) 
                # Keep the loss
                validation_losses[-1].append(loss.item())
        v_loss = np.average(validation_losses[-1])
        v_time = time() - v_time

        print(f'[{epoch:>{3}}/{max_epochs:>{3}}] '
             +f'lr: {lr} '
             +f'train_loss: {t_loss:.5f} ({t_time:.3f} sec) '
             +f'valid_loss: {v_loss:.5f} ({v_time:.3f} sec)')

        # EarlyStopping needs to check if the validation loss has decresed, 
        # and if it has, it will save the current model.
        early_stopping(v_loss, model)
        if early_stopping.early_stop:
            # Empty cache
            with torch.cuda.device(device):
                torch.cuda.empty_cache()
            print("Stop!!!")
            break

[  1/100] lr: 0.003 train_loss: 0.69588 (81.969 sec) valid_loss: 0.69316 (2.451 sec)
Validation loss decreased (inf --> 0.693156), saving model ...
[  2/100] lr: 0.003 train_loss: 0.69318 (81.489 sec) valid_loss: 0.69319 (2.295 sec)
EarlyStopping counter: 1 out of 10
[  3/100] lr: 0.003 train_loss: 0.69318 (82.881 sec) valid_loss: 0.69317 (2.377 sec)
EarlyStopping counter: 2 out of 10
[  4/100] lr: 0.003 train_loss: 0.69318 (80.840 sec) valid_loss: 0.69333 (2.155 sec)
EarlyStopping counter: 3 out of 10
[  5/100] lr: 0.003 train_loss: 0.69321 (81.982 sec) valid_loss: 0.69315 (2.277 sec)
Validation loss decreased (0.693156 --> 0.693147), saving model ...
[  6/100] lr: 0.003 train_loss: 0.69320 (71.562 sec) valid_loss: 0.69316 (2.252 sec)
EarlyStopping counter: 1 out of 10
[  7/100] lr: 0.003 train_loss: 0.69319 (83.176 sec) valid_loss: 0.69315 (2.410 sec)
EarlyStopping counter: 2 out of 10
[  8/100] lr: 0.003 train_loss: 0.69319 (82.780 sec) valid_loss: 0.69336 (2.338 sec)
EarlyStopping 

[ 18/100] lr: 0.0001 train_loss: 0.54522 (83.577 sec) valid_loss: 0.37832 (2.308 sec)
EarlyStopping counter: 10 out of 10
Stop!!!


In [19]:
# Compute performance metrics
def compute_performance_metrics(predictions, labels, lr):

    # Metrics
    metrics = dict(AUCPR=None, AUCROC=None, MCC=None)

    # Metrics to DataFrame
    for metric in metrics:
        if metric == "AUCPR":
            score = average_precision_score(labels, predictions)
            self.metrics.setdefault(metric, score)
            prec, recall, _ = precision_recall_curve(labels, predictions)
            # i.e. precision = 0, recall = 1
            prec = np.insert(prec, 0, 0., axis=0)
            recall = np.insert(recall, 0, 1., axis=0)
            data = list(zip(recall, prec))
            self.__visualize_metric(data, ["Recall", "Precision"], metric)
        elif metric == "AUCROC":
            score = roc_auc_score(labels, predictions)
            self.metrics.setdefault(metric, score)
            fpr, tpr, _ = roc_curve(labels, predictions)
            data = list(zip(fpr, tpr))
            self.__visualize_metric(data, ["Fpr", "Tpr"], metric)
        elif metric == "MCC":
            score = matthews_corrcoef(labels, np.rint(predictions))
            self.metrics.setdefault(metric, score)
    
    if self._verbose:
        write(
            None,
            (f'Final performance metrics: '
            +f'AUCROC: {self.metrics["AUCROC"]:.5f}, '
            +f'AUCPR: {self.metrics["AUCPR"]:.5f}, '
            +f'MCC: {self.metrics["MCC"]:.5f}')
        )

def __visualize_metric(self, data, labels, metric):

    # Metric to DataFrame
    df = pd.DataFrame(data, columns=labels)

    # Seaborn aesthetics
    sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1.5})
    sns.set_palette(sns.color_palette(["#1965B0"]))

    # Plot metric
    kwargs = dict(estimator=None, ci=None)
    g = sns.lineplot(x=labels[0], y=labels[1], data=df, **kwargs)

    # Add metric score
    kwargs = dict(horizontalalignment="center", verticalalignment="center")
    plt.text(.5, 0, "%s = %.5f" % (metric, self.metrics[metric]), **kwargs)

    # Remove spines
    sns.despine()

    # Save & close
    fig = g.get_figure()
    fig.savefig(os.path.join(self.output_dir, "%s.png" % metric))
    plt.close(fig)

for lr in learning_rates:
   
    # Load the best model
    labels = None
    predictions = None
    model = DanQ(len(pos_seqs[0]), 1).to(device)
    state_dict = os.path.join(output_dir, "model-%s.pth.tar" % lr)
    model.load_state_dict(torch.load(state_dict))
    model.eval() # set the model in evaluation mode

    for inputs, targets in test_dataloader:

        inputs = inputs.to(device)
        targets = targets.to(device)

        with torch.no_grad():

            # Forward pass
            outputs = torch.sigmoid(model(inputs))

            if predictions is None and labels is None:
                predictions = outputs.data.cpu().numpy()
                labels = targets.data.cpu().numpy()
            else:
                predictions = np.append(
                    predictions, outputs.data.cpu().numpy(), axis=0
                )
                labels = np.append(
                    labels, targets.data.cpu().numpy(), axis=0
                )

    compute_performance_metrics(predictions.flatten(), labels.flatten(), lr)

tensor([[0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0.4997],
        [0

In [16]:
model = None

In [46]:
fwd = pos_train[:1]
fwd

array([[[0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1.,
         1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1., 0., 1.,
         

In [48]:
one_hot_decode(fwd[0])

'GCATCCACACACCCTCAGATGCTTCCTTTGACGCCCTCTGCTGTGCCCCTAGACACCCCTATCCCGCCACTGGCTGAAGCTGGACTTTGGAGCCATCTGCCTCCCTTGCCTGCGTCCACACCCCGCGCCAGTCCTCAGCCTCCAAGCCCATCTCAGTCGGACCCTTTCTCATTCCTGCCACTCGCTGCCTGTTCCAGGCC'

In [47]:
rev = reverse_complement_one_hot_encoding(fwd)
rev

array([[[0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1.,
         0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 0.,
         0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0.,
         1., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
         

In [49]:
one_hot_decode(rev[0])

'GGCCTGGAACAGGCAGCGAGTGGCAGGAATGAGAAAGGGTCCGACTGAGATGGGCTTGGAGGCTGAGGACTGGCGCGGGGTGTGGACGCAGGCAAGGGAGGCAGATGGCTCCAAAGTCCAGCTTCAGCCAGTGGCGGGATAGGGGTGTCTAGGGGCACAGCAGAGGGCGTCAAAGGAAGCATCTGAGGGTGTGTGGATGC'

In [20]:
#function to train a model - no validation
def train_model(train_loader, model, device, criterion, optimizer, num_epochs,
               weights_folder, name_ind, verbose):
    total_step = len(train_loader)
    train_error = []
    
    for epoch in range(num_epochs):
        model.train() #tell model explicitly that we train
        logs = {}
        running_loss = 0.0
        for seqs, labels in train_loader:
            x = seqs.to(device) #the input here is (batch_size, 4, 200)
            labels = labels.to(device)
            #zero the existing gradients so they don't add up
            optimizer.zero_grad()
            # Forward pass
            outputs = model(x)
            loss = criterion(outputs, labels) 
            # Backward and optimize
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        #save training loss 
        epoch_loss = running_loss / len(train_loader)
        train_error.append(epoch_loss)
        
        if verbose:
            print ('Epoch [{}], Current Train Loss: {:.5f}' 
                       .format(epoch+1, epoch_loss))
            
        #save model weights
        model_wts = copy.deepcopy(model.state_dict())
        torch.save(model_wts, weights_folder + "/"+"model_epoch_"+str(epoch+1)+"_"+
                       name_ind+".pth") 
    return model, train_error

#function to one-hot encode
############################################################
def dna_one_hot(seq, seq_len=None, flatten=True):
    if seq_len == None:
        seq_len = len(seq)
        seq_start = 0
    else:
        if seq_len <= len(seq):
            # trim the sequence
            seq_trim = (len(seq)-seq_len) // 2
            seq = seq[seq_trim:seq_trim+seq_len]
            seq_start = 0
        else:
            seq_start = (seq_len-len(seq)) // 2

    seq = seq.upper()

    seq = seq.replace('A','0')
    seq = seq.replace('C','1')
    seq = seq.replace('G','2')
    seq = seq.replace('T','3')

    # map nt's to a matrix 4 x len(seq) of 0's and 1's.
    #  dtype='int8' fails for N's
    seq_code = np.zeros((4,seq_len), dtype='float16')
    for i in range(seq_len):
        if i < seq_start:
            seq_code[:,i] = 0.25
        else:
            try:
                seq_code[int(seq[i-seq_start]),i] = 1
            except:
                seq_code[:,i] = 0.25

    # flatten and make a column vector 1 x len(seq)
    if flatten:
        seq_code = seq_code.flatten()[None,:]

    return seq_code