In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

In [2]:
df_aaindex = pd.read_csv('../data/aaindex/df_aaindex19.csv')
print(df_aaindex.shape)
df_aaindex.head(1)
tmp = df_aaindex.drop('Unnamed: 0',axis=1).T
aa2val = dict()
for aa, val in zip(tmp.index, tmp.values):
    aa2val[aa]=val

(19, 21)


In [3]:
df_detect_peptide_train = pd.read_csv('../data/df_detect_peptide_train.csv')
test = pd.read_csv('../data/df_detect_peptide_test.csv')
train, val = train_test_split(df_detect_peptide_train, test_size=0.2, random_state=7)

df = pd.concat([train, val, test], axis=0).reset_index(drop=True)

train_idx = df.iloc[:len(train), :].index
val_idx = df.iloc[len(train):len(train)+len(val), :].index
test_idx = df.iloc[len(train)+len(val):, :].index

train.head(1)

Unnamed: 0,peptide,En,Ec,E1,E2,protein,PEP,ID
595411,K.QELNEPPKQSTSFLVLQEILESEEKGDPNK.P,VYKMLQEKQELNEPP,EEKGDPNKPSGFRSV,QELNEPPKQSTSFLV,EILESEEKGDPNKPS,sp|O00151|PDLI1_HUMAN,QELNEPPKQSTSFLVLQEILESEEKGDPNK,0


In [4]:
data = [[n[:8]+p+c[8:], prot, lab] for _, n, c, __, ___, prot, p, lab in train.values]
train_new = pd.DataFrame(data, columns=['peptide', 'protein', 'ID'])

data = [[n[:8]+p+c[8:], prot, lab] for _, n, c, __, ___, prot, p, lab in val.values]
val_new = pd.DataFrame(data, columns=['peptide', 'protein', 'ID'])

data = [[n[:8]+p+c[8:], prot, lab] for _, n, c, __, ___, prot, p, lab in test.values]
test_new = pd.DataFrame(data, columns=['peptide', 'protein', 'ID'])

In [5]:
train_new.head(1)

Unnamed: 0,peptide,protein,ID
0,VYKMLQEKQELNEPPKQSTSFLVLQEILESEEKGDPNKPSGFRSV,sp|O00151|PDLI1_HUMAN,0


# torch gpu

In [6]:
import sys
PATH_TO_REPO = "/home/bis/2021_AIhub/esm/"
sys.path.append(PATH_TO_REPO)

import torch
import esm

from tqdm.notebook import tqdm
from tqdm import tqdm, tqdm_notebook
import time
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from transformers import AdamW
# from transformers.optimization import get_cosine_schedule_with_warmup
from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup

In [7]:
##GPU 사용 시
device = torch.device("cuda")

In [8]:
esm_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

In [9]:
ct = 0
for child in esm_model.children():
    ct += 1
#     print(ct, child)
#     if ct < 7:
    for param in child.parameters():
        param.requires_grad = False

In [10]:
class ESMDataset(Dataset):
    def __init__(self, datasets, idxes):
        pep_idx, label_idx = idxes
        pep_data = [(label, seq) for label, seq in zip(datasets[:, label_idx], datasets[:, pep_idx])]
        labels, pep_strs, pep_tokens = batch_converter(pep_data)

        self.sentences = pep_tokens
        self.labels = labels

    def __getitem__(self, i):
        return ((self.sentences[i], ) + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

# [PAD] = 1, [MASK] = 21  [CLS] = 0 (special classification token), [SEP] = 2 (seperate segment), Z = 27, '-' = 30, .=29, ,=28
# J 없음
# A 2, B 25, C 23, D 13, E 9, F 18, G 6, H 21, I 12, K 15, L 4, M 20, N 17, 
# O 28, P 14, Q 16, R 10, S 8, T 11, U 26, V 7, W 22, X 24, Y 19, Z 27
# 3 5 없음

In [11]:
## Setting parameters
max_len = 30
batch_size = 256
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200
learning_rate =  1e-4

In [13]:
s = time.time()

dataset_train = train_new[['peptide', 'ID']].values
dataset_valid = val_new[['peptide', 'ID']].values
dataset_test = test_new[['peptide', 'ID']].values

data_train = ESMDataset(dataset_train, [0, 1])
data_valid = ESMDataset(dataset_valid, [0, 1])
data_test = ESMDataset(dataset_test, [0, 1])

train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=48)
valid_dataloader = torch.utils.data.DataLoader(data_valid, batch_size=batch_size, num_workers=48)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=48)

e = time.time()
print(round(e-s, 2),'sec')

224.47 sec


In [14]:
for batch_id, (pep_token_ids, label) in enumerate(train_dataloader):
    print(batch_id)
    print(pep_token_ids)
    print(label)
    break

