In [2]:
import sys
import os
import math

import numpy as np
from numpy.random import shuffle
import scipy
import pandas as pd

from typing import Tuple

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.sampler import Sampler, SequentialSampler
from torch.backends import cudnn
from sklearn.model_selection import train_test_split, KFold

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import wandb

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [3]:
class GeneInteractionModel(nn.Module):

    def __init__(self, hidden_size, num_layers):
        super(GeneInteractionModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.c1 = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=72,
                      kernel_size=(2, 3), stride=1, padding=(0, 1)),
            nn.GELU(),
        )
        self.c2 = nn.Sequential(
            nn.Conv1d(in_channels=72, out_channels=64,
                      kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.AvgPool1d(kernel_size=2, stride=2),

            nn.Conv1d(in_channels=64, out_channels=64,
                      kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.AvgPool1d(kernel_size=2, stride=2),

            nn.Conv1d(in_channels=64, out_channels=96,
                      kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.AvgPool1d(kernel_size=2, stride=2),
        )

        self.r = nn.GRU(96, hidden_size, num_layers,
                        batch_first=True, bidirectional=True)

        self.s = nn.Linear(2 * hidden_size, 24, bias=False)
        
        self.d = nn.Sequential(
            nn.Linear(27, 96, bias=False),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(96, 32, bias=False), 
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, 64, bias=False)
        )

        self.head = nn.Sequential(
            # nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(88, 1, bias=True),
        )

    def forward(self, g, x):
        g = torch.squeeze(self.c1(g), 2)
        g = self.c2(g)
        g, _ = self.r(torch.transpose(g, 1, 2))
        g = self.s(g[:, -1, :])
        
        x = self.d(x)

        out = self.head(torch.cat((g, x), dim=1))

        return out


In [4]:
class GeneFeatureDataset(Dataset):

    def __init__(
        self,
        gene: torch.Tensor = None,
        features: torch.Tensor = None,
        target: torch.Tensor = None,
        fold: int = None,
        mode: str = 'train',
        fold_list: np.ndarray = None,
    ):
        self.fold = fold
        self.mode = mode
        self.fold_list = fold_list

        if self.fold_list is not None:
            self.indices = self._select_fold()
            self.gene = gene[self.indices]
            self.features = features[self.indices]
            self.target = target[self.indices]
        else:
            self.gene = gene
            self.features = features
            self.target = target

    def _select_fold(self):
        selected_indices = []

        if self.mode == 'valid':  # SELECT A SINGLE GROUP
            for i in range(len(self.fold_list)):
                if self.fold_list[i] == self.fold:
                    selected_indices.append(i)
        elif self.mode == 'train':  # SELECT OTHERS
            for i in range(len(self.fold_list)):
                if self.fold_list[i] != self.fold:
                    selected_indices.append(i)
        else:  # FOR FINALIZING
            for i in range(len(self.fold_list)):
                selected_indices.append(i)

        return selected_indices

    def __len__(self):
        return len(self.gene)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        gene = self.gene[idx]
        features = self.features[idx]
        target = self.target[idx]

        return gene, features, target


In [5]:
def preprocess_seq(data):
    print("Start preprocessing the sequence done 2d")
    length = 74

    DATA_X = np.zeros((len(data), 1, length, 4), dtype=float)
    print(np.shape(data), len(data), length)
    for l in tqdm(range(len(data))):
        for i in range(length):

            try:
                data[l][i]
            except:
                print(data[l], i, length, len(data))

            if data[l][i] in "Aa":
                DATA_X[l, 0, i, 0] = 1
            elif data[l][i] in "Cc":
                DATA_X[l, 0, i, 1] = 1
            elif data[l][i] in "Gg":
                DATA_X[l, 0, i, 2] = 1
            elif data[l][i] in "Tt":
                DATA_X[l, 0, i, 3] = 1
            elif data[l][i] in "Xx":
                DATA_X[l, 0, i, 0] = 0.5
                DATA_X[l, 0, i, 1] = 0.5
                DATA_X[l, 0, i, 2] = 0.5
                DATA_X[l, 0, i, 3] = 0.5
            else:
                print("Non-ATGC character " + data[l])
                print(i)
                print(data[l][i])
                sys.exit()

    print("Preprocessed the sequence")
    return DATA_X


In [6]:
def seq_concat(data):
    wt = preprocess_seq(data.WT74_On)
    ed = preprocess_seq(data.Edited74_On)
    g = np.concatenate((wt, ed), axis=1)
    g = 2 * g - 1

    return g


In [7]:
train_PECV = pd.read_csv('data/DeepPrime_PECV__train_220214.csv')
test_PECV = pd.read_csv('data/DeepPrime_PECV__test_220214.csv')
train_PF = pd.read_csv('data/DeepPrime_on_target_test_wProfiling_220209.csv')


In [8]:
# PREPROCESS GENES

if not os.path.isfile('data/g_train.npy'):
    g_train = seq_concat(train_PECV)
    np.save('data/g_train.npy', g_train)
else:
    g_train = np.load('data/g_train.npy')

if not os.path.isfile('data/g_test.npy'):
    g_test = seq_concat(test_PECV)
    np.save('data/g_test.npy', g_test)
else:
    g_test = np.load('data/g_test.npy')

if not os.path.isfile('data/g_pf.npy'):
    g_pf = seq_concat(train_PF)
    np.save('data/g_pf.npy', g_pf)
else:
    g_pf = np.load('data/g_pf.npy')


In [9]:
# FEATURE SELECTION

train_features = train_PECV.loc[:, ['PBSlen', 'RTlen', 'RT-PBSlen', 'Edit_pos', 'Edit_len', 'RHA_len', 'type_sub',
                                    'type_ins', 'type_del', 'Tm1', 'Tm2', 'Tm2new', 'Tm3', 'Tm4', 'TmD',
                                    'nGCcnt1', 'nGCcnt2', 'nGCcnt3', 'fGCcont1', 'fGCcont2', 'fGCcont3',
                                    'MFE1', 'MFE2', 'MFE3', 'MFE4', 'MFE5', 'DeepSpCas9_score']]
train_fold = train_PECV.Fold
train_target = train_PECV.Measured_PE_efficiency

test_features = test_PECV.loc[:, ['PBSlen', 'RTlen', 'RT-PBSlen', 'Edit_pos', 'Edit_len', 'RHA_len', 'type_sub',
                                  'type_ins', 'type_del', 'Tm1', 'Tm2', 'Tm2new', 'Tm3', 'Tm4', 'TmD',
                                  'nGCcnt1', 'nGCcnt2', 'nGCcnt3', 'fGCcont1', 'fGCcont2', 'fGCcont3',
                                  'MFE1', 'MFE2', 'MFE3', 'MFE4', 'MFE5', 'DeepSpCas9_score']]
test_target = test_PECV.Measured_PE_efficiency

pf_features = train_PF.loc[:, ['PBSlen', 'RTlen', 'RT-PBSlen', 'Edit_pos', 'Edit_len', 'RHA_len', 'type_sub',
                               'type_ins', 'type_del', 'Tm1', 'Tm2', 'Tm2new', 'Tm3', 'Tm4', 'TmD',
                               'nGCcnt1', 'nGCcnt2', 'nGCcnt3', 'fGCcont1', 'fGCcont2', 'fGCcont3',
                               'MFE1', 'MFE2', 'MFE3', 'MFE4', 'MFE5', 'DeepSpCas9_score']]
pf_target = train_PF.Measured_PE_efficiency


In [10]:
# NORMALIZATION

x_train = (train_features - train_features.mean()) / train_features.std()
y_train = (train_target - train_target.mean()) / train_target.std()
x_train = x_train.to_numpy()
y_train = y_train.to_numpy()

x_test = (test_features - train_features.mean()) / train_features.std()
y_test = (test_target - train_target.mean()) / train_target.std()
x_test = x_test.to_numpy()
y_test = y_test.to_numpy()

x_pf = (pf_features - train_features.mean()) / train_features.std()
y_pf = (pf_target - train_target.mean()) / train_target.std()
x_pf = x_pf.to_numpy()
y_pf = y_pf.to_numpy()

g_train = torch.tensor(g_train, dtype=torch.float32, device=device)
x_train = torch.tensor(x_train, dtype=torch.float32, device=device)
y_train = torch.tensor(y_train, dtype=torch.float32, device=device)

g_test = torch.tensor(g_test, dtype=torch.float32, device=device)
x_test = torch.tensor(x_test, dtype=torch.float32, device=device)
y_test = torch.tensor(y_test, dtype=torch.float32, device=device)

g_pf = torch.tensor(g_pf, dtype=torch.float32, device=device)
x_pf = torch.tensor(x_pf, dtype=torch.float32, device=device)
y_pf = torch.tensor(y_pf, dtype=torch.float32, device=device)


In [11]:
# PARAMS

batch_size = 2048
learning_rate = 4e-3
weight_decay = 1e-2
T_0 = 12
T_mult = 1
hidden_size = 128
n_layers = 1
n_epochs = 10
n_models = 1

In [12]:
def finetune_model(model, fold, pf_loader, valid_loader):

    # PARAMETERS FOR FINETUNING

    learning_rate = 2e-5
    weight_decay = 1e-3
    T_0 = 5
    T_mult = 1
    n_epochs = 5

    # for name, param in model.named_parameters():
    #     if param.requires_grad and name.startswith('c') or name.startswith('r'):
    #         param.requires_grad = False # LOCK THE GENE ABSTRACTION MODULE

    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=T_0, T_mult=T_mult, eta_min=learning_rate/100)

    n_iters = len(pf_loader)

    for epoch in tqdm(range(n_epochs)):
        train_loss, valid_loss = [], []
        train_count, valid_count = 0, 0

        model.train()

        for i, (g, x, y) in enumerate(pf_loader):
            g = torch.permute(g, (0, 3, 1, 2))
            y = y.reshape(-1, 1)

            pred = model(g, x)
            loss = criterion(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step(epoch + i / n_iters)

            train_loss.append(x.size(0) * loss.detach().cpu().numpy())
            train_count += x.size(0)

        model.eval()

        pred_, y_ = None, None

        with torch.no_grad():
            for i, (g, x, y) in enumerate(valid_loader):
                g = torch.permute(g, (0, 3, 1, 2))
                y = y.reshape(-1, 1)

                pred = model(g, x)
                loss = criterion(pred, y)

                valid_loss.append(x.size(0) * loss.detach().cpu().numpy())
                valid_count += x.size(0)

                if pred_ is None:
                    pred_ = pred.detach().cpu().numpy()
                    y_ = y.detach().cpu().numpy()
                else:
                    pred_ = np.concatenate(
                        (pred_, pred.detach().cpu().numpy()))
                    y_ = np.concatenate((y_, y.detach().cpu().numpy()))

        train_loss = sum(train_loss) / train_count
        valid_loss = sum(valid_loss) / valid_count

        SPR = scipy.stats.spearmanr(pred_, y_).correlation

        print('FINETUNING: [FOLD {:02}/{:02}] [E {:03}/{:03}] : {:.4f} | {:.4f} | {:.4f}'.format(
            fold + 1, 5, epoch + 1, n_epochs, train_loss, valid_loss, SPR))
    
    return model


In [13]:
# TRAINING & VALIDATION

for m in range(n_models):

    random_seed = m

    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed(random_seed)

    for fold in range(5):

        best_score = [0, 0, 0]

        model = GeneInteractionModel(
            hidden_size=hidden_size, num_layers=n_layers).to(device)

        train_set = GeneFeatureDataset(
            g_train, x_train, y_train, fold, 'train', train_fold)
        valid_set = GeneFeatureDataset(
            g_train, x_train, y_train, fold, 'valid', train_fold)
        pf_set = GeneFeatureDataset(g_pf, x_pf, y_pf, None, 'train', None)

        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=0)
        valid_loader = DataLoader(
            dataset=valid_set, batch_size=batch_size, shuffle=True, num_workers=0)
        pf_loader = DataLoader(
            dataset=pf_set, batch_size=batch_size, shuffle=True, num_workers=0)

        criterion = nn.MSELoss()
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=T_0, T_mult=T_mult, eta_min=learning_rate/100)

        n_iters = len(train_loader)

        for epoch in tqdm(range(n_epochs)):
            train_loss, valid_loss = [], []
            train_count, valid_count = 0, 0

            model.train()

            for i, (g, x, y) in enumerate(train_loader):
                g = torch.permute(g, (0, 3, 1, 2))
                y = y.reshape(-1, 1)

                pred = model(g, x)
                loss = criterion(pred, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step(epoch + i / n_iters)

                train_loss.append(x.size(0) * loss.detach().cpu().numpy())
                train_count += x.size(0)

            model.eval()

            pred_, y_ = None, None

            with torch.no_grad():
                for i, (g, x, y) in enumerate(valid_loader):
                    g = torch.permute(g, (0, 3, 1, 2))
                    y = y.reshape(-1, 1)

                    pred = model(g, x)
                    loss = criterion(pred, y)

                    valid_loss.append(x.size(0) * loss.detach().cpu().numpy())
                    valid_count += x.size(0)

                    if pred_ is None:
                        pred_ = pred.detach().cpu().numpy()
                        y_ = y.detach().cpu().numpy()
                    else:
                        pred_ = np.concatenate(
                            (pred_, pred.detach().cpu().numpy()))
                        y_ = np.concatenate((y_, y.detach().cpu().numpy()))

            train_loss = sum(train_loss) / train_count
            valid_loss = sum(valid_loss) / valid_count

            SPR = scipy.stats.spearmanr(pred_, y_).correlation

            if SPR > best_score[2]:
                best_score = [train_loss, valid_loss, SPR]

                torch.save(model.state_dict(), 'models/F{:02}_auxiliary.pt'.format(fold + 1))


            print('[FOLD {:02}/{:02}] [M {:03}/{:03}] [E {:03}/{:03}] : {:.4f} | {:.4f} | {:.4f}'.format(fold + 1, 5, m + 1,
                                                                                                         n_models, epoch + 1, n_epochs, train_loss, valid_loss, SPR))
        
        os.rename('models/F{:02}_auxiliary.pt'.format(fold + 1), 'models/F{:02}_T{:.4f}_V{:.4f}_S{:.4f}.pt'.format(fold + 1, *best_score))

        # finetune_model(model, fold, pf_loader, valid_loader)


 10%|█         | 1/10 [00:03<00:27,  3.01s/it]

[FOLD 01/05] [M 001/001] [E 001/010] : 0.7719 | 0.6546 | 0.6841


 20%|██        | 2/10 [00:05<00:23,  2.99s/it]

[FOLD 01/05] [M 001/001] [E 002/010] : 0.6669 | 0.5923 | 0.6851


 30%|███       | 3/10 [00:08<00:20,  2.95s/it]

[FOLD 01/05] [M 001/001] [E 003/010] : 0.5929 | 0.5261 | 0.7124


 40%|████      | 4/10 [00:11<00:17,  2.95s/it]

[FOLD 01/05] [M 001/001] [E 004/010] : 0.5139 | 0.4610 | 0.7418


 50%|█████     | 5/10 [00:14<00:14,  2.96s/it]

[FOLD 01/05] [M 001/001] [E 005/010] : 0.4545 | 0.4436 | 0.7598


 60%|██████    | 6/10 [00:17<00:11,  2.96s/it]

[FOLD 01/05] [M 001/001] [E 006/010] : 0.4182 | 0.4366 | 0.7591


 70%|███████   | 7/10 [00:20<00:08,  2.95s/it]

[FOLD 01/05] [M 001/001] [E 007/010] : 0.3900 | 0.4193 | 0.7627


 80%|████████  | 8/10 [00:23<00:05,  2.96s/it]

[FOLD 01/05] [M 001/001] [E 008/010] : 0.3677 | 0.4155 | 0.7733


 90%|█████████ | 9/10 [00:26<00:02,  2.97s/it]

[FOLD 01/05] [M 001/001] [E 009/010] : 0.3525 | 0.4144 | 0.7714


100%|██████████| 10/10 [00:29<00:00,  2.96s/it]

[FOLD 01/05] [M 001/001] [E 010/010] : 0.3360 | 0.4268 | 0.7702



 10%|█         | 1/10 [00:03<00:27,  3.01s/it]

[FOLD 02/05] [M 001/001] [E 001/010] : 0.7597 | 0.6339 | 0.6614


 20%|██        | 2/10 [00:06<00:24,  3.02s/it]

[FOLD 02/05] [M 001/001] [E 002/010] : 0.6457 | 0.5795 | 0.7014


 30%|███       | 3/10 [00:09<00:20,  2.99s/it]

[FOLD 02/05] [M 001/001] [E 003/010] : 0.5718 | 0.5208 | 0.7124


 40%|████      | 4/10 [00:11<00:17,  2.99s/it]

[FOLD 02/05] [M 001/001] [E 004/010] : 0.5124 | 0.4764 | 0.7364


 50%|█████     | 5/10 [00:14<00:14,  2.97s/it]

[FOLD 02/05] [M 001/001] [E 005/010] : 0.4563 | 0.4485 | 0.7523


 60%|██████    | 6/10 [00:17<00:11,  2.98s/it]

[FOLD 02/05] [M 001/001] [E 006/010] : 0.4201 | 0.4380 | 0.7516


 70%|███████   | 7/10 [00:20<00:08,  2.98s/it]

[FOLD 02/05] [M 001/001] [E 007/010] : 0.3920 | 0.4202 | 0.7650


 80%|████████  | 8/10 [00:23<00:05,  2.98s/it]

[FOLD 02/05] [M 001/001] [E 008/010] : 0.3715 | 0.4275 | 0.7581


 90%|█████████ | 9/10 [00:26<00:02,  2.99s/it]

[FOLD 02/05] [M 001/001] [E 009/010] : 0.3531 | 0.4211 | 0.7650


100%|██████████| 10/10 [00:29<00:00,  2.99s/it]

[FOLD 02/05] [M 001/001] [E 010/010] : 0.3387 | 0.4194 | 0.7697



 10%|█         | 1/10 [00:03<00:27,  3.02s/it]

[FOLD 03/05] [M 001/001] [E 001/010] : 0.7646 | 0.6602 | 0.6671


 20%|██        | 2/10 [00:05<00:23,  2.98s/it]

[FOLD 03/05] [M 001/001] [E 002/010] : 0.6581 | 0.5812 | 0.6858


 30%|███       | 3/10 [00:09<00:21,  3.00s/it]

[FOLD 03/05] [M 001/001] [E 003/010] : 0.5647 | 0.5149 | 0.7215


 40%|████      | 4/10 [00:12<00:18,  3.02s/it]

[FOLD 03/05] [M 001/001] [E 004/010] : 0.4828 | 0.4593 | 0.7327


 50%|█████     | 5/10 [00:15<00:15,  3.02s/it]

[FOLD 03/05] [M 001/001] [E 005/010] : 0.4350 | 0.4406 | 0.7533


 60%|██████    | 6/10 [00:18<00:11,  3.00s/it]

[FOLD 03/05] [M 001/001] [E 006/010] : 0.3995 | 0.4346 | 0.7607


 70%|███████   | 7/10 [00:21<00:08,  3.00s/it]

[FOLD 03/05] [M 001/001] [E 007/010] : 0.3756 | 0.4256 | 0.7643


 80%|████████  | 8/10 [00:24<00:06,  3.00s/it]

[FOLD 03/05] [M 001/001] [E 008/010] : 0.3519 | 0.4230 | 0.7632


 90%|█████████ | 9/10 [00:26<00:02,  2.99s/it]

[FOLD 03/05] [M 001/001] [E 009/010] : 0.3336 | 0.4255 | 0.7632


100%|██████████| 10/10 [00:30<00:00,  3.00s/it]

[FOLD 03/05] [M 001/001] [E 010/010] : 0.3195 | 0.4223 | 0.7673



 10%|█         | 1/10 [00:03<00:27,  3.07s/it]

[FOLD 04/05] [M 001/001] [E 001/010] : 0.7557 | 0.7025 | 0.6977


 20%|██        | 2/10 [00:06<00:24,  3.05s/it]

[FOLD 04/05] [M 001/001] [E 002/010] : 0.6490 | 0.6281 | 0.6880


 30%|███       | 3/10 [00:09<00:21,  3.05s/it]

[FOLD 04/05] [M 001/001] [E 003/010] : 0.5720 | 0.5585 | 0.7172


 40%|████      | 4/10 [00:12<00:18,  3.01s/it]

[FOLD 04/05] [M 001/001] [E 004/010] : 0.5072 | 0.5207 | 0.7351


 50%|█████     | 5/10 [00:15<00:15,  3.02s/it]

[FOLD 04/05] [M 001/001] [E 005/010] : 0.4475 | 0.4736 | 0.7531


 60%|██████    | 6/10 [00:18<00:12,  3.04s/it]

[FOLD 04/05] [M 001/001] [E 006/010] : 0.4090 | 0.4606 | 0.7631


 70%|███████   | 7/10 [00:21<00:09,  3.04s/it]

[FOLD 04/05] [M 001/001] [E 007/010] : 0.3826 | 0.4510 | 0.7735


 80%|████████  | 8/10 [00:24<00:06,  3.03s/it]

[FOLD 04/05] [M 001/001] [E 008/010] : 0.3578 | 0.4524 | 0.7749


 90%|█████████ | 9/10 [00:27<00:03,  3.01s/it]

[FOLD 04/05] [M 001/001] [E 009/010] : 0.3413 | 0.4427 | 0.7758


100%|██████████| 10/10 [00:30<00:00,  3.03s/it]

[FOLD 04/05] [M 001/001] [E 010/010] : 0.3272 | 0.4441 | 0.7738



 10%|█         | 1/10 [00:03<00:27,  3.11s/it]

[FOLD 05/05] [M 001/001] [E 001/010] : 0.7677 | 0.5695 | 0.6365


 20%|██        | 2/10 [00:06<00:24,  3.10s/it]

[FOLD 05/05] [M 001/001] [E 002/010] : 0.6518 | 0.5017 | 0.6670


 30%|███       | 3/10 [00:09<00:21,  3.09s/it]

[FOLD 05/05] [M 001/001] [E 003/010] : 0.5596 | 0.4485 | 0.6862


 40%|████      | 4/10 [00:12<00:18,  3.09s/it]

[FOLD 05/05] [M 001/001] [E 004/010] : 0.4889 | 0.4164 | 0.6972


 50%|█████     | 5/10 [00:15<00:15,  3.10s/it]

[FOLD 05/05] [M 001/001] [E 005/010] : 0.4360 | 0.3970 | 0.7201


 60%|██████    | 6/10 [00:18<00:12,  3.08s/it]

[FOLD 05/05] [M 001/001] [E 006/010] : 0.4071 | 0.3858 | 0.7250


 70%|███████   | 7/10 [00:21<00:09,  3.08s/it]

[FOLD 05/05] [M 001/001] [E 007/010] : 0.3821 | 0.3835 | 0.7311


 80%|████████  | 8/10 [00:24<00:06,  3.09s/it]

[FOLD 05/05] [M 001/001] [E 008/010] : 0.3616 | 0.3823 | 0.7185


 90%|█████████ | 9/10 [00:27<00:03,  3.09s/it]

[FOLD 05/05] [M 001/001] [E 009/010] : 0.3453 | 0.3899 | 0.7259


100%|██████████| 10/10 [00:30<00:00,  3.09s/it]

[FOLD 05/05] [M 001/001] [E 010/010] : 0.3324 | 0.3775 | 0.7306





In [16]:
# MODEL FINE TUNING

models = ['models/F01_T0.3677_V0.4155_S0.7733.pt', 'models/F02_T0.3387_V0.4194_S0.7697.pt',
          'models/F03_T0.3195_V0.4223_S0.7673.pt', 'models/F04_T0.3413_V0.4427_S0.7758.pt', 'models/F05_T0.3821_V0.3835_S0.7311.pt']

pf_set = GeneFeatureDataset(g_pf, x_pf, y_pf, None, 'train', None)
pf_loader = DataLoader(
    dataset=pf_set, batch_size=batch_size, shuffle=True, num_workers=0)

for fold in range(5):
    valid_set = GeneFeatureDataset(
        g_train, x_train, y_train, fold, 'valid', train_fold)
    valid_loader = DataLoader(
        dataset=valid_set, batch_size=batch_size, shuffle=True, num_workers=0)

    model = GeneInteractionModel(
            hidden_size=hidden_size, num_layers=n_layers).to(device)
    
    model.load_state_dict(torch.load(models[fold]))

    model = finetune_model(model, fold, pf_loader, valid_loader)

    torch.save(model.state_dict(), 'models/final/F{:02}.pt'.format(fold + 1))

 20%|██        | 1/5 [00:01<00:04,  1.25s/it]

FINETUNING: [FOLD 01/05] [E 001/005] : 0.9098 | 0.4304 | 0.7748


 40%|████      | 2/5 [00:02<00:03,  1.25s/it]

FINETUNING: [FOLD 01/05] [E 002/005] : 0.8909 | 0.4287 | 0.7753


 60%|██████    | 3/5 [00:03<00:02,  1.25s/it]

FINETUNING: [FOLD 01/05] [E 003/005] : 0.8786 | 0.4302 | 0.7755


 80%|████████  | 4/5 [00:04<00:01,  1.25s/it]

FINETUNING: [FOLD 01/05] [E 004/005] : 0.8702 | 0.4304 | 0.7755


100%|██████████| 5/5 [00:06<00:00,  1.25s/it]

FINETUNING: [FOLD 01/05] [E 005/005] : 0.8687 | 0.4307 | 0.7755



 20%|██        | 1/5 [00:01<00:04,  1.21s/it]

FINETUNING: [FOLD 02/05] [E 001/005] : 0.9079 | 0.4368 | 0.7716


 40%|████      | 2/5 [00:02<00:03,  1.17s/it]

FINETUNING: [FOLD 02/05] [E 002/005] : 0.8778 | 0.4360 | 0.7725


 60%|██████    | 3/5 [00:03<00:02,  1.23s/it]

FINETUNING: [FOLD 02/05] [E 003/005] : 0.8633 | 0.4345 | 0.7726


 80%|████████  | 4/5 [00:04<00:01,  1.19s/it]

FINETUNING: [FOLD 02/05] [E 004/005] : 0.8519 | 0.4365 | 0.7728


100%|██████████| 5/5 [00:05<00:00,  1.20s/it]

FINETUNING: [FOLD 02/05] [E 005/005] : 0.8530 | 0.4363 | 0.7728



 20%|██        | 1/5 [00:01<00:04,  1.18s/it]

FINETUNING: [FOLD 03/05] [E 001/005] : 0.9385 | 0.4383 | 0.7698


 40%|████      | 2/5 [00:02<00:03,  1.13s/it]

FINETUNING: [FOLD 03/05] [E 002/005] : 0.9090 | 0.4393 | 0.7707


 60%|██████    | 3/5 [00:03<00:02,  1.15s/it]

FINETUNING: [FOLD 03/05] [E 003/005] : 0.8900 | 0.4393 | 0.7711


 80%|████████  | 4/5 [00:04<00:01,  1.13s/it]

FINETUNING: [FOLD 03/05] [E 004/005] : 0.8795 | 0.4409 | 0.7712


100%|██████████| 5/5 [00:05<00:00,  1.14s/it]

FINETUNING: [FOLD 03/05] [E 005/005] : 0.8802 | 0.4409 | 0.7712



 20%|██        | 1/5 [00:01<00:04,  1.06s/it]

FINETUNING: [FOLD 04/05] [E 001/005] : 0.9053 | 0.4489 | 0.7774


 40%|████      | 2/5 [00:02<00:03,  1.07s/it]

FINETUNING: [FOLD 04/05] [E 002/005] : 0.8740 | 0.4508 | 0.7777


 60%|██████    | 3/5 [00:03<00:02,  1.09s/it]

FINETUNING: [FOLD 04/05] [E 003/005] : 0.8646 | 0.4495 | 0.7777


 80%|████████  | 4/5 [00:04<00:01,  1.08s/it]

FINETUNING: [FOLD 04/05] [E 004/005] : 0.8576 | 0.4492 | 0.7777


100%|██████████| 5/5 [00:05<00:00,  1.09s/it]

FINETUNING: [FOLD 04/05] [E 005/005] : 0.8538 | 0.4494 | 0.7777



 20%|██        | 1/5 [00:00<00:03,  1.06it/s]

FINETUNING: [FOLD 05/05] [E 001/005] : 0.9381 | 0.3909 | 0.7326


 40%|████      | 2/5 [00:01<00:02,  1.06it/s]

FINETUNING: [FOLD 05/05] [E 002/005] : 0.9081 | 0.3938 | 0.7329


 60%|██████    | 3/5 [00:02<00:01,  1.10it/s]

FINETUNING: [FOLD 05/05] [E 003/005] : 0.8933 | 0.3941 | 0.7330


 80%|████████  | 4/5 [00:03<00:00,  1.08it/s]

FINETUNING: [FOLD 05/05] [E 004/005] : 0.8865 | 0.3939 | 0.7330


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]

