In [1]:
from Bio import SeqIO
import gzip
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
from sklearn.metrics import (
    average_precision_score, precision_recall_curve,
    roc_auc_score, roc_curve,
    matthews_corrcoef
)
from sklearn.model_selection import train_test_split
from time import time
import torch
import torch.nn as nn
# from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset

from utils.pytorchtools import EarlyStopping

In [2]:
# Adapted from:
# https://github.com/FunctionLab/selene/blob/master/models/danQ.py
class DanQ(nn.Module):
    """DanQ architecture (Quang & Xie, 2016)."""

    def __init__(self, sequence_length, n_features):
        """
        Parameters
        ----------
        sequence_length : int
            Input sequence length
        n_features : int
            Total number of features to predict
        """
        super(DanQ, self).__init__()

        self.nnet = nn.Sequential(
            nn.Conv1d(4, 320, kernel_size=26),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=13, stride=13),
            nn.Dropout(0.2)
        )

        self.bdlstm = nn.Sequential(
            nn.LSTM(320, 320, num_layers=1, batch_first=True, bidirectional=True)
        )

        self._n_channels = math.floor((sequence_length - 25) / 13)

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self._n_channels * 640, 925),
            nn.ReLU(inplace=True),
            nn.Linear(925, n_features),
            nn.Sigmoid()
        )

    def forward(self, x):
        """Forward propagation of a batch."""
        out = self.nnet(x)
        reshape_out = out.transpose(0, 1).transpose(0, 2)
        out, _ = self.bdlstm(reshape_out)
        out = out.transpose(0, 1)
        reshape_out = out.contiguous().view(
            out.size(0), 640 * self._n_channels)
        predict = self.classifier(reshape_out)

        return(predict)

def get_criterion():
    """
    Specify the appropriate loss function (criterion) for this model.

    Returns
    -------
    torch.nn._Loss
    """
    return(nn.BCELoss())

def get_optimizer(params, lr=0.001):
    return(torch.optim.Adam(params, lr=lr))

In [3]:
def one_hot_encode(seq):
    """One hot encodes a sequence."""

    seq = seq.replace("A", "0")
    seq = seq.replace("C", "1")
    seq = seq.replace("G", "2")
    seq = seq.replace("T", "3")

    encoded_seq = np.zeros((4, len(seq)), dtype="float16")

    for i in range(len(seq)):
        if seq[i].isdigit():
            encoded_seq[int(seq[i]), i] = 1
        else:
            # i.e. Ns
            encoded_seq[:, i] = 0.25

    return(encoded_seq)

def one_hot_decode(encoded_seq):
    """Reverts a sequence's one hot encoding."""

    seq = []
    code = list("ACGT")
 
    for i in encoded_seq.transpose(1, 0):
        try:
            seq.append(code[int(np.where(i == 1)[0])])
        except:
            # i.e. N?
            seq.append("N")

    return("".join(seq))

def reverse_complement(encoded_seqs):
    """Reverse complements one hot encoding for a list of sequences."""
    return(encoded_seqs[..., ::-1, ::-1])

In [4]:
# Parse FASTA sequences
pos_seqs = {}
neg_seqs = {}
with gzip.open("../Data/pos_seqs.fa.gz", "rt") as handle:
    for seq_record in SeqIO.parse(handle, "fasta"):
        pos_seqs[seq_record.id] = str(seq_record.seq).upper()
pos_seqs = pd.Series(pos_seqs)
with gzip.open("../Data/neg_seqs.fa.gz", "rt") as handle:
    for seq_record in SeqIO.parse(handle, "fasta"):
        neg_seqs[seq_record.id] = str(seq_record.seq).upper()
neg_seqs = pd.Series(neg_seqs)
pos_seqs

