In [137]:
import sys
import os
import math

import numpy as np
from numpy.random import shuffle
import scipy
import pandas as pd

from typing import Tuple

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.sampler import Sampler, SequentialSampler
from torch.backends import cudnn
from sklearn.model_selection import train_test_split, KFold

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import wandb


In [138]:
train_PECV = pd.read_csv('data/DeepPrime_PECV__train_220214.csv')
# test_PECV = pd.read_csv('data/DeepPrime_PECV__test_220214.csv')


In [371]:
class GeneInteractionModel(nn.Module):

    def __init__(self, hidden_size, num_layers):
        super(GeneInteractionModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.c1 = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=128, kernel_size=(2, 3), stride=1, padding=(0, 1)),
            nn.GELU(),
        )
        self.c2 = nn.Sequential(
            nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.AvgPool1d(kernel_size=2, stride=2),

            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.AvgPool1d(kernel_size=2, stride=2),

            nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.AvgPool1d(kernel_size=2, stride=2),
        )

        self.r = nn.GRU(32, hidden_size, num_layers,
                        batch_first=True, bidirectional=True)
        
        self.d = nn.Sequential(
            nn.Linear(27, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32), 
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, 16)
        )
        
        self.head = nn.Linear(2 * hidden_size + 16, 1, bias=True)
    
    def forward(self, g, x):
        g = torch.squeeze(self.c1(g), 2)
        g = self.c2(g)
        g, _ = self.r(torch.transpose(g, 1, 2))
        
        x = self.d(x)

        out = self.head(torch.cat((g[:, -1, :], x), dim=1))

        # out = self.head(g[:, 0])

        return out

In [333]:
class GeneFeatureDataset(Dataset):
    
    def __init__(
        self,
        gene: torch.Tensor = None,
        features: torch.Tensor = None,
        target: torch.Tensor = None,
        fold: int = None,
        mode: str = 'train',
        fold_list: np.ndarray = None,
    ):
        self.fold = fold
        self.mode = mode
        self.fold_list = fold_list
        self.indices = self._select_fold()
        self.gene = gene[self.indices]
        self.features = features[self.indices]
        self.target = target[self.indices]

    def _select_fold(self):
        selected_indices = []

        if self.mode == 'train':
            for i in range(len(self.fold_list)):
                if self.fold_list[i] != self.fold:
                    selected_indices.append(i)
        elif self.mode == 'valid':
            for i in range(len(self.fold_list)):
                if self.fold_list[i] == self.fold:
                    selected_indices.append(i)
        else:
            for i in range(len(self.fold_list)):
                selected_indices.append(i)

        return selected_indices
    
    def __len__(self):
        return len(self.gene)
    
    def __getitem__(self, idx:int) -> Tuple[torch.Tensor, torch.Tensor]:
        gene = self.gene[idx]
        features = self.features[idx]
        target = self.target[idx]

        return gene, features, target




In [183]:
def preprocess_seq(data):
    print("Start preprocessing the sequence done 2d")
    length = 74

    DATA_X = np.zeros((len(data), 1, length, 4), dtype=float)
    print(np.shape(data), len(data), length)
    for l in tqdm(range(len(data))):
        for i in range(length):

            try:
                data[l][i]
            except:
                print(data[l], i, length, len(data))

            if data[l][i] in "Aa":
                DATA_X[l, 0, i, 0] = 1
            elif data[l][i] in "Cc":
                DATA_X[l, 0, i, 1] = 1
            elif data[l][i] in "Gg":
                DATA_X[l, 0, i, 2] = 1
            elif data[l][i] in "Tt":
                DATA_X[l, 0, i, 3] = 1
            elif data[l][i] in "Xx":
                DATA_X[l, 0, i, 0] = 0.5
                DATA_X[l, 0, i, 1] = 0.5
                DATA_X[l, 0, i, 2] = 0.5
                DATA_X[l, 0, i, 3] = 0.5
            else:
                print("Non-ATGC character " + data[l])
                print(i)
                print(data[l][i])
                sys.exit()

    print("Preprocessed the sequence")
    return DATA_X


In [143]:
if not os.path.isfile('data/g_train.npy'):
    wt_train = preprocess_seq(train_PECV.WT74_On)
    ed_train = preprocess_seq(train_PECV.Edited74_On)
    g_train = np.concatenate((wt_train, ed_train), axis=1)
    g_train = 2 * g_train - 1

    np.save('data/g_train.npy', g_train)
