# Model Description
- Apply a transformer based model to pfam/unirep_50 data and extract the embedding features
> In this tutorial, we train nn.TransformerEncoder model on a language modeling task. The language modeling task is to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words. A sequence of tokens are passed to the embedding layer first, followed by a positional encoding layer to account for the order of the word (see the next paragraph for more details). The nn.TransformerEncoder consists of multiple layers of nn.TransformerEncoderLayer. Along with the input sequence, a square attention mask is required because the self-attention layers in nn.TransformerEncoder are only allowed to attend the earlier positions in the sequence. For the language modeling task, any tokens on the future positions should be masked. To have the actual words, the output of nn.TransformerEncoder model is sent to the final Linear layer, which is followed by a log-Softmax function.

## Math and model formulation and code reference:
- Attention is all you need https://arxiv.org/abs/1706.03762
- ResNet https://towardsdatascience.com/understanding-and-visualizing-resnets-442284831be8
- MIT Visualization http://jalammar.github.io/illustrated-transformer/
- An Annotated transformer http://nlp.seas.harvard.edu/2018/04/03/attention.html#a-real-world-example

In [1]:
import math
import torch.nn as nn
import argparse
import random
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.autograd import Variable
import itertools
import pandas as pd
# seed = 7
# torch.manual_seed(seed)
# np.random.seed(seed)

pfamA_motors = pd.read_csv("../data/pfamA_motors.csv")
df_dev = pd.read_csv("../data/df_dev.csv")
pfamA_motors = pfamA_motors.iloc[:,1:]
clan_train_dat = pfamA_motors.groupby("clan").head(4000)
clan_train_dat = clan_train_dat.sample(frac=1).reset_index(drop=True)
clan_test_dat = pfamA_motors.loc[~pfamA_motors["id"].isin(clan_train_dat["id"]),:].groupby("clan").head(400)

clan_train_dat.shape

def df_to_tup(dat):
    data = []
    for i in range(dat.shape[0]):
        row = dat.iloc[i,:]
        tup = (row["seq"],row["clan"])
        data.append(tup)
    return data

clan_training_data = df_to_tup(clan_train_dat)
clan_test_data = df_to_tup(clan_test_dat)
for seq,clan in clan_training_data:
    print(seq)
    print(clan)
    break

aminoacid_list = [
    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
    'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'
]
clan_list = ["actin_like","tubulin_c","tubulin_binding","p_loop_gtpase"]
        
aa_to_ix = dict(zip(aminoacid_list, np.arange(1, 21)))
clan_to_ix = dict(zip(clan_list, np.arange(0, 4)))

def word_to_index(seq,to_ix):
    "Returns a list of indices (integers) from a list of words."
    return [to_ix.get(word, 0) for word in seq]

ix_to_aa = dict(zip(np.arange(1, 21), aminoacid_list))
ix_to_clan = dict(zip(np.arange(0, 4), clan_list))

def index_to_word(ixs,ix_to): 
    "Returns a list of words, given a list of their corresponding indices."
    return [ix_to.get(ix, 'X') for ix in ixs]

def prepare_sequence(seq):
    idxs = word_to_index(seq[0:-1],aa_to_ix)
    return torch.tensor(idxs, dtype=torch.long)

def prepare_labels(seq):
    idxs = word_to_index(seq[1:],aa_to_ix)
    return torch.tensor(idxs, dtype=torch.long)

prepare_labels('YCHXXXXX')



GKAMMGTGEAEGEKRAIQAAEAAISNPLLDEVSMKGAKGVLINITGSMDMTLFEVDEAANRIRAEVDPDANIIVGSTFNQDLEGRVRVSVVATGID
tubulin_c


tensor([2, 7, 0, 0, 0, 0, 0])