7          GGACAGGTCAACTTGAGGAGATTTTGGGCCTTCATAGGCCACCAGG...
16         CCACATTATACAGCTTCTGAAAGGGTTGCTTGACCCACAGATGTGA...
22         GAAGGAGACTGATGTGGTTTCTCCTCAGTTTCTCTGTGCGGCACCA...
49         ACCTCTATGGTGTCGGCGAAGACCCGCCCTTGTGACGTCACGGAAG...
107        GGGAATGCTAAACAGAGGCAGATCTAAACTTAGGAGTTAGGCTTCT...
                                 ...                        
1817711    TGCTAGGAGCCGCAGTCATACTGGCTGTGCATGAGACCATCCACCT...
1817721    AAGGCAAAGTGAGAAAAAGAGGAAACTAGAAGGCTGGTTGGGCTGT...
1817723    CCTTGTCTTGGCATTTTCGGAGAGAACATGGACTCTGTGTTGTTTG...
1817732    CTCTTACTCTTTCTGTGTGTGAAATGTGCAAGTAGCTTTACAGTCT...
1817832    TCTTCTTTCCCTTTCCCTCCTCCCTAGGGGGTGTGACTGTAGAGCA...
Length: 78983, dtype: object

In [5]:
# One-hot encode sequences
pos_seqs_1_hot = pos_seqs.map(lambda x: one_hot_encode(x))
pos_seqs_1_hot = np.stack(pos_seqs_1_hot, axis=0)
neg_seqs_1_hot = neg_seqs.map(lambda x: one_hot_encode(x))
neg_seqs_1_hot = np.stack(neg_seqs_1_hot, axis=0)
pos_seqs_1_hot

