In [1]:
# import esm
import torch
from argparse import Namespace
from esm.constants import proteinseq_toks
import math
import torch.nn as nn
import torch.nn.functional as F
from esm.modules import TransformerLayer, PositionalEmbedding  # noqa
from esm.model import ProteinBertModel

# model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t34_670M_UR50S")
import esm

In [2]:
from ych_util import prepare_mlm_mask
import pandas as pd
import time

In [3]:
motor_toolkit = pd.read_csv("../../data/esm/motor_toolkit.csv")
motor_toolkit = motor_toolkit.sample(frac = 1)
motor_toolkit.head()

Unnamed: 0,Entry,Entry name,Status,Protein names,Gene names,Organism,Length,seq,type
1641,Q0P5A1,DCTN3_BOVIN,reviewed,Dynactin subunit 3,DCTN3,Bos taurus (Bovine),186,MAAVTDVQRLQARVEELERWVYGPGGSRGSRKVADGLVKVQVALGN...,dynein
730,Q6URW6,MYH14_MOUSE,reviewed,Myosin-14 (Myosin heavy chain 14) (Myosin heav...,Myh14,Mus musculus (Mouse),1024,MAAVTMSVSGRKVASRPGPVPEAAQSFLYAPRTPNVGGPGGPQVEW...,kinesin
645,Q63HQ0,AP1AR_HUMAN,reviewed,AP-1 complex-associated regulatory protein (2c...,AP1AR C4orf16 PRO0971,Homo sapiens (Human),302,MGNCCWTQCFGLLRKEAGRLQRVGGGGGSKYFRTCSRGEHLTIEFE...,kinesin
1133,Q5R9P5,GCR_PONAB,reviewed,Glucocorticoid receptor (GR) (Nuclear receptor...,NR3C1 GRL,Pongo abelii (Sumatran orangutan) (Pongo pygma...,777,MDSKESLTPGREENPSSVLAQERGNVMDFYKTLRGGATVKVSASSP...,dynein
2707,P07313,MYLK2_RABIT,reviewed,"Myosin light chain kinase 2, skeletal/cardiac ...",MYLK2,Oryctolagus cuniculus (Rabbit),608,MATENGAVELGIQSLSTDEASKGAASEESLAAEKDPAPPDPEKGPG...,myosin_v


In [4]:
alphabet = esm.Alphabet.from_dict(proteinseq_toks)
# model_name = "esm1_t34_670M_UR50S"
model_name = "esm1_t12_85M_UR50S"
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"

In [5]:
if torch.cuda.is_available():
    print("cuda")
    model_data = torch.hub.load_state_dict_from_url(url, progress=False)
else:
    model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location=torch.device('cpu'))

cuda


In [6]:
pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s)
prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s)
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}

In [7]:
model = esm.ProteinBertModel(
        Namespace(**model_args), len(alphabet), padding_idx=alphabet.padding_idx
    )

model.load_state_dict(model_state)

<All keys matched successfully>

In [8]:
# model.load_state_dict(torch.load("../../data/esm1_t12_85M_UR50S_balanced_201102.pt"))

In [9]:
model.cuda()
model.train()

ProteinBertModel(
  (embed_tokens): Embedding(35, 768, padding_idx=1)
  (embed_positions): PositionalEmbedding()
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): BertLayerNorm()
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (final_layer_norm): BertLayerNorm()
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=Tr

In [10]:
batch_converter = alphabet.get_batch_converter()

In [11]:
criterion = nn.CrossEntropyLoss()
lr = 0.0001 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)


In [12]:
start_time = time.time()
print_every = 1000
for j in range(10):
    for i in range(motor_toolkit.shape[0]):
        if len(motor_toolkit.iloc[i,7])>1024:
            continue
        data = [(motor_toolkit.iloc[i,0], motor_toolkit.iloc[i,7])]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        true_aa,target_ind,masked_batch_tokens = prepare_mlm_mask(alphabet,batch_tokens)
        optimizer.zero_grad()
        results = model(masked_batch_tokens.to('cuda'), repr_layers=[34])   
        pred = results["logits"].squeeze(0)[target_ind,:]   
        target = true_aa.squeeze(0)
        loss = criterion(pred.cpu(),target)
        loss.backward()
        optimizer.step()

        if i % print_every == 0:
            print(batch_labels)
            print(batch_strs)
            print(batch_tokens.size())
            print(masked_batch_tokens.size())
            print(results["logits"].size())
            print(pred.size())
            print(target.size())
            print(f"At Epoch: %.1f"% i)
            print(f"Loss %.4f"% loss)
            elapsed = time.time() - start_time
            print(f"time elapsed %.4f"% elapsed)
            torch.save(model.state_dict(), "../../data/esm1_t12_85M_UR50S_motor_toolkit_201102.pt")
    #     loss_vector.append(loss)
    #     break