FINETUNING: [FOLD 05/05] [E 005/005] : 0.8843 | 0.3938 | 0.7331





In [17]:
test_set = GeneFeatureDataset(g_test, x_test, y_test)
test_loader = DataLoader(
    dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=0)

preds = []

for fold in range(5):
    model = GeneInteractionModel(
        hidden_size=hidden_size, num_layers=n_layers).to(device)
    
    model.load_state_dict(torch.load('models/final/F{:02}.pt'.format(fold + 1)))

    pred_, y_ = None, None

    model.eval()
    with torch.no_grad():
        for i, (g, x, y) in enumerate(test_loader):
            g = torch.permute(g, (0, 3, 1, 2))
            y = y.reshape(-1, 1)

            pred = model(g, x)

            if pred_ is None:
                pred_ = pred.detach().cpu().numpy()
                y_ = y.detach().cpu().numpy()
            else:
                pred_ = np.concatenate(
                    (pred_, pred.detach().cpu().numpy()))
                y_ = np.concatenate((y_, y.detach().cpu().numpy()))
    
    preds.append(pred_)
    SPR = scipy.stats.spearmanr(pred_, y_).correlation

preds = np.squeeze(np.array(preds))
preds = np.mean(preds, axis=0)

print(scipy.stats.spearmanr(preds, y_).correlation)

preds = preds * train_target.std() + train_target.mean()
y_ = y_ * train_target.std() + train_target.mean()

preds = pd.DataFrame(preds, columns=['Predicted PE efficiency'])
preds.to_csv('results/220218.csv', index=False)


0.8417103535717387


In [19]:
_, ax = plt.subplots(figsize=(6, 6))

ax.scatter(y_, preds, s=0.1)
ax.set_xlim([-1, 31])
ax.set_ylim([-1, 31])
ax.set_xticks(range(0, 35, 5))
ax.set_yticks(range(0, 35, 5))

ax.set_title("Evaluation of DeepPE2")
ax.set_xlabel("Measured PE2 efficiency (%)")
ax.set_ylabel("DeepPE prediction score (%)")

ax.annotate('R = 0.8417',
            xy=(1, 0), xycoords='axes fraction',
            xytext=(-20, 20), textcoords='offset pixels',
            horizontalalignment='right',
            verticalalignment='bottom')

plt.savefig('Evaluation of DeepPE2.jpg', bbox_inches="tight", dpi=200)
plt.close()