else:
    g_train = np.load('data/g_train.npy')


Start preprocessing the sequence done 2d
(259910,) 259910 74


100%|██████████| 259910/259910 [01:57<00:00, 2206.04it/s]


Preprocessed the sequence
Start preprocessing the sequence done 2d
(259910,) 259910 74


100%|██████████| 259910/259910 [02:48<00:00, 1539.94it/s]


Preprocessed the sequence


In [145]:
train_features = train_PECV.loc[:, ['PBSlen', 'RTlen', 'RT-PBSlen', 'Edit_pos', 'Edit_len', 'RHA_len', 'type_sub',
                                    'type_ins', 'type_del', 'Tm1', 'Tm2', 'Tm2new', 'Tm3', 'Tm4', 'TmD',
                                    'nGCcnt1', 'nGCcnt2', 'nGCcnt3', 'fGCcont1', 'fGCcont2', 'fGCcont3',
                                    'MFE1', 'MFE2', 'MFE3', 'MFE4', 'MFE5', 'DeepSpCas9_score']]
train_fold = train_PECV.Fold
train_target = train_PECV.Measured_PE_efficiency


In [146]:
x_train = (train_features - train_features.mean()) / train_features.std()
y_train = (train_target - train_target.mean()) / train_target.std()
x_train = x_train.to_numpy()
y_train = y_train.to_numpy()

In [148]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

g_train = torch.tensor(g_train, dtype=torch.float32, device=device)
x_train = torch.tensor(x_train, dtype=torch.float32, device=device)
y_train = torch.tensor(y_train, dtype=torch.float32, device=device)

In [234]:
x_train.shape

torch.Size([259910, 27])

In [374]:
# PARAMS

batch_size = 2048
learning_rate = 4e-3
weight_decay = 1e-2
T_0 = 10
T_mult = 1
hidden_size = 128
n_layers = 1
n_epochs = 10
n_models = 1

In [375]:
# TRAINING & VALIDATION
# preds = np.zeros((n_models, test_y.size(0)))