['Q0P5A1']
['MAAVTDVQRLQARVEELERWVYGPGGSRGSRKVADGLVKVQVALGNIASKRERVKVLYKKIEDLIKYLDPEYIDRIALPDASKLQFILAEEQFILSQVALLEQVEALVPMLDSTHIKAVPEHAARLQRLAQIHIQQQDQCVEITEESKALLEEYNKTTMLLSKQFVQWDELLCQLEAAKQVKPVEE']
torch.Size([1, 187])
torch.Size([1, 187])
torch.Size([1, 187, 35])
torch.Size([27, 35])
torch.Size([27])
At Epoch: 0.0
Loss 2.6544
time elapsed 0.2296
['P18266']
['MSGRPRTTSFAESCKPVQQPSAFGSMKVSRDKDGSKVTTVVATPGQGPDRPQEVSYTDTKVIGNGSFGVVYQAKLCDSGELVAIKKVLQDKRFKNRELQIMRKLDHCNIVRLRYFFYSSGEKKDEVYLNLVLDYVPETVYRVARHYSRAKQTLPVIYVKLYMYQLFRSLAYIHSFGICHRDIKPQNLLLDPDTAVLKLCDFGSAKQLVRGEPNVSYICSRYYRAPELIFGATDYTSSIDMWSAGCVLAELLLGQPIFPGDSGVDQLVEIIKVLGTPTREQIREMNPNYTEFKFPQIKAHPWTKVFRPRTPPEAIALCSRLLEYTPTARLTPLEACAHSFFDELRDPNVKLPNGRDTPALFNFTTQELSSNPPLATILIPPHARIQAAASPPANATAASDTNAGDRGQTNNAASASASNST']
torch.Size([1, 421])
torch.Size([1, 421])
torch.Size([1, 421, 35])
torch.Size([63, 35])
torch.Size([63])
At Epoch: 1000.0
Loss 2.8208
time elapsed 253.8363
['Q9BY41']
['MEEPEEPADSGQSLVPVYIYSPEYVSMCDSLAKIPKRASMVH

['P18266']
['MSGRPRTTSFAESCKPVQQPSAFGSMKVSRDKDGSKVTTVVATPGQGPDRPQEVSYTDTKVIGNGSFGVVYQAKLCDSGELVAIKKVLQDKRFKNRELQIMRKLDHCNIVRLRYFFYSSGEKKDEVYLNLVLDYVPETVYRVARHYSRAKQTLPVIYVKLYMYQLFRSLAYIHSFGICHRDIKPQNLLLDPDTAVLKLCDFGSAKQLVRGEPNVSYICSRYYRAPELIFGATDYTSSIDMWSAGCVLAELLLGQPIFPGDSGVDQLVEIIKVLGTPTREQIREMNPNYTEFKFPQIKAHPWTKVFRPRTPPEAIALCSRLLEYTPTARLTPLEACAHSFFDELRDPNVKLPNGRDTPALFNFTTQELSSNPPLATILIPPHARIQAAASPPANATAASDTNAGDRGQTNNAASASASNST']
torch.Size([1, 421])
torch.Size([1, 421])
torch.Size([1, 421, 35])
torch.Size([63, 35])
torch.Size([63])
At Epoch: 1000.0
Loss 2.2273
time elapsed 4011.1810
['Q9BY41']
['MEEPEEPADSGQSLVPVYIYSPEYVSMCDSLAKIPKRASMVHSLIEAYALHKQMRIVKPKVASMEEMATFHTDAYLQHLQKVSQEGDDDHPDSIEYGLGYDCPATEGIFDYAAAIGGATITAAQCLIDGMCKVAINWSGGWHHAKKDEASGFCYLNDAVLGILRLRRKFERILYVDLDLHHGDGVEDAFSFTSKVMTVSLHKFSPGFFPGTGDVSDVGLGKGRYYSVNVPIQDGIQDEKYYQICESVLKEVYQAFNPKAVVLQLGADTIAGDPMCSFNMTPVGIGKCLKYILQWQLATLILGGGGYNLANTARCWTYLTGVILGKTLSSEIPDHEFFTAYGPDYVLEITPSCRPDRNEPHRIQQILNYIKGNLKHVV']
torch.Size([1,

['Q9BY41']
['MEEPEEPADSGQSLVPVYIYSPEYVSMCDSLAKIPKRASMVHSLIEAYALHKQMRIVKPKVASMEEMATFHTDAYLQHLQKVSQEGDDDHPDSIEYGLGYDCPATEGIFDYAAAIGGATITAAQCLIDGMCKVAINWSGGWHHAKKDEASGFCYLNDAVLGILRLRRKFERILYVDLDLHHGDGVEDAFSFTSKVMTVSLHKFSPGFFPGTGDVSDVGLGKGRYYSVNVPIQDGIQDEKYYQICESVLKEVYQAFNPKAVVLQLGADTIAGDPMCSFNMTPVGIGKCLKYILQWQLATLILGGGGYNLANTARCWTYLTGVILGKTLSSEIPDHEFFTAYGPDYVLEITPSCRPDRNEPHRIQQILNYIKGNLKHVV']
torch.Size([1, 378])
torch.Size([1, 378])
torch.Size([1, 378, 35])
torch.Size([56, 35])
torch.Size([56])
At Epoch: 2000.0
Loss 2.7920
time elapsed 8188.3776
['A5DKH0']
['MAIVKRGARSKAKQEAPAKSGIKKAEFDLHKKKEVGVSDLTLLSKISDDSINDNLHKRFMNNTIYTYIGHVLISVNPFQDLGIYTKEYLNMYKGKNRLEVPPHVFAIAESMYYHLKSYGESQCVIISGESGAGKTEAAKQIMQYIANVSVDDKVSTTSEITQIKDMVLATNPLLESFGCAKTLRNNNSSRHGKYLEIYFNPSNYQPVAAHITNYLLEKQRVVSQITNERNFHIFYQLTKSCPPEYKQSFGLQGPETYVYTSAAKCIDVEGINDGKDFAETLQAMNTIGLSKAEQDNIFRSLASILWIGNISFVENEDGNAAIRDDTVTTFVAYLLEVDANVLKKSILERVIETSHGMRRGSTYHVPLNIVQATASRDALAKGIYNYLFDWIVERVNISLRGRAEAMEKKTIGILDIYGFEIFEHNSFEQICINYVNE

['A5DKH0']
['MAIVKRGARSKAKQEAPAKSGIKKAEFDLHKKKEVGVSDLTLLSKISDDSINDNLHKRFMNNTIYTYIGHVLISVNPFQDLGIYTKEYLNMYKGKNRLEVPPHVFAIAESMYYHLKSYGESQCVIISGESGAGKTEAAKQIMQYIANVSVDDKVSTTSEITQIKDMVLATNPLLESFGCAKTLRNNNSSRHGKYLEIYFNPSNYQPVAAHITNYLLEKQRVVSQITNERNFHIFYQLTKSCPPEYKQSFGLQGPETYVYTSAAKCIDVEGINDGKDFAETLQAMNTIGLSKAEQDNIFRSLASILWIGNISFVENEDGNAAIRDDTVTTFVAYLLEVDANVLKKSILERVIETSHGMRRGSTYHVPLNIVQATASRDALAKGIYNYLFDWIVERVNISLRGRAEAMEKKTIGILDIYGFEIFEHNSFEQICINYVNEKLQQIFIQLTLKAEQDEYVQEQIKWTPIDYFNNKVVCDLIEATRPQPGLFAALNDSIKTAHADSDAADQVFAQRLSMVGANNRHFEDRKGKFIIKHYAGDVVYDVAGMTDKNKDAMLRDLLEMLSTSQNTFVNSVLFPPDLLAVLTDKKKRPETASDKIKKSANLLVDTLSQCQPSYIRTIKPNQTKRPKEYDNAQVLHQVKYLGLKENVRIRRAGFAYRTTFDKFVQRFYLLSPKTGYAGDYIWNGDDISAVREILKSCHIPDTEFQMGTSKVFIKTPETLFAMEDMRDKYWHNMAARIQRAWRRYVKRKEDAARLIQNAWKVKKHGNQFEQLRDYGNGLLQGRKERRRMSMLGSRAFMGDYLGCNYSSGFGRFVLNQVGLNEHVVFSGKGEILLSKFGRSSKRLPRIFVLGRSSLYIIAENLVERRLQLSKEFVIPINSINYVGLSTFQDNWLAVSLHSPTPTTPDVLINLDFKTELVTHLKKLNPGLTIKIGPTIEYQKKPGKFHTVKFVRSDVSTIPIHGDVYKSGTVSVRPGLSPDSQNPKRPRA

In [13]:
torch.save(model.state_dict(), "../../data/esm1_t12_85M_UR50S_motor_toolkit_201102.pt")

In [14]:
print("done")

done