In [2]:
# set device
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
class PositionalEncoding(nn.Module):
    """
    PositionalEncoding module injects some information about the relative or absolute position of
    the tokens in the sequence. The positional encodings have the same dimension as the embeddings 
    so that the two can be summed. Here, we use sine and cosine functions of different frequencies.
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
#         x = x + self.pe[:x.size(0), :]
#         print("x.size() : ", x.size())
#         print("self.pe.size() :", self.pe[:x.size(0),:,:].size())
        x = torch.add(x ,Variable(self.pe[:x.size(0),:,:], requires_grad=False))
        return self.dropout(x)

In [4]:

    
class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        if self.src_mask is None or self.src_mask.size(0) != src.size(0):
            device = src.device
            mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
            self.src_mask = mask
#         print("src.device: ", src.device)
        src = self.encoder(src) * math.sqrt(self.ninp)
#         print("self.encoder(src) size: ", src.size())
        src = self.pos_encoder(src)
#         print("elf.pos_encoder(src) size: ", src.size())
        output = self.transformer_encoder(src, self.src_mask)
#         print("output size: ", output.size())
        output = self.decoder(output)
        return output

In [5]:
ntokens = len(aminoacid_list) + 1 # the size of vocabulary
emsize = 768 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 6 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 12 # the number of heads in the multiheadattention models
dropout = 0.1 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout)

In [6]:
import time

In [7]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

model.to(device)
model.train() # Turn on the train mode

TransformerModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=200, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=200, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=20

In [8]:
start_time = time.time()
print_every = 1
loss_vector = []

for epoch in np.arange(0, df_dev.shape[0]): 
    seq = df_dev.iloc[epoch, 6]
    sentence_in = prepare_sequence(seq)
    targets = prepare_labels(seq)
#     sentence_in = sentence_in.to(device = device)
    sentence_in = sentence_in.unsqueeze(1).to(device = device)
    targets = targets.to(device = device)
    
    optimizer.zero_grad()
    output = model(sentence_in)
    
    print("targets size: ", targets.size())
    loss = criterion(output.view(-1, ntokens), targets)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if epoch % print_every == 0:
        print(f"At Epoch: %.1f"% epoch)
        print(f"Loss %.4f"% loss)
    loss_vector.append(loss)
    break
  


targets size:  torch.Size([115])
At Epoch: 0.0
Loss 4.3648


In [9]:
start_time = time.time()
print_every = 1000
# loss_vector = []

for epoch in np.arange(0, df_dev.shape[0]): 
    seq = df_dev.iloc[epoch, 6]
    
    sentence_in = prepare_sequence(seq)
    targets = prepare_labels(seq)
#     sentence_in = sentence_in.to(device = device)
    sentence_in = sentence_in.unsqueeze(1).to(device = device)
    targets = targets.to(device = device)
    
    optimizer.zero_grad()
    output = model(sentence_in)
    
#     print("targets size: ", targets.size())
    loss = criterion(output.view(-1, ntokens), targets)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if epoch % print_every == 0:
        print(f"At Epoch: %.1f"% epoch)
        print(f"Loss %.4f"% loss)
        elapsed = time.time() - start_time
        print(f"time elapsed %.4f"% elapsed)
        torch.save(model.state_dict(), "../data/transformer_encoder_201012.pt")
#     loss_vector.append(loss)

  

At Epoch: 0.0
Loss 8.4750
time elapsed 0.0296
At Epoch: 1000.0
Loss 3.0725
time elapsed 23.4438
At Epoch: 2000.0
Loss 3.2927
time elapsed 45.1294
At Epoch: 3000.0
Loss 2.9190
time elapsed 66.5057
At Epoch: 4000.0
Loss 3.0607
time elapsed 87.3455
At Epoch: 5000.0
Loss 2.9439
time elapsed 108.5937
At Epoch: 6000.0
Loss 3.0892
time elapsed 129.0390
At Epoch: 7000.0
Loss 3.0167
time elapsed 149.5365
At Epoch: 8000.0
Loss 3.0786
time elapsed 170.1348
At Epoch: 9000.0
Loss 2.8547
time elapsed 193.8191
At Epoch: 10000.0
Loss 3.0577
time elapsed 215.4009
At Epoch: 11000.0
Loss 2.8724
time elapsed 237.3885
At Epoch: 12000.0
Loss 2.9954
time elapsed 258.8655
At Epoch: 13000.0
Loss 2.9179
time elapsed 279.4901
At Epoch: 14000.0
Loss 2.9298
time elapsed 300.0160
At Epoch: 15000.0
Loss 2.9484
time elapsed 321.0145
At Epoch: 16000.0
Loss 2.8884
time elapsed 342.1072
At Epoch: 17000.0
Loss 2.8484
time elapsed 363.1312
At Epoch: 18000.0
Loss 2.8564
time elapsed 384.3074
At Epoch: 19000.0
Loss 2.8983
t

At Epoch: 155000.0
Loss 2.9901
time elapsed 3300.3507
At Epoch: 156000.0
Loss 3.0519
time elapsed 3322.6611
At Epoch: 157000.0
Loss 2.9507
time elapsed 3344.3627
At Epoch: 158000.0
Loss 3.1861
time elapsed 3365.4376
At Epoch: 159000.0
Loss 2.9261
time elapsed 3387.0268
At Epoch: 160000.0
Loss 2.8583
time elapsed 3408.6589
At Epoch: 161000.0
Loss 3.4816
time elapsed 3430.0375
At Epoch: 162000.0
Loss 2.9583
time elapsed 3450.8343
At Epoch: 163000.0
Loss 2.9407
time elapsed 3472.6823
At Epoch: 164000.0
Loss 3.3737
time elapsed 3493.4988
At Epoch: 165000.0
Loss 3.0869
time elapsed 3513.9564
At Epoch: 166000.0
Loss 2.9903
time elapsed 3534.4022
At Epoch: 167000.0
Loss 2.7824
time elapsed 3555.2207
At Epoch: 168000.0
Loss 2.9817
time elapsed 3576.8597
At Epoch: 169000.0
Loss 3.4274
time elapsed 3599.9424
At Epoch: 170000.0
Loss 3.1884
time elapsed 3623.3197
At Epoch: 171000.0
Loss 3.0298
time elapsed 3645.1108
At Epoch: 172000.0
Loss 3.0394
time elapsed 3666.0750
At Epoch: 173000.0
Loss 3.12

At Epoch: 307000.0
Loss 2.9097
time elapsed 6541.1349
At Epoch: 308000.0
Loss 3.0504
time elapsed 6562.0790
At Epoch: 309000.0
Loss 2.9092
time elapsed 6582.8036
At Epoch: 310000.0
Loss 3.0297
time elapsed 6605.1886
At Epoch: 311000.0
Loss 3.0360
time elapsed 6627.4690
At Epoch: 312000.0
Loss 3.0126
time elapsed 6649.3154
At Epoch: 313000.0
Loss 3.1299
time elapsed 6670.2051
At Epoch: 314000.0
Loss 2.9979
time elapsed 6691.4245
At Epoch: 315000.0
Loss 3.2214
time elapsed 6712.2494
At Epoch: 316000.0
Loss 3.2847
time elapsed 6733.4585
At Epoch: 317000.0
Loss 3.0457
time elapsed 6754.4249
At Epoch: 318000.0
Loss 3.0197
time elapsed 6775.4901
At Epoch: 319000.0
Loss 3.0796
time elapsed 6796.3298
At Epoch: 320000.0
Loss 2.9445
time elapsed 6819.6741
At Epoch: 321000.0
Loss 2.9948
time elapsed 6842.8050
At Epoch: 322000.0
Loss 3.0061
time elapsed 6864.4155
At Epoch: 323000.0
Loss 3.0284
time elapsed 6885.2709
At Epoch: 324000.0
Loss 3.2249
time elapsed 6906.6269
At Epoch: 325000.0
Loss 2.91

At Epoch: 459000.0
Loss 3.0351
time elapsed 9775.3689
At Epoch: 460000.0
Loss 3.1621
time elapsed 9796.3341
At Epoch: 461000.0
Loss 3.0106
time elapsed 9817.5880
At Epoch: 462000.0
Loss 3.0930
time elapsed 9838.3271
At Epoch: 463000.0
Loss 2.9860
time elapsed 9859.1370
At Epoch: 464000.0
Loss 3.1868
time elapsed 9879.9488
At Epoch: 465000.0
Loss 2.9623
time elapsed 9902.2903
At Epoch: 466000.0
Loss 2.9450
time elapsed 9924.3228
At Epoch: 467000.0
Loss 3.1079
time elapsed 9945.3653
At Epoch: 468000.0
Loss 3.2383
time elapsed 9966.4299
At Epoch: 469000.0
Loss 3.1915
time elapsed 9987.2209
At Epoch: 470000.0
Loss 3.0520
time elapsed 10008.0602
At Epoch: 471000.0
Loss 2.8817
time elapsed 10029.2279
At Epoch: 472000.0
Loss 2.7903
time elapsed 10049.8056
At Epoch: 473000.0
Loss 3.0450
time elapsed 10070.4982
At Epoch: 474000.0
Loss 3.0319
time elapsed 10091.7836
At Epoch: 475000.0
Loss 2.8645
time elapsed 10112.4572
At Epoch: 476000.0
Loss 2.9519
time elapsed 10133.5103
At Epoch: 477000.0
Lo

At Epoch: 609000.0
Loss 3.0097
time elapsed 12974.7748
At Epoch: 610000.0
Loss 2.9496
time elapsed 12995.4356
At Epoch: 611000.0
Loss 3.0796
time elapsed 13016.1354
At Epoch: 612000.0
Loss 3.2714
time elapsed 13036.9276
At Epoch: 613000.0
Loss 3.0017
time elapsed 13057.7669
At Epoch: 614000.0
Loss 3.0042
time elapsed 13078.5821
At Epoch: 615000.0
Loss 3.0605
time elapsed 13099.6022
At Epoch: 616000.0
Loss 3.0771
time elapsed 13122.1544
At Epoch: 617000.0
Loss 3.0882
time elapsed 13144.3192
At Epoch: 618000.0
Loss 3.0673
time elapsed 13165.8255
At Epoch: 619000.0
Loss 3.0015
time elapsed 13188.8262
At Epoch: 620000.0
Loss 2.9510
time elapsed 13212.1374
At Epoch: 621000.0
Loss 3.0618
time elapsed 13235.3811
At Epoch: 622000.0
Loss 3.0854
time elapsed 13259.0063
At Epoch: 623000.0
Loss 2.9219
time elapsed 13281.8440
At Epoch: 624000.0
Loss 3.1809
time elapsed 13303.9608
At Epoch: 625000.0
Loss 3.3179
time elapsed 13324.8973
At Epoch: 626000.0
Loss 3.0477
time elapsed 13345.9977
At Epoch: 

At Epoch: 758000.0
Loss 2.9175
time elapsed 16169.1145
At Epoch: 759000.0
Loss 2.8408
time elapsed 16190.0553
At Epoch: 760000.0
Loss 3.0965
time elapsed 16211.9299
At Epoch: 761000.0
Loss 2.9654
time elapsed 16232.7546
At Epoch: 762000.0
Loss 2.8741
time elapsed 16253.6610
At Epoch: 763000.0
Loss 3.0440
time elapsed 16274.4400
At Epoch: 764000.0
Loss 3.0382
time elapsed 16295.3461
At Epoch: 765000.0
Loss 2.9673
time elapsed 16316.4184
At Epoch: 766000.0
Loss 3.1912
time elapsed 16337.4174
At Epoch: 767000.0
Loss 2.9253
time elapsed 16358.7169
At Epoch: 768000.0
Loss 3.2451
time elapsed 16381.1928
At Epoch: 769000.0
Loss 3.0235
time elapsed 16402.2664
At Epoch: 770000.0
Loss 3.1147
time elapsed 16423.2226
At Epoch: 771000.0
Loss 3.0430
time elapsed 16445.8516
At Epoch: 772000.0
Loss 3.1818
time elapsed 16470.2880
At Epoch: 773000.0
Loss 2.9886
time elapsed 16491.9517
At Epoch: 774000.0
Loss 3.0105
time elapsed 16512.6158
At Epoch: 775000.0
Loss 2.9727
time elapsed 16533.1834
At Epoch: 

At Epoch: 907000.0
Loss 3.1337
time elapsed 19362.9045
At Epoch: 908000.0
Loss 3.3058
time elapsed 19384.3172
At Epoch: 909000.0
Loss 2.9737
time elapsed 19406.2345
At Epoch: 910000.0
Loss 3.2482
time elapsed 19427.1677
At Epoch: 911000.0
Loss 2.9501
time elapsed 19448.0542
At Epoch: 912000.0
Loss 2.9741
time elapsed 19468.9023
At Epoch: 913000.0
Loss 2.9041
time elapsed 19489.6929
At Epoch: 914000.0
Loss 3.1414
time elapsed 19510.5792
At Epoch: 915000.0
Loss 2.9654
time elapsed 19531.9944
At Epoch: 916000.0
Loss 3.0026
time elapsed 19554.2332
At Epoch: 917000.0
Loss 3.1066
time elapsed 19575.5949
At Epoch: 918000.0
Loss 2.9325
time elapsed 19596.5881
At Epoch: 919000.0
Loss 2.9819
time elapsed 19617.3959
At Epoch: 920000.0
Loss 2.8544
time elapsed 19638.2291
At Epoch: 921000.0
Loss 3.0360
time elapsed 19659.1393
At Epoch: 922000.0
Loss 2.8270
time elapsed 19680.2341
At Epoch: 923000.0
Loss 3.0087
time elapsed 19701.2047
At Epoch: 924000.0
Loss 2.9978
time elapsed 19722.1817
At Epoch: 

At Epoch: 1055000.0
Loss 2.9863
time elapsed 22519.7654
At Epoch: 1056000.0
Loss 2.9858
time elapsed 22543.9701
At Epoch: 1057000.0
Loss 3.1834
time elapsed 22568.1609
At Epoch: 1058000.0
Loss 2.8317
time elapsed 22592.0220
At Epoch: 1059000.0
Loss 2.8598
time elapsed 22613.9067
At Epoch: 1060000.0
Loss 3.0616
time elapsed 22634.8585
At Epoch: 1061000.0
Loss 3.1385
time elapsed 22655.6932
At Epoch: 1062000.0
Loss 3.3219
time elapsed 22676.5380
At Epoch: 1063000.0
Loss 2.8477
time elapsed 22697.2584
At Epoch: 1064000.0
Loss 2.7897
time elapsed 22718.3303
At Epoch: 1065000.0
Loss 3.1456
time elapsed 22741.9437
At Epoch: 1066000.0
Loss 2.9621
time elapsed 22763.9049
At Epoch: 1067000.0
Loss 3.1099
time elapsed 22784.6238
At Epoch: 1068000.0
Loss 3.3556
time elapsed 22806.1021
At Epoch: 1069000.0
Loss 3.1260
time elapsed 22826.7791
At Epoch: 1070000.0
Loss 3.1175
time elapsed 22850.9238
At Epoch: 1071000.0
Loss 3.0189
time elapsed 22872.5136
At Epoch: 1072000.0
Loss 2.9729
time elapsed 228

At Epoch: 1202000.0
Loss 2.7824
time elapsed 25695.7325
At Epoch: 1203000.0
Loss 3.2458
time elapsed 25718.9835
At Epoch: 1204000.0
Loss 2.8969
time elapsed 25740.3366
At Epoch: 1205000.0
Loss 3.0189
time elapsed 25761.3719
At Epoch: 1206000.0
Loss 2.8373
time elapsed 25782.1579
At Epoch: 1207000.0
Loss 3.0137
time elapsed 25804.6218
At Epoch: 1208000.0
Loss 3.0992
time elapsed 25826.0223
At Epoch: 1209000.0
Loss 2.9757
time elapsed 25847.0148
At Epoch: 1210000.0
Loss 2.9546
time elapsed 25868.0455
At Epoch: 1211000.0
Loss 3.0597
time elapsed 25890.3536
At Epoch: 1212000.0
Loss 3.0609
time elapsed 25911.0759


In [11]:
torch.save(model.state_dict(), "../data/transformer_encoder_201012.pt")

In [12]:
print("done")

done