for m in range(n_models):

    random_seed = m

    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed(random_seed)

    for fold in range(5):
        model = GeneInteractionModel(hidden_size=hidden_size, num_layers=n_layers).to(device)

        train_set = GeneFeatureDataset(g_train, x_train, y_train, fold, 'train', train_fold)
        valid_set = GeneFeatureDataset(g_train, x_train, y_train, fold, 'valid', train_fold)

        train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=0)
        valid_loader = DataLoader(dataset=valid_set, batch_size=batch_size, shuffle=True, num_workers=0)

        criterion = nn.MSELoss()
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=T_0, T_mult=T_mult, eta_min=learning_rate/100)

        n_iters = len(train_loader)

        for epoch in tqdm(range(n_epochs)):
            train_loss, valid_loss = [], []
            train_count, valid_count = 0, 0

            model.train()

            for i, (g, x, y) in enumerate(train_loader):
                g = torch.permute(g, (0, 3, 1, 2))
                y = y.reshape(-1, 1)

                pred = model(g, x)
                loss = criterion(pred, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step(epoch + i / n_iters)

                train_loss.append(x.size(0) * loss.detach().cpu().numpy())
                train_count += x.size(0)

            model.eval()

            pred_, y_ = None, None

            with torch.no_grad():
                for i, (g, x, y) in enumerate(valid_loader):
                    g = torch.permute(g, (0, 3, 1, 2))
                    y = y.reshape(-1, 1)

                    pred = model(g, x)
                    loss = criterion(pred, y)
                    
                    valid_loss.append(x.size(0) * loss.detach().cpu().numpy())
                    valid_count += x.size(0)

                    if pred_ is None:
                        pred_ = pred.detach().cpu().numpy()
                        y_ = y.detach().cpu().numpy()
                    else:
                        pred_ = np.concatenate((pred_, pred.detach().cpu().numpy()))
                        y_ = np.concatenate((y_, y.detach().cpu().numpy()))


            train_loss = sum(train_loss) / train_count
            valid_loss = sum(valid_loss) / valid_count

            SPR = scipy.stats.spearmanr(pred_, y_).correlation

            print('[FOLD {:02}/{:02}] [M {:03}/{:03}] [E {:03}/{:03}] : {:.4f} | {:.4f} | {:.4f}'.format(fold + 1, 5, m + 1,
                n_models, epoch + 1, n_epochs, train_loss, valid_loss, SPR))

 10%|█         | 1/10 [00:03<00:28,  3.21s/it]

[FOLD 01/05] [M 001/001] [E 001/010] : 0.7853 | 0.6470 | 0.6436


 20%|██        | 2/10 [00:06<00:26,  3.26s/it]

[FOLD 01/05] [M 001/001] [E 002/010] : 0.6611 | 0.5893 | 0.6798


 30%|███       | 3/10 [00:09<00:22,  3.23s/it]

[FOLD 01/05] [M 001/001] [E 003/010] : 0.5940 | 0.5330 | 0.7139


 40%|████      | 4/10 [00:12<00:19,  3.25s/it]

[FOLD 01/05] [M 001/001] [E 004/010] : 0.5117 | 0.4659 | 0.7406


 50%|█████     | 5/10 [00:16<00:16,  3.23s/it]

[FOLD 01/05] [M 001/001] [E 005/010] : 0.4501 | 0.4385 | 0.7480


 60%|██████    | 6/10 [00:19<00:12,  3.24s/it]

[FOLD 01/05] [M 001/001] [E 006/010] : 0.4136 | 0.4259 | 0.7572


 70%|███████   | 7/10 [00:22<00:09,  3.24s/it]

[FOLD 01/05] [M 001/001] [E 007/010] : 0.3833 | 0.4222 | 0.7650


 80%|████████  | 8/10 [00:25<00:06,  3.25s/it]

[FOLD 01/05] [M 001/001] [E 008/010] : 0.3631 | 0.4210 | 0.7672


 90%|█████████ | 9/10 [00:29<00:03,  3.23s/it]

[FOLD 01/05] [M 001/001] [E 009/010] : 0.3503 | 0.4234 | 0.7666


100%|██████████| 10/10 [00:32<00:00,  3.23s/it]

[FOLD 01/05] [M 001/001] [E 010/010] : 0.3435 | 0.4222 | 0.7664



 10%|█         | 1/10 [00:03<00:29,  3.32s/it]

[FOLD 02/05] [M 001/001] [E 001/010] : 0.7642 | 0.6278 | 0.6575


 20%|██        | 2/10 [00:06<00:26,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 002/010] : 0.6525 | 0.5882 | 0.6702


 30%|███       | 3/10 [00:09<00:23,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 003/010] : 0.5830 | 0.5090 | 0.7158


 40%|████      | 4/10 [00:13<00:19,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 004/010] : 0.5034 | 0.4789 | 0.7469


 50%|█████     | 5/10 [00:16<00:16,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 005/010] : 0.4356 | 0.4419 | 0.7528


 60%|██████    | 6/10 [00:19<00:13,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 006/010] : 0.3962 | 0.4258 | 0.7668


 70%|███████   | 7/10 [00:23<00:09,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 007/010] : 0.3679 | 0.4248 | 0.7665


 80%|████████  | 8/10 [00:26<00:06,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 008/010] : 0.3483 | 0.4240 | 0.7669


 90%|█████████ | 9/10 [00:29<00:03,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 009/010] : 0.3334 | 0.4216 | 0.7708


100%|██████████| 10/10 [00:32<00:00,  3.30s/it]

[FOLD 02/05] [M 001/001] [E 010/010] : 0.3267 | 0.4230 | 0.7705



 10%|█         | 1/10 [00:03<00:30,  3.34s/it]

[FOLD 03/05] [M 001/001] [E 001/010] : 0.7744 | 0.6488 | 0.6502


 20%|██        | 2/10 [00:06<00:26,  3.34s/it]

[FOLD 03/05] [M 001/001] [E 002/010] : 0.6464 | 0.6232 | 0.6932


 30%|███       | 3/10 [00:09<00:23,  3.30s/it]

[FOLD 03/05] [M 001/001] [E 003/010] : 0.5892 | 0.5302 | 0.7241


 40%|████      | 4/10 [00:13<00:19,  3.33s/it]

[FOLD 03/05] [M 001/001] [E 004/010] : 0.5177 | 0.4768 | 0.7338


 50%|█████     | 5/10 [00:16<00:16,  3.35s/it]

[FOLD 03/05] [M 001/001] [E 005/010] : 0.4482 | 0.4554 | 0.7375


 60%|██████    | 6/10 [00:20<00:13,  3.34s/it]

[FOLD 03/05] [M 001/001] [E 006/010] : 0.4089 | 0.4390 | 0.7489


 70%|███████   | 7/10 [00:23<00:10,  3.34s/it]

[FOLD 03/05] [M 001/001] [E 007/010] : 0.3814 | 0.4260 | 0.7552


 80%|████████  | 8/10 [00:26<00:06,  3.36s/it]

[FOLD 03/05] [M 001/001] [E 008/010] : 0.3597 | 0.4229 | 0.7596


 90%|█████████ | 9/10 [00:30<00:03,  3.37s/it]

[FOLD 03/05] [M 001/001] [E 009/010] : 0.3478 | 0.4172 | 0.7629


100%|██████████| 10/10 [00:33<00:00,  3.35s/it]

[FOLD 03/05] [M 001/001] [E 010/010] : 0.3406 | 0.4186 | 0.7618



 10%|█         | 1/10 [00:03<00:29,  3.25s/it]

[FOLD 04/05] [M 001/001] [E 001/010] : 0.7621 | 0.6725 | 0.6622


 20%|██        | 2/10 [00:06<00:26,  3.29s/it]

[FOLD 04/05] [M 001/001] [E 002/010] : 0.6234 | 0.5870 | 0.7150


 30%|███       | 3/10 [00:09<00:23,  3.31s/it]

[FOLD 04/05] [M 001/001] [E 003/010] : 0.5389 | 0.5389 | 0.7100


 40%|████      | 4/10 [00:13<00:19,  3.32s/it]

[FOLD 04/05] [M 001/001] [E 004/010] : 0.4619 | 0.4912 | 0.7498


 50%|█████     | 5/10 [00:16<00:16,  3.29s/it]

[FOLD 04/05] [M 001/001] [E 005/010] : 0.4162 | 0.5010 | 0.7582


 60%|██████    | 6/10 [00:19<00:13,  3.30s/it]

[FOLD 04/05] [M 001/001] [E 006/010] : 0.3833 | 0.4610 | 0.7633


 70%|███████   | 7/10 [00:23<00:09,  3.28s/it]

[FOLD 04/05] [M 001/001] [E 007/010] : 0.3581 | 0.4547 | 0.7642


 80%|████████  | 8/10 [00:26<00:06,  3.29s/it]

[FOLD 04/05] [M 001/001] [E 008/010] : 0.3398 | 0.4508 | 0.7675


 90%|█████████ | 9/10 [00:29<00:03,  3.29s/it]

[FOLD 04/05] [M 001/001] [E 009/010] : 0.3283 | 0.4518 | 0.7689


100%|██████████| 10/10 [00:32<00:00,  3.29s/it]

[FOLD 04/05] [M 001/001] [E 010/010] : 0.3211 | 0.4516 | 0.7704



 10%|█         | 1/10 [00:03<00:30,  3.44s/it]

[FOLD 05/05] [M 001/001] [E 001/010] : 0.7728 | 0.5593 | 0.6356


 20%|██        | 2/10 [00:06<00:27,  3.40s/it]

[FOLD 05/05] [M 001/001] [E 002/010] : 0.6389 | 0.4985 | 0.6656


 30%|███       | 3/10 [00:10<00:23,  3.42s/it]

[FOLD 05/05] [M 001/001] [E 003/010] : 0.5540 | 0.4489 | 0.6867


 40%|████      | 4/10 [00:13<00:20,  3.43s/it]

[FOLD 05/05] [M 001/001] [E 004/010] : 0.4874 | 0.4323 | 0.7007


 50%|█████     | 5/10 [00:17<00:17,  3.40s/it]

[FOLD 05/05] [M 001/001] [E 005/010] : 0.4281 | 0.4052 | 0.7190


 60%|██████    | 6/10 [00:20<00:13,  3.39s/it]

[FOLD 05/05] [M 001/001] [E 006/010] : 0.3911 | 0.4012 | 0.7165


 70%|███████   | 7/10 [00:23<00:10,  3.38s/it]

[FOLD 05/05] [M 001/001] [E 007/010] : 0.3644 | 0.3973 | 0.7281


 80%|████████  | 8/10 [00:27<00:06,  3.39s/it]

[FOLD 05/05] [M 001/001] [E 008/010] : 0.3453 | 0.3912 | 0.7257


 90%|█████████ | 9/10 [00:30<00:03,  3.40s/it]

[FOLD 05/05] [M 001/001] [E 009/010] : 0.3318 | 0.3905 | 0.7268


100%|██████████| 10/10 [00:33<00:00,  3.40s/it]

[FOLD 05/05] [M 001/001] [E 010/010] : 0.3251 | 0.3915 | 0.7276