array([[[0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [1., 1., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 1., 0.]],

       [[0., 0., 1., ..., 0., 0., 0.],
        [1., 1., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 1., 0.]],

       [[0., 1., 1., ..., 1., 1., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 1., 1., 1.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.],
        [0., 1., 0., ..., 1., 0., 1.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 1., 1., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 1., ..., 0., 0., 1.]]], dtype=float16)

In [6]:
# Split sequences into train, validation and test
seed = 123
pos_train_seqs, pos_test_seqs = train_test_split(
    pos_seqs_1_hot, test_size=0.2, random_state=seed
)
pos_validation_seqs, pos_test_seqs = train_test_split(
    pos_test_seqs, test_size=0.5, random_state=seed
)
neg_train_seqs, neg_test_seqs = train_test_split(
    neg_seqs_1_hot, test_size=0.2, random_state=seed
)
neg_validation_seqs, neg_test_seqs = train_test_split(
    neg_test_seqs, test_size=0.5, random_state=seed
)
pos_train_seqs

array([[[0., 1., 1., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 1., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 1., ..., 1., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 1., 1.],
        [1., 0., 1., ..., 1., 0., 0.]],

       ...,

       [[0., 0., 1., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 1., 1.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 1., 1., ..., 0., 0., 1.],
        [1., 0., 0., ..., 0., 1., 0.]],

       [[1., 0., 1., ..., 0., 1., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 1., 0., ..., 0., 0., 0.]]], dtype=float16)

In [7]:
# Create TensorDatasets
X = np.concatenate((pos_train_seqs, neg_train_seqs))
y = np.concatenate(
    (np.ones((len(pos_train_seqs), 1)), np.zeros((len(neg_train_seqs), 1)))
)
train_dataset = TensorDataset(torch.Tensor(X), torch.Tensor(y))
X = np.concatenate((pos_validation_seqs, neg_validation_seqs))
y = np.concatenate(
    (np.ones((len(pos_validation_seqs), 1)), np.zeros((len(neg_validation_seqs), 1)))
)
validation_dataset = TensorDataset(torch.Tensor(X), torch.Tensor(y))
X = np.concatenate((pos_test_seqs, neg_test_seqs))
y = np.concatenate(
    (np.ones((len(pos_test_seqs), 1)), np.zeros((len(neg_test_seqs), 1)))
)
test_dataset = TensorDataset(torch.Tensor(X), torch.Tensor(y))
train_dataset

<torch.utils.data.dataset.TensorDataset at 0x7f2649adf190>

In [8]:
# Create DataLoaders
parameters = dict(batch_size=64, shuffle=True, num_workers=8)
train_dataloader = DataLoader(train_dataset, **parameters)
validation_dataloader = DataLoader(validation_dataset, **parameters)
test_dataloader = DataLoader(test_dataset, **parameters)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f2675799910>

In [10]:
# Train and validate
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DanQ(len(pos_seqs[0]), 1).to(device)
output_dir = "./CTCF/"
os.makedirs(output_dir, exist_ok=True)
losses_file = os.path.join(output_dir, "losses.csv")
state_dict = os.path.join(output_dir, "model.pth.tar")

if not os.path.exists(losses_file):

    max_epochs = 100
    criterion = get_criterion()
    optimizer = get_optimizer(model.parameters(), lr=0.0003)
    # scheduler = ReduceLROnPlateau(optimizer, "min", patience=5, verbose=True, factor=0.5)
    early_stopping = EarlyStopping(10, True, path=state_dict)
    train_losses = []
    validation_losses = []

    for epoch in range(1, max_epochs + 1):

        # Train
        t_time = time()
        model.train() # set the model in train mode
        train_losses.append([])
        for seqs, labels in train_dataloader:
            x = seqs.to(device) # shape = (batch_size, 4, 200)
            labels = labels.to(device)
            # Zero existing gradients so they don't add up
            optimizer.zero_grad()
            # Forward pass
            outputs = model(x)
            loss = criterion(outputs, labels) 
            # Backward and optimize
            loss.backward()
            optimizer.step()
            # Keep the loss
            train_losses[-1].append(loss.item())
        t_loss = np.average(train_losses[-1])
        t_time = time() - t_time

        # Validate
        v_time = time()
        model.eval() # set the model in evaluation mode
        validation_losses.append([])
        for seqs, labels in validation_dataloader:
            x = seqs.to(device) # shape = (batch_size, 4, 200)
            labels = labels.to(device)
            with torch.no_grad():
                # Forward pass
                outputs = model(x)
                loss = criterion(outputs, labels) 
                # Keep the loss
                validation_losses[-1].append(loss.item())
        v_loss = np.average(validation_losses[-1])
        v_time = time() - v_time

        print(f'[{epoch:>{3}}/{max_epochs:>{3}}] '
             +f'train_loss: {t_loss:.5f} ({t_time:.3f} sec) '
             +f'valid_loss: {v_loss:.5f} ({v_time:.3f} sec)')

        # Adjust learning rate
        # scheduler.step(math.ceil(v_loss * 1000.0) / 1000.0)

        # EarlyStopping needs to check if the validation loss has decresed, 
        # and if it has, it will save the current model
        early_stopping(v_loss, model)
        if early_stopping.early_stop:
            print("Stop!!!")
            break

    # Losses to DataFrame
    data = []
    for i in range(len(train_losses)):
        for j in range(len(train_losses[i])):
            data.append(["train", i+1, j+1, train_losses[i][j]])
    for i in range(len(validation_losses)):
        for j in range(len(validation_losses[i])):
            data.append(
                ["validation", i+1, j+1, validation_losses[i][j]]
            )
    df = pd.DataFrame(data, columns=["Mode", "Epoch", "Batch", "Loss"])

    # Save losses
    df.to_csv(losses_file)

[  1/100] train_loss: 0.46951 (90.089 sec) valid_loss: 0.31048 (3.426 sec)
Validation loss decreased (inf --> 0.310475), saving model ...
[  2/100] train_loss: 0.29060 (90.471 sec) valid_loss: 0.27290 (3.425 sec)
Validation loss decreased (0.310475 --> 0.272897), saving model ...
[  3/100] train_loss: 0.26428 (90.145 sec) valid_loss: 0.26039 (3.427 sec)
Validation loss decreased (0.272897 --> 0.260393), saving model ...
[  4/100] train_loss: 0.24563 (90.187 sec) valid_loss: 0.25642 (3.418 sec)
Validation loss decreased (0.260393 --> 0.256423), saving model ...
[  5/100] train_loss: 0.23120 (90.148 sec) valid_loss: 0.25111 (3.423 sec)
Validation loss decreased (0.256423 --> 0.251108), saving model ...
[  6/100] train_loss: 0.21669 (90.076 sec) valid_loss: 0.25783 (3.423 sec)
EarlyStopping counter: 1 out of 10
[  7/100] train_loss: 0.20354 (89.932 sec) valid_loss: 0.25682 (3.429 sec)
EarlyStopping counter: 2 out of 10
[  8/100] train_loss: 0.19187 (89.776 sec) valid_loss: 0.25954 (3.429 

In [11]:
# Losses
df = pd.read_csv(losses_file, index_col=0)
# Seaborn aesthetics
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1.5})
sns.set_palette(sns.color_palette(["#1965B0", "#DC050C"]))
# Plot losses
#kwargs = dict(estimator=None, ci=None)
g = sns.lineplot(x="Epoch", y="Loss", hue="Mode", data=df)
# Plot best epoch (i.e. lowest validation loss)
best_epoch = df[(df.Mode == "validation")][["Epoch", "Loss"]]\
    .groupby("Epoch").mean().idxmin()