0
tensor([[ 0,  7, 19,  ...,  8,  7,  2],
        [ 0, 10,  5,  ...,  1,  1,  1],
        [ 0, 14, 10,  ...,  1,  1,  1],
        ...,
        [ 0,  5,  4,  ...,  1,  1,  1],
        [ 0,  9,  4,  ...,  1,  1,  1],
        [ 0,  9,  4,  ...,  1,  1,  1]])
tensor([0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1,
        0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0,
        0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
        1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0,
        1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0,
        0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0,
        1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 1

# model

In [13]:
class ESMClassifier(nn.Module):
    def __init__(self,
                 esm,
                 num_classes=1,
                 params=None):
        
        super(ESMClassifier, self).__init__()
        self.esm = esm
        self.pep_lstm1 = nn.LSTM(input_size=1280, hidden_size=1, batch_first=True)        
        self.ts_lstm1 = nn.LSTM(input_size=1280, hidden_size=1, batch_first=True)
                
        self.fc1 = nn.Linear(5, 1)


#     def gen_attention_mask(self, token_ids, valid_length):
#         attention_mask = torch.zeros_like(token_ids)
#         for i, v in enumerate(valid_length):
#             attention_mask[i][:v] = 1
#         return attention_mask.float()

    def forward(self, token_ids):
#         attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        pep_dig_embed = self.esm(token_ids, repr_layers=[33])['representations'][33]
        
        pad_ids = torch.tensor([[0, 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2]]*len(token_ids)).long().to(device)
        pad_vec = self.esm(pad_ids, repr_layers=[33])['representations'][33]  # 15mer
        pad_vec = torch.stack([pad_vec[row_idx][1:-1] for row_idx in range(len(pad_vec))])
        
        pep_embed = torch.stack([pep_dig_embed[row_idx][8:-7] for row_idx in range(len(pep_dig_embed))])
        en_embed = torch.stack([pep_dig_embed[row_idx][:15] for row_idx in range(len(pep_dig_embed))])
        ec_embed = torch.stack([pep_dig_embed[row_idx][-15:] for row_idx in range(len(pep_dig_embed))])
        
        lpad = torch.tensor([0]*8).long().to(device)
        rpad = torch.tensor([0]*8).long().to(device)
        pep_ids = torch.stack([torch.cat([lpad, token_ids[row_idx][8:-8], rpad]) for row_idx in range(len(token_ids))])  # __KR에서 마지막 ts 제외한 ids.
        e1_embed = []
        e2_embed = []
        for row_idx, pep_id in enumerate(pep_ids):
            k_idx = (pep_id==15).nonzero(as_tuple=True)[0]
            r_idx = (pep_id==10).nonzero(as_tuple=True)[0]
            ts_idx = torch.sort(torch.cat([k_idx, r_idx],dim=0))[0]  # 0 : nterm, 1: m1, 2: m2, 3: cterm, if len(ts_idx)==4
            
            if len(ts_idx)==1:
                e1 = pep_dig_embed[row_idx][ts_idx[0]-7:ts_idx[0]+8]  # 15mer (15*1280 shape)
                e2 = pad_vec[row_idx]  # 15*1280
                # e2_embed = torch.tensor([[0]*1280]*15).long().to(device)
            elif len(ts_idx)==2:
                e1 = pep_dig_embed[row_idx][ts_idx[0]-7:ts_idx[0]+8]
                e2 = pep_dig_embed[row_idx][ts_idx[1]-7:ts_idx[1]+8]
            else:
                e1 = pad_vec[row_idx]
                e2 = pad_vec[row_idx]
            e1_embed.append(e1)
            e2_embed.append(e2)
            
        e1_embed = torch.stack(e1_embed)
        e2_embed = torch.stack(e2_embed)
        
        pep_lstm, (pep_hn, __) = self.pep_lstm1(pep_embed)
        en_lstm, (en_hn, __) = self.ts_lstm1(en_embed)
        ec_lstm, (ec_hn, __) = self.ts_lstm1(ec_embed)
        e1_lstm, (e1_hn, __) = self.ts_lstm1(e1_embed)
        e2_lstm, (e2_hn, __) = self.ts_lstm1(e2_embed)
        
        merge = torch.cat([pep_hn[0], en_hn[0], ec_hn[0], e1_hn[0], e2_hn[0]], dim=1)

        merge = self.fc1(merge)        
        out = torch.sigmoid(merge)
        return out

In [14]:
model = ESMClassifier(esm_model).to(device)

In [15]:
model.eval()

ESMClassifier(
  (esm): ProteinBertModel(
    (embed_tokens): Embedding(33, 1280, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1280, out_features=5120, bias=True)
        (fc2): Linear(in_features=5120, out_features=1280, bias=True)
        (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (v_proj): Linear(in_fe

In [16]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+------------------------+------------+
|        Modules         | Parameters |
+------------------------+------------+
| pep_lstm1.weight_ih_l0 |    5120    |
| pep_lstm1.weight_hh_l0 |     4      |
|  pep_lstm1.bias_ih_l0  |     4      |
|  pep_lstm1.bias_hh_l0  |     4      |
| ts_lstm1.weight_ih_l0  |    5120    |
| ts_lstm1.weight_hh_l0  |     4      |
|  ts_lstm1.bias_ih_l0   |     4      |
|  ts_lstm1.bias_hh_l0   |     4      |
|       fc1.weight       |     5      |
|        fc1.bias        |     1      |
+------------------------+------------+
Total Trainable Params: 10270


10270

In [17]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCELoss()
# loss_fn = F.binary_cross_entropy()

t_total = len(train_dataloader) * num_epochs
# warmup_step = int(t_total * warmup_ratio)

# scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps=warmup_step, t_total=t_total)

def calc_accuracy(X,Y):
    train_acc = ((X>0.5)==Y).sum().data.cpu().numpy() / len(Y)
    return train_acc

In [19]:
best_acc = 0
for e in range(num_epochs):
    t0 = time.time()
    train_acc = 0.0
    test_acc = 0.0
    
    model.train()
    for batch_id, (pep_token_ids, label) in enumerate(train_dataloader):
#         print(batch_id, round(time.time()-t0, 2))  # batch256->2100 loop, each 5 sec -> per 1 epoch, 3h
        
        pep_token_ids = pep_token_ids.long().to(device)
        label = torch.reshape(label.float(), (-1, 1)).to(device)
        
        pred = model(pep_token_ids)
        loss = F.binary_cross_entropy(pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_acc += calc_accuracy(pred, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {} time {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1), round(time.time()-t0, 2)))
        
    print("epoch {} train acc {} time {}".format(e+1, train_acc / (batch_id+1), round(time.time()-t0,2)))
    
    model.eval()
    for batch_id, (pep_token_ids, label) in enumerate(valid_dataloader):
        pep_token_ids = pep_token_ids.long().to(device)
        label = label.long().to(device)
        label = torch.reshape(label, (-1, 1))
        pred = model(pep_token_ids)
        
        test_acc += calc_accuracy(pred, label)
    test_acc = test_acc / (batch_id+1)
    if test_acc > best_acc:
        best_acc=test_acc
        print(f"best_acc: {best_acc}")
        torch.save({"best_acc":best_acc, "model":model.state_dict()},f'./finetuning_211221.pl')
    print("epoch {} test acc {}".format(e+1, test_acc))

epoch 1 batch id 1 loss 0.6745380163192749 train acc 0.65625 time 7.3
epoch 1 batch id 201 loss 0.6721905469894409 train acc 0.5851990049751243 time 586.32
epoch 1 batch id 401 loss 0.6573207974433899 train acc 0.5945293017456359 time 1180.67
epoch 1 batch id 601 loss 0.6337117552757263 train acc 0.6046563019966722 time 1781.37
epoch 1 batch id 801 loss 0.6262001991271973 train acc 0.611725577403246 time 2380.62
epoch 1 batch id 1001 loss 0.6685328483581543 train acc 0.6163367882117882 time 2979.72
epoch 1 batch id 1201 loss 0.6001158952713013 train acc 0.6222711542464613 time 3578.73
epoch 1 batch id 1401 loss 0.6021649837493896 train acc 0.6285939730549608 time 4173.06
epoch 1 batch id 1601 loss 0.6278054714202881 train acc 0.6342300710493441 time 4765.72
epoch 1 batch id 1801 loss 0.6127115488052368 train acc 0.6391392802609661 time 5364.42
epoch 1 batch id 2001 loss 0.6508327126502991 train acc 0.643426333708146 time 5963.24
epoch 1 train acc 0.6458929727434759 time 6329.8
best_acc

epoch 8 batch id 1801 loss 0.5259740352630615 train acc 0.7258445828706275 time 14599.82
epoch 8 batch id 2001 loss 0.5849611759185791 train acc 0.725414636431784 time 16220.93
epoch 8 train acc 0.7256116849105462 time 17210.5
best_acc: 0.7243402777777778
epoch 8 test acc 0.7243402777777778
epoch 9 batch id 1 loss 0.5039335489273071 train acc 0.77734375 time 12.73
epoch 9 batch id 201 loss 0.5269186496734619 train acc 0.7282143967661692 time 1646.51
epoch 9 batch id 401 loss 0.5371017456054688 train acc 0.7269034445137157 time 3261.51
epoch 9 batch id 601 loss 0.514401912689209 train acc 0.7272839538269551 time 4880.08
epoch 9 batch id 801 loss 0.5709639191627502 train acc 0.7274256788389513 time 6500.87
epoch 9 batch id 1001 loss 0.5772814750671387 train acc 0.7267068868631369 time 8128.12
epoch 9 batch id 1201 loss 0.5546128749847412 train acc 0.7266373074521232 time 9742.69
epoch 9 batch id 1401 loss 0.5579428672790527 train acc 0.7269138115631691 time 11355.86
epoch 9 batch id 1601

# 211220 - pep / ts seperate

In [16]:
best_acc = 0
for e in range(num_epochs):
    t0 = time.time()
    train_acc = 0.0
    test_acc = 0.0
    
    model.train()
    for batch_id, (pep_token_ids, label) in enumerate(train_dataloader):
#         print(batch_id, round(time.time()-t0, 2))  # batch256->2100 loop, each 5 sec -> per 1 epoch, 3h
        
        pep_token_ids = pep_token_ids.long().to(device)
        label = torch.reshape(label.float(), (-1, 1)).to(device)
        
        pred = model(pep_token_ids)
        loss = F.binary_cross_entropy(pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_acc += calc_accuracy(pred, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {} time {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1), round(time.time()-t0, 2)))
        
    print("epoch {} train acc {} time {}".format(e+1, train_acc / (batch_id+1), round(time.time()-t0,2)))
    
    model.eval()
    for batch_id, (pep_token_ids, label) in enumerate(valid_dataloader):
        pep_token_ids = pep_token_ids.long().to(device)
        label = label.long().to(device)
        label = torch.reshape(label, (-1, 1))
        pred = model(pep_token_ids)
        
        test_acc += calc_accuracy(pred, label)
    if test_acc > best_acc:
        best_acc=test_acc
        torch.save({"best_acc":best_acc / (batch_id+1),"model":model.state_dict()},f'./finetuning.pl')
        print(f"best_acc: {best_acc}")
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

epoch 1 batch id 1 loss 0.7282492518424988 train acc 0.5078125 time 10.04
epoch 1 batch id 201 loss 0.6480684280395508 train acc 0.4993198072139303 time 2702.1
epoch 1 batch id 401 loss 0.6338706016540527 train acc 0.5020554083541147 time 5607.34
epoch 1 batch id 601 loss 0.5915693044662476 train acc 0.5678556572379367 time 8516.95
epoch 1 batch id 801 loss 0.6123077869415283 train acc 0.6009968008739076 time 11428.16
epoch 1 batch id 1001 loss 0.6052110195159912 train acc 0.6203171828171828 time 14324.08
epoch 1 batch id 1201 loss 0.5544191598892212 train acc 0.6339085917985012 time 17219.76
epoch 1 batch id 1401 loss 0.5964603424072266 train acc 0.6426603765167738 time 20251.29
epoch 1 batch id 1601 loss 0.5929796695709229 train acc 0.6501771353841349 time 23284.44
epoch 1 batch id 1801 loss 0.5663135051727295 train acc 0.6554800284564131 time 26316.99
epoch 1 batch id 2001 loss 0.6105095148086548 train acc 0.6596974950024987 time 29351.44
epoch 1 train acc 0.6622013300376648 time 31

epoch 8 batch id 1801 loss 0.48304346203804016 train acc 0.7543465435868961 time 25036.29
epoch 8 batch id 2001 loss 0.5375452041625977 train acc 0.7543611006996501 time 27808.03
epoch 8 train acc 0.754491346768227 time 29500.04
best_acc: 400.52453125
epoch 8 test acc 0.754283486346516
epoch 9 batch id 1 loss 0.4690837264060974 train acc 0.75390625 time 14.26
epoch 9 batch id 201 loss 0.4643460214138031 train acc 0.755733053482587 time 2164.47
epoch 9 batch id 401 loss 0.5001388788223267 train acc 0.7550557200748129 time 4316.5
epoch 9 batch id 601 loss 0.4926818609237671 train acc 0.7554466514143094 time 6466.44
epoch 9 batch id 801 loss 0.5099382996559143 train acc 0.7557837858926342 time 8617.96
epoch 9 batch id 1001 loss 0.49767738580703735 train acc 0.7550496378621379 time 10769.21
epoch 9 batch id 1201 loss 0.4714639186859131 train acc 0.7552820566194838 time 12920.19
epoch 9 batch id 1401 loss 0.5621832609176636 train acc 0.7552334270164168 time 15071.85
epoch 9 batch id 1601 lo