In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/bioinformatics

In [None]:
import itertools
import os
import random
import shutil
import time
from tqdm import tqdm
import uuid
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import matplotlib.pyplot as plt

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

# Config

In [None]:
TRAIN_ON_TF = 'TF_2'
TRAIN_PROBE = 'A'
TEST_PROBE = 'B'
MOTIF_LEN = 24
NUM_MOTIF_DETECTORS=10
BATCH_SIZE = 64
MODEL_NAME = "lstm"
DROPOUT = 0.5
LR = 1e-4
WEIGHT_DECAY = 0.01

In [None]:
DNA_BASES = 'ACGT'
RNA_BASES = 'ACGU'

In [None]:
PBM_DATA = "./data/dream5/pbm"

# Data preparation

### Read DREAM5 sequence data with binding scores

In [None]:
df_seq = pd.read_csv(f"{PBM_DATA}/sequences.tsv", sep='\t')
df_seq.head()

In [None]:
df_targets = pd.read_csv(f"{PBM_DATA}/targets.tsv", sep='\t')
df_targets.head()

### Build a dataframe for single transcription factor

In [None]:
def build_df(tf, df_seq, df_targets):
    df = df_seq.copy()
    df["Target"] = df_targets[tf].values
    return df

In [None]:
df = build_df(TRAIN_ON_TF, df_seq, df_targets)
df.head()

In [None]:
df['Target'].mean()

In [None]:
df['Target'].std()

### Remove probe specific biases for each sequence

In [None]:
biases = df_targets.median(axis=1).values
biases

In [None]:
df['TargetNorm'] = df['Target'].values / biases
df.head()

### Calculate NA content and drop rows

In [None]:
# Only 4% NA content, it is safe to drop the rows
df['Target'].isna().mean()

In [None]:
df.dropna(subset=['Target'], inplace=True)
df['Target'].isna().mean()

### Train/Test data split

In [None]:
df_train = df[df['Fold ID'] == TRAIN_PROBE]
df_train.head()

In [None]:
df_test = df[df['Fold ID'] == TEST_PROBE]
df_test.head()

In [None]:
df_train.shape, df_test.shape

### Add labels for ROC and AUC

In [None]:
# DREAM5 https://pmc.ncbi.nlm.nih.gov/articles/PMC3687085/
def add_label(d):
    # Add positive label to only rows with binding score higher than mean + 4*std
    mean = d['Target'].mean()
    std = d['Target'].std()
    lower_limit = mean + 4*std

    # Limits MAX: 1300 rows MIN: 50 rows
    top = d[d['Target'] > lower_limit].copy()
    if len(top) >= 50:
        top = top.sort_values(by='Target', ascending=False).head(1300)
    else:
        top = d.sort_values(by='Target', ascending=False, inplace=False).head(50)

    d['Label'] = 0
    d.loc[top.index, 'Label'] = 1

In [None]:
add_label(df_train)
df_train.head()

In [None]:
df_train[df_train['Label'] == 1].head()

In [None]:
add_label(df_test)
df_test.head()

In [None]:
df_test[df_test['Label'] == 1].head()

### DNA/RNA sequence to Matrix logic

In [None]:
def fill_cell(motif_len, row, col, bases, seq):
    num_rows = len(seq) + 2 * motif_len - 2

    # First M-1 rows are filled with 0.25
    if row < motif_len-1:
        return 0.25

    # Last M-1 rows are filled with 0.25
    if num_rows-1-row < motif_len-1:
        return 0.25

    idx = row - motif_len + 1
    if seq[idx] == bases[col]:
        return 1.0

    return 0.0

def seq2matrix(seq, motif_len, typ='DNA'):
    bases = DNA_BASES if typ == 'DNA' else RNA_BASES
    num_rows = len(seq) + 2 * motif_len - 2
    result = np.empty([num_rows, 4])
    for row in range(num_rows):
        for col in range(4):
            result[row, col] = fill_cell(motif_len, row, col, bases, seq)
    return np.transpose(result)

In [None]:
# Test the function
S = seq2matrix("ATGG", 3, 'DNA')
S

In [None]:
S.shape

### Sequence Dataset and Loader

In [None]:
class SeqDataset(Dataset):
    def __init__(self, df):
        self.sequences = df['seq'].values
        self.targets = df['TargetNorm'].values
        self.labels = df['Label'].values

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]

        M = seq2matrix(seq, MOTIF_LEN, 'DNA')

        x = torch.tensor(M, dtype=torch.float32)
        y = torch.tensor(self.targets[idx], dtype=torch.float32)
        label = self.labels[idx].copy()

        return x, y, label

