In [1]:
import itertools
import os
import random
import shutil
import time
from tqdm.notebook import tqdm
import uuid
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import matplotlib.pyplot as plt

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

# Config

In [32]:
TRAIN_ON_TF = 'TF_2'
TRAIN_PROBE = 'A'
TEST_PROBE = 'B'
MOTIF_LEN = 24
NUM_MOTIF_DETECTORS=16
BATCH_SIZE = 64

In [4]:
DNA_BASES = 'ACGT'
RNA_BASES = 'ACGU'

In [5]:
PBM_DATA = "/home/vazzu/Desktop/bioinfo/compbio-final/data/dream5/pbm"

# Data preparation

### Read DREAM5 sequence data with binding scores

In [6]:
df_seq = pd.read_csv(f"{PBM_DATA}/sequences.tsv", sep='\t')
df_seq.head()

Unnamed: 0,Fold ID,Event ID,seq
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT


In [7]:
df_targets = pd.read_csv(f"{PBM_DATA}/targets.tsv", sep='\t')
df_targets.head()

Unnamed: 0,TF_40,TF_41,TF_42,TF_43,TF_44,TF_45,TF_46,TF_47,TF_48,TF_49,...,C_19,C_18,C_15,C_14,C_17,C_16,C_11,C_10,C_13,C_12
0,823.914118,12702.625538,2124.023125,2314.305782,1474.888697,1131.785521,4597.003319,14589.890994,1556.951404,34180.775942,...,1651.254953,1242.303363,724.10585,3184.883349,8935.394363,12689.558779,4102.312624,505.126184,12946.381724,1313.790253
1,1307.840222,4316.426121,2554.658908,3415.320661,3408.586803,1697.342725,5272.763446,22903.130555,2181.551097,10000.297243,...,3505.604759,2516.00012,1640.114829,3463.713253,19535.468264,18006.72169,6890.427794,1402.597,38309.856355,3024.107809
2,1188.353499,3436.803941,2088.909658,3708.324021,2219.741833,1571.646567,6225.376501,13858.014077,1971.053716,18800.025304,...,3270.572883,1693.419147,997.792996,3196.992198,16695.027604,14486.992627,13517.968701,10680.866586,25648.825592,2675.530918
3,1806.103795,6531.268855,2406.186212,3601.204703,2828.415329,2746.861783,5810.10465,25701.749693,2191.273065,19213.880658,...,2701.555739,2059.614815,1432.163042,4927.163643,18896.765835,18784.043322,8608.167421,4624.044391,23651.726053,3679.449867
4,1417.411525,3951.243575,2581.309532,3375.884699,2764.716964,1806.919566,5033.976283,26364.859152,2311.790793,16139.097553,...,2457.214141,1901.709222,1672.531034,3877.787322,14699.253953,17119.871513,8995.328144,12641.425965,27999.405431,3128.844808


### Build a dataframe for single transcription factor

In [8]:
def build_df(tf, df_seq, df_targets):
    df = df_seq.copy()
    df["Target"] = df_targets[tf].values
    return df

In [9]:
df = build_df(TRAIN_ON_TF, df_seq, df_targets)
df.head()

Unnamed: 0,Fold ID,Event ID,seq,Target
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT,2941.528352
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT,6646.004089
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT,4606.883308
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT,8126.753206
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT,7935.937598


### Remove probe specific biases for each sequence

In [10]:
biases = df_targets.median(axis=1).values
biases

array([2641.988164 , 4170.2475   , 3699.8877625, ..., 2287.3149815,
       1884.915039 , 1231.9427485], shape=(80856,))

In [11]:
df['Target'] = df['Target'].values / biases
df.head()

Unnamed: 0,Fold ID,Event ID,seq,Target
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT,1.113377
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT,1.593671
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT,1.245141
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT,1.758181
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT,2.046512


### Calculate NA content and drop rows

In [12]:
# Only 4% NA content, it is safe to drop the rows
df['Target'].isna().mean()

np.float64(0.04049173839912932)

In [13]:
df.dropna(subset=['Target'], inplace=True)
df['Target'].isna().mean()

np.float64(0.0)

### Train/Test data split

In [14]:
df_train = df[df['Fold ID'] == TRAIN_PROBE]
df_train.head()

Unnamed: 0,Fold ID,Event ID,seq,Target
40526,A,HK26479,ACGCGCGTCAGCTTTTTGGAATATTGCGGAGAGTTCCTGT,0.713724
40527,A,HK26478,GACTATGGGATGGGCCGCCTTTGATTACGCGCGTCCCTGT,1.0
40528,A,HK22998,CCTTCGTGAGCCATGTGTTTCAGGCTGTGCGTGTCCCTGT,0.79075
40529,A,HK22999,GGCGGGTGGTAAAGGCCCCGGAAGCGGACACGCACCCTGT,0.622457
40530,A,HK26473,GTTGTGGTTTGTCCTTTTGTATTAACAGTGTATGGCCTGT,0.626426