g.axvline(
    int(best_epoch), linestyle=":", color="dimgray", label="best epoch"
)
# Plot legend
g.legend_.remove()
handles, labels = g.axes.get_legend_handles_labels()
plt.legend(handles, labels, frameon=False)
# Modify axes
g.set(xlim=(0, int(df["Epoch"].max()) + 1))
g.set(ylim=(0, 0.5))
# Remove spines
sns.despine()
# Save & close
fig = g.get_figure()
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "losses.png"))
plt.close(fig)

In [12]:
# Test
y = None
predictions = None
model.load_state_dict(torch.load(state_dict))

model.eval() # set the model in evaluation mode
for seqs, labels in test_dataloader:
    x = seqs.to(device) # shape = (batch_size, 4, 200)
    labels = labels.to(device)
    with torch.no_grad():
        # Forward pass
        outputs = model(x)
        # Save predictions
        if predictions is None and y is None:
            predictions = outputs.data.cpu().numpy()
            y = labels.data.cpu().numpy()
        else:
            predictions = np.append(
                predictions, outputs.data.cpu().numpy(), axis=0
            )
            y = np.append(y, labels.data.cpu().numpy(), axis=0)
predictions

array([[0.97409576],
       [0.01435998],
       [0.4332375 ],
       ...,
       [0.9999329 ],
       [0.05632421],
       [0.99999857]], dtype=float32)

In [13]:
# Metrics
metrics = dict(AUCPR=None, AUCROC=None, MCC=None)
p = predictions.flatten()
l = y.flatten()

# Metrics to DataFrame
for metric in metrics:
    if metric == "AUCPR":
        score = average_precision_score(l, p)
    elif metric == "AUCROC":
        score = roc_auc_score(l, p)
    elif metric == "MCC":
        score = matthews_corrcoef(l, np.rint(p))
    metrics[metric] = score

print(f'Final performance metrics: '
     +f'AUCROC: {metrics["AUCROC"]:.5f}, '
     +f'AUCPR: {metrics["AUCPR"]:.5f}, '
     +f'MCC: {metrics["MCC"]:.5f}')

Final performance metrics: AUCROC: 0.96213, AUCPR: 0.96623, MCC: 0.81622