In [None]:
class AugmentedSeqDataset(Dataset):
    def __init__(self, df, motif_len, augment_prob=0.7):
        self.sequences = df['seq'].values
        self.targets = df['TargetNorm'].values
        self.labels = df['Label'].values
        self.motif_len = motif_len
        self.augment_prob = augment_prob

    def __len__(self):
        return len(self.sequences)

    def reverse_complement_seq(self, seq):
        complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
        return ''.join([complement.get(b, b) for b in reversed(seq)])

    def __getitem__(self, idx):
        seq = self.sequences[idx]

        if np.random.random() < self.augment_prob:
            seq = self.reverse_complement_seq(seq)

        M = seq2matrix(seq, self.motif_len, 'DNA')
        x = torch.tensor(M, dtype=torch.float32)
        y = torch.tensor(self.targets[idx], dtype=torch.float32)
        label = self.labels[idx].copy()

        return x, y, label

In [None]:
train_dataset = AugmentedSeqDataset(df_train, MOTIF_LEN, augment_prob=0.7)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
xs, ys, labels = [], [], []

for x, y, label in train_dataset:
    xs.append(x)
    ys.append(y)
    labels.append(label)

x_train = torch.stack(xs)
y_train = torch.tensor(ys, dtype=torch.float32)
label_train = torch.tensor(labels)

In [None]:
x, target, label = next(iter(train_loader))
x.shape, target.shape, label.shape

In [None]:
test_dataset = SeqDataset(df_test)

xs, ys, labels = [], [], []

for x, y, label in test_dataset:
    xs.append(x)
    ys.append(y)
    labels.append(label)

x_test = torch.stack(xs)
y_test = torch.tensor(ys, dtype=torch.float32)
label_test = torch.tensor(labels)

In [None]:
x_test.shape, y_test.shape, label_test.shape

# Model

### DeepBind Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class DeepBindShallow(nn.Module):
    def __init__(self, num_motif_detectors, motif_len):
        super().__init__()

        self.conv = nn.Conv1d(in_channels=4, out_channels=num_motif_detectors, kernel_size=motif_len)
        self.fc = nn.Linear(num_motif_detectors, 1)

    def forward(self, x):

        x = self.conv(x)
        x = F.relu(x)
        x, _ = torch.max(x, dim=2)
        x = self.fc(x)

        return x

In [None]:
def reverse_complement(x):
    # x: (B, 4, L)
    # reverse the sequence
    x = torch.flip(x, dims=[2])

    # swap A-T, C-G
    # A C G T → T G C A
    x = x[:, [3, 2, 1, 0], :]

    return x

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class DeepBindOriginal(nn.Module):
    def __init__(self, num_motif_detectors, motif_len):
        super().__init__()

        self.conv = nn.Conv1d(4, num_motif_detectors, kernel_size=motif_len)

        self.fc1 = nn.Linear(num_motif_detectors, 32)
        self.fc2 = nn.Linear(32, 1)

        self.init_weights(self.conv)
        self.init_weights(self.fc1)
        self.init_weights(self.fc2)

    def init_weights(self, component):
        init.kaiming_normal_(component.weight, nonlinearity='relu')
        init.zeros_(component.bias)

    def forward_pass(self, x):
        x = F.relu(self.conv(x))
        x, _ = torch.max(x, dim=2)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def forward(self, x):
        r = self.forward_pass(x)
        r_comp = self.forward_pass(reverse_complement(x))
        return torch.max(r, r_comp)