In [15]:
df_test = df[df['Fold ID'] == TEST_PROBE]
df_test.head()

Unnamed: 0,Fold ID,Event ID,seq,Target
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT,1.113377
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT,1.593671
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT,1.245141
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT,1.758181
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT,2.046512


In [16]:
df_train.shape, df_test.shape

((38041, 4), (39541, 4))

### DNA/RNA sequence to Matrix logic

In [17]:
def fill_cell(motif_len, row, col, bases, seq):
    num_rows = len(seq) + 2 * motif_len - 2
    
    # First M-1 rows are filled with 0.25
    if row < motif_len-1:
        return 0.25

    # Last M-1 rows are filled with 0.25
    if num_rows-1-row < motif_len-1:
        return 0.25

    idx = row - motif_len + 1
    if seq[idx] == bases[col]:
        return 1.0

    return 0.0

def seq2matrix(seq, motif_len, typ='DNA'):
    bases = DNA_BASES if typ == 'DNA' else RNA_BASES
    num_rows = len(seq) + 2 * motif_len - 2
    result = np.empty([num_rows, 4])
    for row in range(num_rows):
        for col in range(4):
            result[row, col] = fill_cell(motif_len, row, col, bases, seq)
    return np.transpose(result)

In [18]:
# Test the function
S = seq2matrix("ATGG", 3, 'DNA')
S

array([[0.25, 0.25, 1.  , 0.  , 0.  , 0.  , 0.25, 0.25],
       [0.25, 0.25, 0.  , 0.  , 0.  , 0.  , 0.25, 0.25],
       [0.25, 0.25, 0.  , 0.  , 1.  , 1.  , 0.25, 0.25],
       [0.25, 0.25, 0.  , 1.  , 0.  , 0.  , 0.25, 0.25]])

In [19]:
S.shape

(4, 8)

### Sequence Dataset and Loader

In [20]:
class SeqDataset(Dataset):
    def __init__(self, df):
        self.sequences = df["seq"].values
        self.targets = df["Target"].values

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        
        M = seq2matrix(seq, MOTIF_LEN, 'DNA')
        
        x = torch.tensor(M, dtype=torch.float32)
        y = torch.tensor(self.targets[idx], dtype=torch.float32)
        
        return x, y

In [23]:
train_dataset = SeqDataset(df_train)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True
)

In [135]:
x_train = torch.stack([x for x, y in train_dataset])
y_train = torch.tensor([y for x, y in train_dataset], dtype=torch.float32)

In [25]:
x, target = next(iter(train_loader))
x.shape, target.shape

(torch.Size([64, 4, 86]), torch.Size([64]))

In [112]:
test_dataset = SeqDataset(df_test)
x_test = torch.stack([x for x, y in test_dataset])
y_test = torch.tensor([y for x, y in test_dataset], dtype=torch.float32)

In [113]:
x_test.shape, y_test.shape

(torch.Size([39541, 4, 86]), torch.Size([39541]))

# Model

### DeepBind Model

In [41]:
import torch.nn as nn
import torch.nn.functional as F

class DeepBindShallow(nn.Module):
    def __init__(self, num_motif_detectors, motif_len):
        super().__init__()

        self.conv = nn.Conv1d(in_channels=4, out_channels=num_motif_detectors, kernel_size=motif_len)
        self.fc = nn.Linear(num_motif_detectors, 1)

    def forward(self, x):
        
        x = self.conv(x)
        x = F.relu(x)
        x, _ = torch.max(x, dim=2)
        x = self.fc(x)

        return x

In [156]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class DeepBind(nn.Module):
    def __init__(self, num_motif_detectors, motif_len):
        super().__init__()

        self.conv = nn.Conv1d(in_channels=4, out_channels=num_motif_detectors, kernel_size=motif_len)
        self.fc1 = nn.Linear(num_motif_detectors, 32)
        self.fc2 = nn.Linear(32, 1)

        self.init_weights(self.conv)
        self.init_weights(self.fc1)
        self.init_weights(self.fc2)

    def init_weights(self, component):
        init.kaiming_normal_(component.weight, nonlinearity='relu')
        init.zeros_(component.bias)
    
    def forward(self, x):
        
        x = self.conv(x)
        x = F.relu(x)
        x, _ = torch.max(x, dim=2)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)

        return x

### Model Wrapper for training

In [157]:
from scipy.stats import pearsonr, spearmanr

