In [1]:
import math
import torch.nn as nn
import argparse
import random
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.autograd import Variable
import itertools
import pandas as pd
from torch.nn import TransformerEncoder, TransformerEncoderLayer
# seed = 7
# torch.manual_seed(seed)
# np.random.seed(seed)

pfamA_motors = pd.read_csv("../../data/pfamA_motors.csv")
df_dev = pd.read_csv("../../data/df_dev.csv")
pfamA_motors = pfamA_motors.iloc[:,1:]
clan_train_dat = pfamA_motors.groupby("clan").head(4000)
clan_train_dat = clan_train_dat.sample(frac=1).reset_index(drop=True)
clan_test_dat = pfamA_motors.loc[~pfamA_motors["id"].isin(clan_train_dat["id"]),:].groupby("clan").head(400)

clan_train_dat.shape

def df_to_tup(dat):
    data = []
    for i in range(dat.shape[0]):
        row = dat.iloc[i,:]
        tup = (row["seq"],row["clan"])
        data.append(tup)
    return data

clan_training_data = df_to_tup(clan_train_dat)
clan_test_data = df_to_tup(clan_test_dat)
for seq,clan in clan_training_data:
    print(seq)
    print(clan)
    break

aminoacid_list = [
    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
    'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'
]
clan_list = ["actin_like","tubulin_c","tubulin_binding","p_loop_gtpase"]
        
aa_to_ix = dict(zip(aminoacid_list, np.arange(1, 21)))
clan_to_ix = dict(zip(clan_list, np.arange(0, 4)))

def word_to_index(seq,to_ix):
    "Returns a list of indices (integers) from a list of words."
    return [to_ix.get(word, 0) for word in seq]

ix_to_aa = dict(zip(np.arange(1, 21), aminoacid_list))
ix_to_clan = dict(zip(np.arange(0, 4), clan_list))

def index_to_word(ixs,ix_to): 
    "Returns a list of words, given a list of their corresponding indices."
    return [ix_to.get(ix, 'X') for ix in ixs]

def prepare_sequence(seq):
    idxs = word_to_index(seq[0:-1],aa_to_ix)
    return torch.tensor(idxs, dtype=torch.long)

def prepare_labels(seq):
    idxs = word_to_index(seq[1:],aa_to_ix)
    return torch.tensor(idxs, dtype=torch.long)

prepare_labels('YCHXXXXX')



KTIAINAGSSSLKWQLYQMPDEIVLAKGLIERIGLHQSKSTVKFNGQSESQTVDIPDHTKAVKVLLDDLLRLNIIDSYQEITGIGHRIVAGGEYFNQSTVVGEKELALIEELSALAPLHNPGAAAGIRAFMELLPGVTSVAVFDTAFHTTMKDYTYLYPIPRKYYNELKVRKYGAHGTSHQYVAQEAAKLLGKPLDQLKLITAHIGNGVSITANYHGESVDTSMGFTPLAGPMMGTRSGDIDPAIIPYLIANDDELNDAADVIDMLNKKSGLGGVSEISSDMRDIEDGLQAKNKDAVLAYNMFIDRIKKFIGQYLAVLNGADAIVFTAGMGENGYLMRQDVIEAMSWFGMKLDPEKNVFGYHGEISTPDSLIKVLVIPTDEELMIAR
actin_like


tensor([2, 7, 0, 0, 0, 0, 0])