In [14]:
# Plot
def __plot(data, columns, metric, score):

    # Metric to DataFrame
    df = pd.DataFrame(data, columns=columns)
    # Seaborn aesthetics
    sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1.5})
    sns.set_palette(sns.color_palette(["#1965B0"]))
    # Plot metric
    kwargs = dict(estimator=None, ci=None)
    g = sns.lineplot(x=columns[0], y=columns[1], data=df, **kwargs)
    # Add metric score
    kwargs = dict(horizontalalignment="center", verticalalignment="center")
    plt.text(.5, 0, "%s = %.5f" % (metric, score), **kwargs)
    # Remove spines
    sns.despine()
    # Save & close
    fig = g.get_figure()
    fig.tight_layout()
    fig.savefig(os.path.join(output_dir, "%s.png" % metric))
    plt.close(fig)

# AUCROC
fpr, tpr, _ = roc_curve(l, p)
data = list(zip(fpr, tpr))
__plot(data, ["Fpr", "Tpr"], "AUCROC", metrics["AUCROC"])

In [15]:
# AUCPR
prec, recall, _ = precision_recall_curve(l, p)
# i.e. precision = 0, recall = 1
prec = np.insert(prec, 0, 0., axis=0)
recall = np.insert(recall, 0, 1., axis=0)
data = list(zip(recall, prec))
__plot(data, ["Recall", "Precision"], "AUCPR", metrics["AUCPR"])

In [16]:
# Sanity check:
# Reverse complement test sequences
pos_test_seqs_rc = reverse_complement(pos_test_seqs)
neg_test_seqs_rc = reverse_complement(neg_test_seqs)
# Create TensorDatasets
X = np.concatenate((pos_test_seqs_rc, neg_test_seqs_rc))
y = np.concatenate(
    (np.ones((len(pos_test_seqs_rc), 1)), np.zeros((len(neg_test_seqs_rc), 1)))
)
test_dataset_rc = TensorDataset(torch.Tensor(X), torch.Tensor(y))
test_dataset_rc

<torch.utils.data.dataset.TensorDataset at 0x7f25e8a7d130>

In [17]:
# Create DataLoaders
parameters = dict(batch_size=64, shuffle=True, num_workers=8)
test_dataloader_rc = DataLoader(test_dataset_rc, **parameters)
test_dataloader_rc

<torch.utils.data.dataloader.DataLoader at 0x7f25e8c8da00>

In [18]:
# Test reverse complement sequences
y = None
predictions = None
state_dict = os.path.join(output_dir, "model.pth.tar")
model.load_state_dict(torch.load(state_dict))

model.eval() # set the model in evaluation mode
for seqs, labels in test_dataloader_rc:
    x = seqs.to(device) # shape = (batch_size, 4, 200)
    labels = labels.to(device)
    with torch.no_grad():
        # Forward pass
        outputs = model(x)
        # Save predictions
        if predictions is None and y is None:
            predictions = outputs.data.cpu().numpy()
            y = labels.data.cpu().numpy()
        else:
            predictions = np.append(
                predictions, outputs.data.cpu().numpy(), axis=0
            )
            y = np.append(y, labels.data.cpu().numpy(), axis=0)
predictions

array([[0.9998809 ],
       [0.09347926],
       [0.97753906],
       ...,
       [0.99864525],
       [0.9995192 ],
       [0.96260846]], dtype=float32)

In [19]:
# Metrics
metrics = dict(AUCPR=None, AUCROC=None, MCC=None)
p = predictions.flatten()
l = y.flatten()

# Metrics to DataFrame
for metric in metrics:
    if metric == "AUCPR":
        score = average_precision_score(l, p)
    elif metric == "AUCROC":
        score = roc_auc_score(l, p)
    elif metric == "MCC":
        score = matthews_corrcoef(l, np.rint(p))
    metrics[metric] = score

print(f'Final performance metrics: '
     +f'AUCROC: {metrics["AUCROC"]:.5f}, '
     +f'AUCPR: {metrics["AUCPR"]:.5f}, '
     +f'MCC: {metrics["MCC"]:.5f}')

Final performance metrics: AUCROC: 0.96119, AUCPR: 0.96613, MCC: 0.81351