class ModelWrapper:
    def __init__(self, m):
        self.model = m.to(DEVICE)
        self.opt = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.criterion = torch.nn.MSELoss()

    def train_step(self, x, target):
        self.model.train()
        
        x = x.to(DEVICE)
        target = target.to(DEVICE)

        self.opt.zero_grad()

        pred = self.model(x)
        loss = self.criterion(pred.squeeze(), target)

        loss.backward()
        self.opt.step()
        
        return loss.item()

    def predict(self, x):
        self.model.eval()
        x = x.to(DEVICE)
        pred = self.model(x).squeeze()
        pred = pred.detach().cpu().numpy()
        return pred

    def evaluate(self, x, y_true):
        y_pred = self.predict(x)
        
        pearson_corr, _ = pearsonr(y_true, y_pred)
        spearman_corr, _ = spearmanr(y_true, y_pred)

        return {
            'pearson': pearson_corr,
            'spearman': spearman_corr,
        }

    def train_one_epoch(self, loader):
        epoch_loss = 0
        for i, (x, y) in enumerate(tqdm(loader)):
            epoch_loss += self.train_step(x, y)
        epoch_loss = epoch_loss / len(loader)
        return epoch_loss

### Sanity Check: Overfit on single mini-batch

In [158]:
x, y = next(iter(train_loader))
x.shape, y.shape

(torch.Size([64, 4, 86]), torch.Size([64]))

In [67]:
m_sanity = DeepBindShallow(num_motif_detectors=NUM_MOTIF_DETECTORS, motif_len=MOTIF_LEN)

In [68]:
mw_sanity = ModelWrapper(m_sanity)

In [69]:
for i in range(10000):
    loss = mw_sanity.train_step(x, y)
    if i % 1000 == 0:
        print(loss)

0.9674995541572571
1.17877929639576e-13
5.3013149425851225e-15
1.6375789613221059e-15
1.4710455076283324e-15
2.3037127760972e-15
8.076872504148014e-15
2.6922908347160046e-15
4.3021142204224816e-15
7.355227538141662e-15


In [84]:
pred_sanity = mw_sanity.predict(x).squeeze().detach().cpu().numpy()
pred_sanity[:10]

array([0.52148795, 1.5151204 , 0.8145001 , 0.7346965 , 0.79065657,
       0.876678  , 0.53783476, 0.78819394, 0.43525562, 0.79170513],
      dtype=float32)

In [83]:
target_sanity = y.cpu().numpy()
target_sanity[:10]

array([0.521488  , 1.5151204 , 0.8145001 , 0.7346965 , 0.79065657,
       0.876678  , 0.5378348 , 0.7881938 , 0.43525565, 0.791705  ],
      dtype=float32)

In [91]:
for name, p in mw_sanity.model.named_parameters():
    print(name, p.data.abs().mean().item())

conv.weight 0.06720635294914246
conv.bias 0.04288124665617943
fc.weight 0.1710122972726822
fc.bias 0.187648743391037


# Training

In [159]:
def train(mw, loader, epochs):
    for epoch in range(1, epochs+1):
        print(f"===== EPOCH {epoch} =====")
        epoch_loss = mw.train_one_epoch(loader)
        print(f"Loss: {epoch_loss}")

In [131]:
m = DeepBindShallow(num_motif_detectors=NUM_MOTIF_DETECTORS, motif_len=MOTIF_LEN)

In [160]:
m = DeepBind(num_motif_detectors=NUM_MOTIF_DETECTORS, motif_len=MOTIF_LEN)

In [161]:
mw = ModelWrapper(m)

In [162]:
train(mw, train_loader, 30)

===== EPOCH 1 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 1.9820068546516054
===== EPOCH 2 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 1.6470541815296942
===== EPOCH 3 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 1.159998936669416
===== EPOCH 4 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.790402569152227
===== EPOCH 5 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.5830417976033788
===== EPOCH 6 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.47954962704675036
===== EPOCH 7 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.42584325829912134
===== EPOCH 8 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.3773121386578604
===== EPOCH 9 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.35072723782989157
===== EPOCH 10 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.31905806282175186
===== EPOCH 11 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.29785882313148815
===== EPOCH 12 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.2768552167043716
===== EPOCH 13 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.2636920803015222
===== EPOCH 14 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.24967472308761432
===== EPOCH 15 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.217961900875348
===== EPOCH 16 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.20959082186660347
===== EPOCH 17 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.1930756077139067
===== EPOCH 18 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.18597091944778668
===== EPOCH 19 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.17144461934368652
===== EPOCH 20 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.16562955916679206
===== EPOCH 21 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.1537855327849378
===== EPOCH 22 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.15149174677599378
===== EPOCH 23 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.13767513518511248
===== EPOCH 24 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.13881803313661523
===== EPOCH 25 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.13244946813132583
===== EPOCH 26 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.12720774313231475
===== EPOCH 27 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.12713246807894285
===== EPOCH 28 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.1290014826825687
===== EPOCH 29 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.11805249273839619
===== EPOCH 30 =====


  0%|          | 0/595 [00:00<?, ?it/s]

Loss: 0.11314517335981882


In [164]:
mw.evaluate(x_test, y_test)

{'pearson': np.float32(0.54699504), 'spearman': np.float64(0.3566680088846224)}

In [166]:
mw.evaluate(x_train, y_train)

{'pearson': np.float32(0.9737026), 'spearman': np.float64(0.5306100004864945)}