In [None]:
class DeepBindDeeper(nn.Module):
    def __init__(self, num_motif_detectors, motif_len, dropout=0.3):
        super().__init__()

        self.conv1 = nn.Conv1d(4, num_motif_detectors, kernel_size=motif_len)
        self.bn1 = nn.BatchNorm1d(num_motif_detectors)

        self.conv2 = nn.Conv1d(num_motif_detectors, num_motif_detectors * 2, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(num_motif_detectors * 2)

        self.fc1 = nn.Linear(num_motif_detectors * 2, 128)
        self.bn_fc1 = nn.BatchNorm1d(128)
        self.dropout1 = nn.Dropout(dropout)

        self.fc2 = nn.Linear(128, 64)
        self.bn_fc2 = nn.BatchNorm1d(64)
        self.dropout2 = nn.Dropout(dropout)

        self.fc3 = nn.Linear(64, 32)
        self.dropout3 = nn.Dropout(dropout / 2)

        self.fc4 = nn.Linear(32, 1)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)

    def forward_pass(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x, _ = torch.max(x, dim=2)

        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = F.relu(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = F.relu(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = F.relu(x)
        x = self.dropout3(x)

        x = self.fc4(x)
        return x

    def forward(self, x):
        r = self.forward_pass(x)
        r_comp = self.forward_pass(reverse_complement(x))
        return torch.max(r, r_comp)

In [None]:
class DeepBindResidual(nn.Module):
    def __init__(self, num_motif_detectors, motif_len, dropout=0.3):
        super().__init__()

        self.conv1 = nn.Conv1d(4, num_motif_detectors, kernel_size=motif_len)
        self.bn1 = nn.BatchNorm1d(num_motif_detectors)

        self.conv2 = nn.Conv1d(num_motif_detectors, num_motif_detectors, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(num_motif_detectors)

        self.conv3 = nn.Conv1d(num_motif_detectors, num_motif_detectors, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(num_motif_detectors)

        self.fc1 = nn.Linear(num_motif_detectors, 128)
        self.bn_fc1 = nn.BatchNorm1d(128)
        self.dropout1 = nn.Dropout(dropout)

        self.fc2 = nn.Linear(128, 128)
        self.bn_fc2 = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(dropout)

        self.fc3 = nn.Linear(128, 64)
        self.bn_fc3 = nn.BatchNorm1d(64)
        self.dropout3 = nn.Dropout(dropout / 2)

        self.fc4 = nn.Linear(64, 1)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)

    def forward_pass(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        identity = x
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = x + identity

        identity = x
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = x + identity

        x, _ = torch.max(x, dim=2)

        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = F.relu(x)
        x = self.dropout1(x)

        identity = x
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = x + identity

        x = self.fc3(x)
        x = self.bn_fc3(x)
        x = F.relu(x)
        x = self.dropout3(x)

        x = self.fc4(x)
        return x

    def forward(self, x):
        r = self.forward_pass(x)
        r_comp = self.forward_pass(reverse_complement(x))
        return torch.max(r, r_comp)

In [None]:
class DeepBindLSTM(nn.Module):
    def __init__(
        self,
        num_motif_detectors: int,
        motif_len: int,
        dropout: float = 0.2,
        lstm_hidden: int = 32,
        fc_hidden: int = 32,
        bidirectional: bool = True,
    ):
        super().__init__()

        padding = motif_len // 2
        self.conv = nn.Conv1d(
            4,
            num_motif_detectors,
            kernel_size=motif_len,
            padding=padding,
        )

        self.relu = nn.ReLU()
        self.dropout_conv = nn.Dropout(dropout * 0.5)

        self.bidirectional = bidirectional
        self.lstm_hidden = lstm_hidden

        self.lstm = nn.LSTM(
            input_size=num_motif_detectors,
            hidden_size=lstm_hidden,
            num_layers=1,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.0,
        )

        lstm_out_dim = lstm_hidden * (2 if bidirectional else 1)

        self.fc1 = nn.Linear(lstm_out_dim, fc_hidden)
        self.ln_fc = nn.LayerNorm(fc_hidden)
        self.drop = nn.Dropout(dropout)
        self.fc2 = nn.Linear(fc_hidden, 1)

        self.init_weights()

    def init_weights(self):
        init.kaiming_normal_(self.conv.weight, nonlinearity="relu")
        init.zeros_(self.conv.bias)

        for name, p in self.lstm.named_parameters():
            if "weight" in name:
                init.xavier_uniform_(p)
            elif "bias" in name:
                p.data.fill_(0)
                n = p.size(0)
                p.data[n // 4 : n // 2].fill_(1.0)

        init.kaiming_normal_(self.fc1.weight, nonlinearity="relu")
        init.zeros_(self.fc1.bias)
        init.normal_(self.fc2.weight, mean=0.0, std=0.01)
        init.zeros_(self.fc2.bias)

    def forward_once(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.dropout_conv(x)

        x = x.transpose(1, 2)

        out, _ = self.lstm(x)

        feat = torch.max(out, dim=1).values

        feat = self.fc1(feat)
        feat = self.ln_fc(feat)
        feat = F.relu(feat)
        feat = self.drop(feat)
        y = self.fc2(feat)

        return y

    def forward(self, x):
        y_fwd = self.forward_once(x)
        y_rc = self.forward_once(reverse_complement(x))
        return torch.max(y_fwd, y_rc)


In [None]:
def get_model(model_name, num_motif_detectors=16, motif_len=24, dropout=0.3):
    models = {
        'original': DeepBindOriginal,
        'deeper': DeepBindDeeper,
        'residual': DeepBindResidual,
        'lstm': DeepBindLSTM,
    }

    if model_name not in models:
        raise ValueError(f"Model {model_name} not found. Choose from {list(models.keys())}")

    model_class = models[model_name]

    if model_name == 'original':
        return model_class(num_motif_detectors, motif_len)
    elif model_name == 'lstm':
        return model_class(num_motif_detectors, motif_len, dropout)
    else:
        return model_class(num_motif_detectors, motif_len, dropout)

### Model Wrapper for training

In [None]:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

class ModelWrapper:
    def __init__(self, model, device=DEVICE, lr=1e-3, weight_decay=1e-5,
                 use_scheduler=True, use_mixup=False, mixup_alpha=0.2):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr,
                                    weight_decay=weight_decay)
        self.criterion = nn.MSELoss()

        self.use_mixup = use_mixup
        self.mixup_alpha = mixup_alpha

        self.use_scheduler = use_scheduler
        if use_scheduler:
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='max', factor=0.5, patience=3
            )

        self.history = {
            'train_loss': [],
            'val_auc': [],
            'val_pearson': [],
            'val_spearman': [],
        }

    def train_step(self, x, target):
        self.model.train()
        x = x.to(self.device)
        target = target.to(self.device)

        if self.use_mixup:
            lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
            batch_size = x.size(0)
            index = torch.randperm(batch_size).to(self.device)
            x = lam * x + (1 - lam) * x[index]
            target = lam * target + (1 - lam) * target[index]

        self.optimizer.zero_grad()
        pred = self.model(x)
        loss = self.criterion(pred.squeeze(), target)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

        self.optimizer.step()

        return loss.item()

    def predict(self, x):
        self.model.eval()
        x = x.to(self.device)
        with torch.no_grad():
            pred = self.model(x).squeeze()
        pred = pred.cpu().numpy()
        return pred

    def evaluate(self, x, y_true, label_true, plot=True):
        y_pred = self.predict(x)

        fpr, tpr, thresholds = roc_curve(label_true, y_pred)
        roc_auc = auc(fpr, tpr)

        pearson_corr, _ = pearsonr(y_true, y_pred)
        spearman_corr, _ = spearmanr(y_true, y_pred)

        if plot:
            plt.figure(figsize=(8, 6))
            plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}", linewidth=2)
            plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
            plt.xlabel("False Positive Rate")
            plt.ylabel("True Positive Rate")
            plt.legend()
            plt.grid(alpha=0.3)
            plt.tight_layout()
            plt.show()

        return {
            'pearson': pearson_corr,
            'spearman': spearman_corr,
            'auc': roc_auc,
        }

    def train_one_epoch(self, loader):
        epoch_loss = 0
        for x, y, label in tqdm(loader):
            loss = self.train_step(x, y)
            epoch_loss += loss
        epoch_loss = epoch_loss / len(loader)
        return epoch_loss

    def train(self, train_loader, x_val, y_val, label_val, epochs=30):
        for epoch in range(1, epochs + 1):
            print(f"===== EPOCH {epoch} =====")

            epoch_loss = self.train_one_epoch(train_loader)
            self.history['train_loss'].append(epoch_loss)

            val_metrics = self.evaluate(x_val, y_val, label_val, plot=True)
            self.history['val_auc'].append(val_metrics['auc'])
            self.history['val_pearson'].append(val_metrics['pearson'])
            self.history['val_spearman'].append(val_metrics['spearman'])

            print(f"Loss: {epoch_loss}")

            if self.use_scheduler:
                self.scheduler.step(val_metrics['auc'])

        return self.history

### Sanity Check: Overfit on single mini-batch

In [None]:
x, y, label = next(iter(train_loader))
x.shape, y.shape, label.shape

In [None]:
m_sanity = get_model(model_name=MODEL_NAME, num_motif_detectors=NUM_MOTIF_DETECTORS, motif_len=MOTIF_LEN, dropout=DROPOUT)

In [None]:
mw_sanity = ModelWrapper(m_sanity, lr=LR, weight_decay=WEIGHT_DECAY)

In [None]:
# for i in range(10000):
#     loss = mw_sanity.train_step(x, y)
#     if i % 1000 == 0:
#         print(loss)

In [None]:
pred_sanity = mw_sanity.predict(x)
pred_sanity[:10]

In [None]:
target_sanity = y.cpu().numpy()
target_sanity[:10]

In [None]:
for name, p in mw_sanity.model.named_parameters():
    print(name, p.data.abs().mean().item())

In [None]:
m = get_model(model_name=MODEL_NAME, num_motif_detectors=NUM_MOTIF_DETECTORS, motif_len=MOTIF_LEN, dropout=DROPOUT)

In [None]:
mw = ModelWrapper(m, lr=LR, weight_decay=WEIGHT_DECAY)

### Sanity Check: Reverse compliment

In [None]:
np.allclose(mw.predict(x), mw.predict(reverse_complement(reverse_complement(x))))

# Training

In [None]:
def train(mw, loader, epochs):
    for epoch in range(1, epochs+1):
        print(f"===== EPOCH {epoch} =====")
        epoch_loss = mw.train_one_epoch(loader)
        mw.evaluate(x_test, y_test, label_test)
        print(f"Loss: {epoch_loss}")

In [None]:
train(mw, train_loader, 30)

In [None]:
mw.evaluate(x_train, y_train, label_train)

In [None]:
mw.evaluate(x_test, y_test, label_test)