In [2]:
# set device
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
class PositionalEncoding(nn.Module):
    """
    PositionalEncoding module injects some information about the relative or absolute position of
    the tokens in the sequence. The positional encodings have the same dimension as the embeddings 
    so that the two can be summed. Here, we use sine and cosine functions of different frequencies.
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
#         x = x + self.pe[:x.size(0), :]
#         print("x.size() : ", x.size())
#         print("self.pe.size() :", self.pe[:x.size(0),:,:].size())
        x = torch.add(x ,Variable(self.pe[:x.size(0),:,:], requires_grad=False))
        return self.dropout(x)

In [4]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        if self.src_mask is None or self.src_mask.size(0) != src.size(0):
            device = src.device
            mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
            self.src_mask = mask
#         print("src.device: ", src.device)
        src = self.encoder(src) * math.sqrt(self.ninp)
#         print("self.encoder(src) size: ", src.size())
        src = self.pos_encoder(src)
#         print("elf.pos_encoder(src) size: ", src.size())
        output = self.transformer_encoder(src, self.src_mask)
#         print("output size: ", output.size())
        output = self.decoder(output)
        return output

In [5]:
ntokens = len(aminoacid_list) + 1 # the size of vocabulary
emsize = 12 # embedding dimension
nhid = 100 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 6 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 12 # the number of heads in the multiheadattention models
dropout = 0.1 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout)

In [6]:
criterion = nn.CrossEntropyLoss()
lr = 3.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

model.to(device)
model.train() # Turn on the train mode

TransformerModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=100, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=100, out_features=12, bias=True)
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=100, bias=T

In [7]:
import time

In [11]:
start_time = time.time()
print_every = 1
# loss_vector = []

for epoch in np.arange(0, df_dev.shape[0]): 
    seq = df_dev.iloc[epoch, 6]
    sentence_in = prepare_sequence(seq)
    targets = prepare_labels(seq)
#     sentence_in = sentence_in.to(device = device)
    sentence_in = sentence_in.unsqueeze(1).to(device = device)
    targets = targets.to(device = device)
    
    optimizer.zero_grad()
    output = model(sentence_in)
    
    print("targets size: ", targets.size())
    print(output.view(-1, ntokens).size())
    print(targets.size())
    loss = criterion(output.view(-1, ntokens), targets)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if epoch % print_every == 0:
        print(f"At Epoch: %.1f"% epoch)
        print(f"Loss %.4f"% loss)
#     loss_vector.append(loss)
    break

targets size:  torch.Size([115])
torch.Size([115, 21])
torch.Size([115])
At Epoch: 0.0
Loss 2.9394


In [None]:
start_time = time.time()
print_every = 1000
# loss_vector = []

for epoch in np.arange(0, df_dev.shape[0]): 
    seq = df_dev.iloc[epoch, 6]
    
    sentence_in = prepare_sequence(seq)
    targets = prepare_labels(seq)
#     sentence_in = sentence_in.to(device = device)
    sentence_in = sentence_in.unsqueeze(1).to(device = device)
    targets = targets.to(device = device)
    
    optimizer.zero_grad()
    output = model(sentence_in)
    
#     print("targets size: ", targets.size())
    loss = criterion(output.view(-1, ntokens), targets)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if epoch % print_every == 0:
        print(f"At Epoch: %.1f"% epoch)
        print(f"Loss %.4f"% loss)
        elapsed = time.time() - start_time
        print(f"time elapsed %.4f"% elapsed)
        torch.save(model.state_dict(), "../../data/transformer_encoder_201025.pt")
#     loss_vector.append(loss)


At Epoch: 0.0
Loss 2.8715
time elapsed 0.0409
At Epoch: 1000.0
Loss 2.8395
time elapsed 19.6486
At Epoch: 2000.0
Loss 3.2027
time elapsed 38.9114
At Epoch: 3000.0
Loss 2.8683
time elapsed 58.1720
At Epoch: 4000.0
Loss 2.9009
time elapsed 77.3369
At Epoch: 5000.0
Loss 2.8267
time elapsed 96.2997
At Epoch: 6000.0
Loss 3.0502
time elapsed 115.7323
At Epoch: 7000.0
Loss 2.9993
time elapsed 135.0577
At Epoch: 8000.0
Loss 2.9352
time elapsed 154.7522
At Epoch: 9000.0
Loss 2.8490
time elapsed 174.1527
At Epoch: 10000.0
Loss 2.9785
time elapsed 194.0793
At Epoch: 11000.0
Loss 2.8689
time elapsed 213.3307
At Epoch: 12000.0
Loss 2.9383
time elapsed 232.5846
At Epoch: 13000.0
Loss 2.8743
time elapsed 254.0461
At Epoch: 14000.0
Loss 2.8858
time elapsed 273.6493
At Epoch: 15000.0
Loss 2.8853
time elapsed 293.2797
At Epoch: 16000.0
Loss 2.8956
time elapsed 315.4301
At Epoch: 17000.0
Loss 2.8313
time elapsed 336.0866
At Epoch: 18000.0
Loss 2.8642
time elapsed 355.1280
At Epoch: 19000.0
Loss 2.8091
ti

At Epoch: 155000.0
Loss 2.8806
time elapsed 3191.6272
At Epoch: 156000.0
Loss 3.0133
time elapsed 3210.6789
At Epoch: 157000.0
Loss 2.8179
time elapsed 3229.7174
At Epoch: 158000.0
Loss 3.0357
time elapsed 3248.8882
At Epoch: 159000.0
Loss 2.8781
time elapsed 3267.9811
At Epoch: 160000.0
Loss 2.8095
time elapsed 3287.2994
At Epoch: 161000.0
Loss 3.2922
time elapsed 3306.5120
At Epoch: 162000.0
Loss 2.8721
time elapsed 3325.5989
At Epoch: 163000.0
Loss 2.9506
time elapsed 3344.8434
At Epoch: 164000.0
Loss 2.9230
time elapsed 3364.0068
At Epoch: 165000.0
Loss 2.9316
time elapsed 3383.2638
At Epoch: 166000.0
Loss 2.8873
time elapsed 3402.2399
At Epoch: 167000.0
Loss 2.7954
time elapsed 3421.3486
At Epoch: 168000.0
Loss 2.8975
time elapsed 3440.6770
At Epoch: 169000.0
Loss 3.0561
time elapsed 3459.8477
At Epoch: 170000.0
Loss 2.9578
time elapsed 3478.9750
At Epoch: 171000.0
Loss 2.8686
time elapsed 3498.0510
At Epoch: 172000.0
Loss 2.9308
time elapsed 3518.4108
At Epoch: 173000.0
Loss 2.89

At Epoch: 307000.0
Loss 2.8775
time elapsed 6029.1911
At Epoch: 308000.0
Loss 2.8559
time elapsed 6047.4638
At Epoch: 309000.0
Loss 2.8172
time elapsed 6065.7279
At Epoch: 310000.0
Loss 2.7845
time elapsed 6084.0873
At Epoch: 311000.0
Loss 2.8599
time elapsed 6102.4510
At Epoch: 312000.0
Loss 2.9258
time elapsed 6120.8129
At Epoch: 313000.0
Loss 2.8173
time elapsed 6139.0597
At Epoch: 314000.0
Loss 2.9676
time elapsed 6157.5454
At Epoch: 315000.0
Loss 2.9411
time elapsed 6175.9573
At Epoch: 316000.0
Loss 3.1104
time elapsed 6194.4250
At Epoch: 317000.0
Loss 2.8771
time elapsed 6212.7808
At Epoch: 318000.0
Loss 2.9468
time elapsed 6231.0867
At Epoch: 319000.0
Loss 2.8686
time elapsed 6249.5511
At Epoch: 320000.0
Loss 2.8366
time elapsed 6268.3688
At Epoch: 321000.0
Loss 2.9200
time elapsed 6286.8242
At Epoch: 322000.0
Loss 2.8766
time elapsed 6305.3147
At Epoch: 323000.0
Loss 2.9355
time elapsed 6323.6173
At Epoch: 324000.0
Loss 3.2143
time elapsed 6341.9410
At Epoch: 325000.0
Loss 2.87

At Epoch: 459000.0
Loss 2.9435
time elapsed 8850.8029
At Epoch: 460000.0
Loss 2.9570
time elapsed 8869.0668
At Epoch: 461000.0
Loss 2.8651
time elapsed 8887.4601
At Epoch: 462000.0
Loss 2.8269
time elapsed 8905.8473
At Epoch: 463000.0
Loss 2.8568
time elapsed 8924.1660
At Epoch: 464000.0
Loss 3.0386
time elapsed 8942.5264
At Epoch: 465000.0
Loss 2.9307
time elapsed 8961.2746
At Epoch: 466000.0
Loss 2.7949
time elapsed 8979.7514
At Epoch: 467000.0
Loss 3.0489
time elapsed 8998.2499
At Epoch: 468000.0
Loss 2.9722
time elapsed 9016.9176
At Epoch: 469000.0
Loss 2.8544
time elapsed 9035.3789
At Epoch: 470000.0
Loss 2.9462
time elapsed 9053.8929
At Epoch: 471000.0
Loss 2.8361
time elapsed 9072.3931
At Epoch: 472000.0
Loss 2.7769
time elapsed 9090.9295
At Epoch: 473000.0
Loss 2.9340
time elapsed 9109.2548
At Epoch: 474000.0
Loss 2.8458
time elapsed 9128.0724
At Epoch: 475000.0
Loss 2.7567
time elapsed 9146.6238
At Epoch: 476000.0
Loss 3.0525
time elapsed 9165.2325
At Epoch: 477000.0
Loss 2.95

At Epoch: 610000.0
Loss 2.8243
time elapsed 11731.4131
At Epoch: 611000.0
Loss 2.9508
time elapsed 11752.4188
At Epoch: 612000.0
Loss 2.8084
time elapsed 11773.4216
At Epoch: 613000.0
Loss 2.9531
time elapsed 11792.3002
At Epoch: 614000.0
Loss 2.8914
time elapsed 11811.3826
At Epoch: 615000.0
Loss 2.9300
time elapsed 11830.7118
At Epoch: 616000.0
Loss 2.9586
time elapsed 11849.9899
At Epoch: 617000.0
Loss 2.8596
time elapsed 11869.0656
At Epoch: 618000.0
Loss 3.0116
time elapsed 11888.2264
At Epoch: 619000.0
Loss 2.8701
time elapsed 11907.0233
At Epoch: 620000.0
Loss 2.8568
time elapsed 11926.4092
At Epoch: 621000.0
Loss 2.9078
time elapsed 11946.1712
At Epoch: 622000.0
Loss 3.0070
time elapsed 11965.4985
At Epoch: 623000.0
Loss 2.8487
time elapsed 11984.0576
At Epoch: 624000.0
Loss 2.9811
time elapsed 12002.7569
At Epoch: 625000.0
Loss 3.2832
time elapsed 12021.5913
At Epoch: 626000.0
Loss 2.9213
time elapsed 12040.2917
At Epoch: 627000.0
Loss 2.8322
time elapsed 12058.9424
At Epoch: 

At Epoch: 759000.0
Loss 2.8880
time elapsed 14556.6057
At Epoch: 760000.0
Loss 2.9698
time elapsed 14575.6309
At Epoch: 761000.0
Loss 2.8610
time elapsed 14594.8027
At Epoch: 762000.0
Loss 2.8224
time elapsed 14613.4885
At Epoch: 763000.0
Loss 2.9481
time elapsed 14632.1755
At Epoch: 764000.0
Loss 2.9306
time elapsed 14651.0971
At Epoch: 765000.0
Loss 2.8459
time elapsed 14669.9566
At Epoch: 766000.0
Loss 2.9278
time elapsed 14688.8862
At Epoch: 767000.0
Loss 2.8784
time elapsed 14707.7395
At Epoch: 768000.0
Loss 2.8879
time elapsed 14726.4881
At Epoch: 769000.0
Loss 2.8470
time elapsed 14745.3844
At Epoch: 770000.0
Loss 2.8543
time elapsed 14764.2524
At Epoch: 771000.0
Loss 2.9128
time elapsed 14783.0968
At Epoch: 772000.0
Loss 2.9646
time elapsed 14802.0488
At Epoch: 773000.0
Loss 2.9332
time elapsed 14821.0186
At Epoch: 774000.0
Loss 2.9094
time elapsed 14839.7873
At Epoch: 775000.0
Loss 2.9222
time elapsed 14858.5540
At Epoch: 776000.0
Loss 2.8687
time elapsed 14877.2631
At Epoch: 

At Epoch: 908000.0
Loss 2.9858
time elapsed 17274.6437
At Epoch: 909000.0
Loss 2.9433
time elapsed 17292.9117
At Epoch: 910000.0
Loss 3.0915
time elapsed 17311.0133
At Epoch: 911000.0
Loss 2.8551
time elapsed 17328.9797
At Epoch: 912000.0
Loss 2.9050
time elapsed 17346.9724
At Epoch: 913000.0
Loss 2.8358
time elapsed 17364.9435
At Epoch: 914000.0
Loss 2.9518
time elapsed 17382.9500
At Epoch: 915000.0
Loss 2.9720
time elapsed 17400.9365
At Epoch: 916000.0
Loss 2.8104
time elapsed 17419.2658
At Epoch: 917000.0
Loss 2.8956
time elapsed 17437.2424
At Epoch: 918000.0
Loss 2.9082
time elapsed 17455.2336
At Epoch: 919000.0
Loss 2.8838
time elapsed 17473.1848
At Epoch: 920000.0
Loss 2.8740
time elapsed 17491.0102
At Epoch: 921000.0
Loss 2.8611
time elapsed 17509.1697
At Epoch: 922000.0
Loss 2.8218
time elapsed 17527.1296
At Epoch: 923000.0
Loss 2.9158
time elapsed 17545.2853
At Epoch: 924000.0
Loss 2.9162
time elapsed 17563.3983
At Epoch: 925000.0
Loss 2.8070
time elapsed 17581.4274
At Epoch: 

At Epoch: 1056000.0
Loss 2.8459
time elapsed 19949.1973
At Epoch: 1057000.0
Loss 2.8766
time elapsed 19967.1730
At Epoch: 1058000.0
Loss 2.8846
time elapsed 19985.1466
At Epoch: 1059000.0
Loss 2.8213
time elapsed 20002.9675
At Epoch: 1060000.0
Loss 2.7298
time elapsed 20020.8227
At Epoch: 1061000.0
Loss 2.9432
time elapsed 20038.7911
At Epoch: 1062000.0
Loss 2.9522
time elapsed 20056.8324
At Epoch: 1063000.0
Loss 2.7336
time elapsed 20074.8105
At Epoch: 1064000.0
Loss 2.7487
time elapsed 20092.7204
At Epoch: 1065000.0
Loss 2.9396
time elapsed 20111.2150
At Epoch: 1066000.0
Loss 2.8143
time elapsed 20129.3690
At Epoch: 1067000.0
Loss 2.9287
time elapsed 20147.3509
At Epoch: 1068000.0
Loss 3.0326
time elapsed 20165.2983
At Epoch: 1069000.0
Loss 2.8376
time elapsed 20183.3212
At Epoch: 1070000.0
Loss 2.9297
time elapsed 20201.2703
At Epoch: 1071000.0
Loss 2.9264
time elapsed 20219.1286
At Epoch: 1072000.0
Loss 2.8023
time elapsed 20237.0304
At Epoch: 1073000.0
Loss 3.1396
time elapsed 202

In [None]:
torch.save(model.state_dict(), "../../data/transformer_encoder_201025.pt")

In [None]:
print("done")