# Documentation
> The following code adopted from transformer_encoder_ver2 (a smaller embedding and FDN with the hope of getting less concentrated features). It further exapanded by applying a masked language model. 

> In the previous model, a mask is applied so that at each word, it sees only the word prior to it. (When predicting #2 word, it sees only #1 word, and all rest are masked to -INF), so that the model follows an auto-regressive manner

> In the MLM setting, certain proportion of the sentense is randomly masked (15% in BERT), and they are masked throughout the training process. The loss is only on those positions's correctness. In the TransformerEncoderLayer, the mask pass in should be changed and should mask the position of mask to -INF while keeping all the rest to 0.


In [1]:
import math
import torch.nn as nn
import argparse
import random
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.autograd import Variable
import itertools
import pandas as pd
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math

seed = 7
torch.manual_seed(seed)
np.random.seed(seed)


pfamA_motors = pd.read_csv("../../data/pfamA_motors.csv")
df_dev = pd.read_csv("../../data/df_dev.csv")
motor_toolkit = pd.read_csv("../../data/motor_tookits.csv")

pfamA_motors_balanced = pfamA_motors.groupby('clan').apply(lambda _df: _df.sample(4500,random_state=1))
pfamA_motors_balanced = pfamA_motors_balanced.apply(lambda x: x.reset_index(drop = True))

pfamA_target_name = ["PF00349","PF00022","PF03727","PF06723",\
                       "PF14450","PF03953","PF12327","PF00091","PF10644",\
                      "PF13809","PF14881","PF00063","PF00225","PF03028"]

pfamA_target = pfamA_motors.loc[pfamA_motors["pfamA_acc"].isin(pfamA_target_name),:]


# shuffle pfamA_target and pfamA_motors_balanced
pfamA_target = pfamA_target.sample(frac = 1)
pfamA_target_ind = pfamA_target.iloc[:,0]
print(pfamA_target_ind[0:5])
print(pfamA_motors_balanced.shape)

pfamA_motors_balanced = pfamA_motors_balanced.sample(frac = 1) 
pfamA_motors_balanced_ind = pfamA_motors_balanced.iloc[:,0]
print(pfamA_motors_balanced_ind[0:5])
print(pfamA_target.shape)



179519      179519
1414859    1414859
12920        12920
1415258    1415258
13385        13385
Name: Unnamed: 0, dtype: int64
(18000, 6)
13493    180756
1539     166414
2688     131988
1691      37094
188      130155
Name: Unnamed: 0, dtype: int64
(59149, 6)


In [2]:
aminoacid_list = [
    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
    'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'
]
clan_list = ["actin_like","tubulin_c","tubulin_binding","p_loop_gtpase"]
        
aa_to_ix = dict(zip(aminoacid_list, np.arange(1, 21)))
clan_to_ix = dict(zip(clan_list, np.arange(0, 4)))

def word_to_index(seq,to_ix):
    "Returns a list of indices (integers) from a list of words."
    return [to_ix.get(word, 0) for word in seq]

ix_to_aa = dict(zip(np.arange(1, 21), aminoacid_list))
ix_to_clan = dict(zip(np.arange(0, 4), clan_list))

def index_to_word(ixs,ix_to): 
    "Returns a list of words, given a list of their corresponding indices."
    return [ix_to.get(ix, 'X') for ix in ixs]



In [3]:
def prepare_sequence(seq):
    idxs = word_to_index(seq[:],aa_to_ix)
    return torch.tensor(idxs, dtype=torch.long)

# def prepare_labels(seq):
#     idxs = word_to_index(seq[1:],aa_to_ix)
#     return torch.tensor(idxs, dtype=torch.long)

def prepare_eval(seq):
    idxs = word_to_index(seq[:],aa_to_ix)
    return torch.tensor(idxs, dtype=torch.long)

prepare_sequence('YCHXXXXX')

tensor([20,  2,  7,  0,  0,  0,  0,  0])

In [4]:
# set device
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:
class PositionalEncoding(nn.Module):
    """
    PositionalEncoding module injects some information about the relative or absolute position of
    the tokens in the sequence. The positional encodings have the same dimension as the embeddings 
    so that the two can be summed. Here, we use sine and cosine functions of different frequencies.
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
#         x = x + self.pe[:x.size(0), :]
#         print("x.size() : ", x.size())
#         print("self.pe.size() :", self.pe[:x.size(0),:,:].size())
        x = torch.add(x ,Variable(self.pe[:x.size(0),:,:], requires_grad=False))
        return self.dropout(x)

    
    
class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout,activation='gelu')
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
#     def _generate_square_mlm_mask(self, sz):
#         # 0's are the masked position
#         zeros_num = int(sz * mask_frac)
#         ones_num = sz - zeros_num
#         lm_mask = torch.cat([torch.zeros(zeros_num), torch.ones(ones_num)])
#         lm_mask = lm_mask[torch.randperm(sz)]
#         masked_ind = lm_mask.eq(0)
#         lm_mask = lm_mask.repeat(sz, 1)
#         mask = lm_mask.float().masked_fill(lm_mask == 0, float('-inf')).masked_fill(lm_mask == 1, float(0.0))
#         mask = mask.to(device)
#         return mask,masked_ind
        
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src,lm_mask):
#         if model.training:
#             device = src.device
#             mask = self._generate_square_mlm_mask(src.size(0))
            
#         print("src.device: ", src.device)
        src = self.encoder(src) * math.sqrt(self.ninp)
#         print("self.encoder(src) size: ", src.size())
        src = self.pos_encoder(src)
#         print("elf.pos_encoder(src) size: ", src.size())
        output = self.transformer_encoder(src, lm_mask)
#         print("output size: ", output.size())
        output = self.decoder(output)
        return output

In [6]:
ntokens = len(aminoacid_list) + 1 # the size of vocabulary
emsize = 12 # embedding dimension
nhid = 100 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 6 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 12 # the number of heads in the multiheadattention models
dropout = 0.1 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout)

In [7]:
criterion = nn.CrossEntropyLoss()
lr = 3.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [8]:
model.to(device)
model.train() # Turn on the train mode

TransformerModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=100, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=100, out_features=12, bias=True)
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=100, bias=T

In [9]:
import time

In [10]:
start_time = time.time()
print_every = 1000
# loss_vector = []
mask_frac = 0.15

for epoch in np.arange(0, df_dev.shape[0]): 
    seq = df_dev.iloc[epoch, 6]
    sz = len(seq)
    zeros_num = int(sz * mask_frac)
    ones_num = sz - zeros_num
    lm_mask = torch.cat([torch.zeros(zeros_num), torch.ones(ones_num)])
    lm_mask = lm_mask[torch.randperm(sz)]
    mask_ind = lm_mask.eq(0)
    lm_mask = lm_mask.repeat(sz, 1)
    lm_mask = lm_mask.float().masked_fill(lm_mask == 0, float('-inf')).masked_fill(lm_mask == 1, float(0.0))
    lm_mask = lm_mask.to(device)

    sentence_in = prepare_sequence(seq)
#     sentence_in = sentence_in.to(device = device)
    sentence_in = sentence_in.unsqueeze(1).to(device = device)

    optimizer.zero_grad()
    output = model(sentence_in,lm_mask)
#     print(mask_ind)
    targets = sentence_in[mask_ind]
#     targets = targets.to(device = device)
    
    print(targets.squeeze(1).size())
    print(output[mask_ind].squeeze(1).size())

    loss = criterion(output[mask_ind].squeeze(1), targets.squeeze(1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if epoch % print_every == 0:
        print(f"At Epoch: %.1f"% epoch)
        print(f"Loss %.4f"% loss)
        elapsed = time.time() - start_time
        print(f"time elapsed %.4f"% elapsed)
#         torch.save(model.state_dict(), "../../data/transformer_encoder_201025.pt")
#     loss_vector.append(loss)
    break

torch.Size([17])
torch.Size([17, 21])
At Epoch: 0.0
Loss 3.0147
time elapsed 0.0404


In [15]:
start_time = time.time()
print_every = 1000
# loss_vector = []
mask_frac = 0.15

for epoch in np.arange(409000, df_dev.shape[0]): 
    seq = df_dev.iloc[epoch, 6]
#     print(seq)
    
    sz = len(seq)
#     print(sz)
    zeros_num = int(sz * mask_frac)
    if zeros_num < 1:
        continue
    ones_num = sz - zeros_num
    lm_mask = torch.cat([torch.zeros(zeros_num), torch.ones(ones_num)])
    lm_mask = lm_mask[torch.randperm(sz)]
    mask_ind = lm_mask.eq(0)
    lm_mask = lm_mask.repeat(sz, 1)
    lm_mask = lm_mask.float().masked_fill(lm_mask == 0, float('-inf')).masked_fill(lm_mask == 1, float(0.0))
    lm_mask = lm_mask.to(device)

    sentence_in = prepare_sequence(seq)
#     sentence_in = sentence_in.to(device = device)
    sentence_in = sentence_in.unsqueeze(1).to(device = device)

    optimizer.zero_grad()
    output = model(sentence_in,lm_mask)
#     print(mask_ind)
    targets = sentence_in[mask_ind]
#     targets = targets.to(device = device)
    
#     print(targets.squeeze(1).size())
#     print(output[mask_ind].squeeze(1).size())

    loss = criterion(output[mask_ind].squeeze(1), targets.squeeze(1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if epoch % print_every == 0:
        print(f"At Epoch: %.1f"% epoch)
        print(f"Loss %.4f"% loss)
        elapsed = time.time() - start_time
        print(f"time elapsed %.4f"% elapsed)
#         torch.save(model.state_dict(), "../../data/transformer_encoder_201025.pt")
#     loss_vector.append(loss)
#     break

At Epoch: 409000.0
Loss 2.9233
time elapsed 0.0420
At Epoch: 410000.0
Loss 3.0972
time elapsed 21.4877
At Epoch: 411000.0
Loss 2.8950
time elapsed 42.8958
At Epoch: 412000.0
Loss 2.8483
time elapsed 63.4584
At Epoch: 413000.0
Loss 2.9233
time elapsed 84.2736
At Epoch: 414000.0
Loss 3.0290
time elapsed 104.3078
At Epoch: 415000.0
Loss 2.8642
time elapsed 124.5025
At Epoch: 416000.0
Loss 2.9257
time elapsed 144.6149
At Epoch: 417000.0
Loss 3.1042
time elapsed 164.7017
At Epoch: 418000.0
Loss 2.9663
time elapsed 184.9746
At Epoch: 419000.0
Loss 2.8315
time elapsed 204.7974
At Epoch: 420000.0
Loss 3.0034
time elapsed 224.5893
At Epoch: 421000.0
Loss 2.9070
time elapsed 244.4433
At Epoch: 422000.0
Loss 2.8272
time elapsed 264.1716
At Epoch: 423000.0
Loss 2.9028
time elapsed 284.2062
At Epoch: 424000.0
Loss 3.1271
time elapsed 304.0006
At Epoch: 425000.0
Loss 2.8654
time elapsed 323.6824
At Epoch: 426000.0
Loss 2.9949
time elapsed 343.4006
At Epoch: 427000.0
Loss 2.9270
time elapsed 363.1939

At Epoch: 562000.0
Loss 2.8067
time elapsed 2997.5154
At Epoch: 563000.0
Loss 2.7862
time elapsed 3016.9134
At Epoch: 564000.0
Loss 3.0033
time elapsed 3036.4787
At Epoch: 565000.0
Loss 3.0746
time elapsed 3055.8076
At Epoch: 566000.0
Loss 3.1346
time elapsed 3075.2744
At Epoch: 567000.0
Loss 2.8001
time elapsed 3094.5744
At Epoch: 568000.0
Loss 2.9719
time elapsed 3113.8396
At Epoch: 569000.0
Loss 3.0518
time elapsed 3133.1251
At Epoch: 570000.0
Loss 2.8023
time elapsed 3153.2113
At Epoch: 571000.0
Loss 3.1707
time elapsed 3172.6268
At Epoch: 572000.0
Loss 3.0161
time elapsed 3192.5412
At Epoch: 573000.0
Loss 3.0183
time elapsed 3211.8664
At Epoch: 574000.0
Loss 2.7574
time elapsed 3231.1118
At Epoch: 575000.0
Loss 2.9552
time elapsed 3251.5716
At Epoch: 576000.0
Loss 2.9036
time elapsed 3271.1131
At Epoch: 577000.0
Loss 3.0611
time elapsed 3290.7239
At Epoch: 578000.0
Loss 3.0125
time elapsed 3310.1727
At Epoch: 579000.0
Loss 2.8550
time elapsed 3329.6687
At Epoch: 580000.0
Loss 2.82

At Epoch: 714000.0
Loss 2.8673
time elapsed 5956.9187
At Epoch: 715000.0
Loss 3.0382
time elapsed 5976.3020
At Epoch: 716000.0
Loss 2.9460
time elapsed 5995.6193
At Epoch: 717000.0
Loss 3.0634
time elapsed 6015.2330
At Epoch: 718000.0
Loss 2.9278
time elapsed 6034.5269
At Epoch: 719000.0
Loss 3.0794
time elapsed 6053.8210
At Epoch: 720000.0
Loss 2.9188
time elapsed 6073.3006
At Epoch: 721000.0
Loss 3.0978
time elapsed 6093.3564
At Epoch: 722000.0
Loss 3.0004
time elapsed 6113.0592
At Epoch: 723000.0
Loss 2.7015
time elapsed 6132.3735
At Epoch: 724000.0
Loss 3.2598
time elapsed 6152.1439
At Epoch: 725000.0
Loss 2.9024
time elapsed 6171.6905
At Epoch: 726000.0
Loss 3.2313
time elapsed 6191.5199
At Epoch: 727000.0
Loss 2.9044
time elapsed 6210.8693
At Epoch: 728000.0
Loss 2.9072
time elapsed 6230.9512
At Epoch: 729000.0
Loss 3.2125
time elapsed 6250.0366
At Epoch: 730000.0
Loss 2.9290
time elapsed 6269.1898
At Epoch: 731000.0
Loss 3.0676
time elapsed 6288.5590
At Epoch: 732000.0
Loss 3.05

At Epoch: 866000.0
Loss 2.8241
time elapsed 8910.4189
At Epoch: 867000.0
Loss 2.8173
time elapsed 8929.8072
At Epoch: 868000.0
Loss 2.7985
time elapsed 8949.6901
At Epoch: 869000.0
Loss 3.0638
time elapsed 8969.0518
At Epoch: 870000.0
Loss 3.0494
time elapsed 8988.3854
At Epoch: 871000.0
Loss 2.9990
time elapsed 9007.8692
At Epoch: 872000.0
Loss 3.0540
time elapsed 9027.8354
At Epoch: 873000.0
Loss 2.8170
time elapsed 9048.9102
At Epoch: 874000.0
Loss 3.0366
time elapsed 9069.0785
At Epoch: 875000.0
Loss 2.8885
time elapsed 9089.0950
At Epoch: 876000.0
Loss 3.1424
time elapsed 9108.7287
At Epoch: 877000.0
Loss 3.0683
time elapsed 9128.5453
At Epoch: 878000.0
Loss 3.1583
time elapsed 9148.0700
At Epoch: 879000.0
Loss 2.9296
time elapsed 9167.6477
At Epoch: 880000.0
Loss 2.9904
time elapsed 9187.3637
At Epoch: 881000.0
Loss 3.0627
time elapsed 9206.5252
At Epoch: 882000.0
Loss 3.4990
time elapsed 9226.8080
At Epoch: 883000.0
Loss 2.9818
time elapsed 9246.1304
At Epoch: 884000.0
Loss 3.06

At Epoch: 1016000.0
Loss 2.9246
time elapsed 11832.5901
At Epoch: 1017000.0
Loss 2.8810
time elapsed 11851.6546
At Epoch: 1018000.0
Loss 2.9882
time elapsed 11870.7110
At Epoch: 1019000.0
Loss 3.4012
time elapsed 11889.9895
At Epoch: 1020000.0
Loss 2.8464
time elapsed 11909.1442
At Epoch: 1021000.0
Loss 2.9913
time elapsed 11928.4331
At Epoch: 1022000.0
Loss 2.9649
time elapsed 11947.5233
At Epoch: 1023000.0
Loss 2.7080
time elapsed 11966.7119
At Epoch: 1024000.0
Loss 3.3483
time elapsed 11986.1007
At Epoch: 1025000.0
Loss 3.0975
time elapsed 12005.2706
At Epoch: 1026000.0
Loss 2.8761
time elapsed 12024.3902
At Epoch: 1027000.0
Loss 3.0020
time elapsed 12043.5033
At Epoch: 1028000.0
Loss 2.9670
time elapsed 12062.7518
At Epoch: 1029000.0
Loss 2.9908
time elapsed 12082.1296
At Epoch: 1030000.0
Loss 3.3704
time elapsed 12101.4696
At Epoch: 1031000.0
Loss 2.7275
time elapsed 12120.9378
At Epoch: 1032000.0
Loss 2.9323
time elapsed 12140.3520
At Epoch: 1033000.0
Loss 3.3307
time elapsed 121

At Epoch: 1168000.0
Loss 3.1625
time elapsed 14833.1363
At Epoch: 1169000.0
Loss 3.1762
time elapsed 14852.8635
At Epoch: 1170000.0
Loss 3.2426
time elapsed 14872.8542
At Epoch: 1171000.0
Loss 2.8256
time elapsed 14892.8476
At Epoch: 1172000.0
Loss 3.0895
time elapsed 14912.8235
At Epoch: 1173000.0
Loss 3.4598
time elapsed 14932.7872
At Epoch: 1174000.0
Loss 3.1454
time elapsed 14952.8656
At Epoch: 1175000.0
Loss 3.2269
time elapsed 14972.8702
At Epoch: 1176000.0
Loss 3.0172
time elapsed 14992.8344
At Epoch: 1177000.0
Loss 2.7145
time elapsed 15012.9722
At Epoch: 1178000.0
Loss 2.9531
time elapsed 15033.0231
At Epoch: 1179000.0
Loss 2.7678
time elapsed 15053.0379
At Epoch: 1180000.0
Loss 3.0395
time elapsed 15072.9713
At Epoch: 1181000.0
Loss 2.6566
time elapsed 15093.0153
At Epoch: 1182000.0
Loss 3.0592
time elapsed 15113.1308
At Epoch: 1183000.0
Loss 2.7146
time elapsed 15133.3558
At Epoch: 1184000.0
Loss 3.1270
time elapsed 15153.2265
At Epoch: 1185000.0
Loss 2.9956
time elapsed 151

In [16]:
torch.save(model.state_dict(), "../../data/transformer_encoder_mlm_201025.pt")

In [17]:
print('done')

done
