In [2]:
# import esm
import torch
from argparse import Namespace
from esm.constants import proteinseq_toks
import math
import torch.nn as nn
import torch.nn.functional as F
from esm.modules import TransformerLayer, PositionalEmbedding  # noqa
from esm.model import ProteinBertModel

# model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t34_670M_UR50S")
import esm

In [3]:
from ych_util import prepare_mlm_mask
import pandas as pd
import time

In [4]:
kinesin_labelled = pd.read_csv("../../data/esm/kinesin_labelled.csv")
kinesin_labelled = kinesin_labelled.sample(frac = 1)
kinesin_labelled.head()

Unnamed: 0,Entry,Entry name,Status,Protein names,Gene names,Organism,Length,seq,type,label
720,Q9USU8,NGG1_SCHPO,reviewed,Chromatin-remodeling complexes subunit ngg1 (K...,ngg1 ada3 kap1 SPBC28F2.10c,Schizosaccharomyces pombe (strain 972 / ATCC 2...,551,MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVST...,kinesin,unlabeled
706,K7U9N8,OP1_MAIZE,reviewed,Protein OPAQUE1 (Myosin XI motor protein),O1 ZEAMMB73_923224 Zm.5032,Zea mays (Maize),1024,MSYRKGLKVWVEEKGEGWVEAEVVEAKERAVVVFSSQRKKITVSPE...,kinesin,unlabeled
969,O14656,TOR1A_HUMAN,reviewed,Torsin-1A (Dystonia 1 protein) (Torsin ATPase-...,TOR1A DQ2 DYT1 TA TORA,Homo sapiens (Human),332,MKLGRAVLGLLLLAPSVVQAVEPISLGLALAGVLTGYIYPRLYCLF...,kinesin,unlabeled
934,Q99K43,PRC1_MOUSE,reviewed,Protein regulator of cytokinesis 1,Prc1,Mus musculus (Mouse),603,MRRSEVLADESITCLQKALTHLREIWELIGIPEEQRLQRTEVVKKH...,kinesin,unlabeled
318,Q63850,NUP62_MOUSE,reviewed,Nuclear pore glycoprotein p62 (62 kDa nucleopo...,Nup62,Mus musculus (Mouse),526,MSGFNFGGTGAPAGGFTFGTAKTATTTPATGFSFSASGTGTGGFNF...,kinesin,unlabeled


In [5]:
alphabet = esm.Alphabet.from_dict(proteinseq_toks)
# model_name = "esm1_t34_670M_UR50S"
model_name = "esm1_t12_85M_UR50S"
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"

In [6]:
if torch.cuda.is_available():
    print("cuda")
    model_data = torch.hub.load_state_dict_from_url(url, progress=False)
else:
    model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location=torch.device('cpu'))

cuda


In [7]:
pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s)
prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s)
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}

In [8]:
model = esm.ProteinBertModel(
        Namespace(**model_args), len(alphabet), padding_idx=alphabet.padding_idx
    )

model.load_state_dict(model_state)

<All keys matched successfully>

In [9]:
# model.load_state_dict(torch.load("../../data/esm1_t12_85M_UR50S_balanced_201102.pt"))

In [10]:
model.cuda()
model.train()

ProteinBertModel(
  (embed_tokens): Embedding(35, 768, padding_idx=1)
  (embed_positions): PositionalEmbedding()
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): BertLayerNorm()
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (final_layer_norm): BertLayerNorm()
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=Tr

In [11]:
batch_converter = alphabet.get_batch_converter()

In [12]:
criterion = nn.CrossEntropyLoss()
lr = 0.0001 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)


In [17]:
start_time = time.time()
print_every = 10000
for j in range(50):
    for i in range(kinesin_labelled.shape[0]):
        if len(kinesin_labelled.iloc[i,7])>1024:
            continue
        data = [(kinesin_labelled.iloc[i,0], kinesin_labelled.iloc[i,7])]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        true_aa,target_ind,masked_batch_tokens = prepare_mlm_mask(alphabet,batch_tokens)
        optimizer.zero_grad()
        results = model(masked_batch_tokens.to('cuda'), repr_layers=[34])   
        pred = results["logits"].squeeze(0)[target_ind,:]   
        target = true_aa.squeeze(0)
        loss = criterion(pred.cpu(),target)
        loss.backward()
        optimizer.step()

        if i % print_every == 0:
            print(batch_labels)
            print(batch_strs)
            print(batch_tokens.size())
            print(masked_batch_tokens.size())
            print(results["logits"].size())
            print(pred.size())
            print(target.size())
            print(f"At Epoch: %.1f"% i)
            print(f"Loss %.4f"% loss)
            elapsed = time.time() - start_time
            print(f"time elapsed %.4f"% elapsed)
            torch.save(model.state_dict(), "../../data/esm1_t12_85M_UR50S_kinesin_201102.pt")
    #     loss_vector.append(loss)
    #     break

['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSSYNIPPLGERYYDLTPEDEMTNLCANSIYQNLQTSAQGSLEAFNEADTVSEEVRCGPLTERLMASLIPCYTQNDEEQKPSIAVGEFAETDSGSEKSKIGTSIDGIESGNNEYTEQPDIQESSLSICEDRLRYTLKQLGILYDGDVDWSKRQDDEISATLRSLNARLKVVSDENEKMRNALLQMLPEEMAFQEFQNVMDDLDKQIEQAYVKRNRSLKVKKKRIVTDKIGSSATSGSFPVIKSLMDKRSMWLEKLQPLFQDKLTQHLGSPTSIFNDLSDHTTSNYSTSV']
torch.Size([1, 552])
torch.Size([1, 552])
torch.Size([1, 552, 35])
torch.Size([82, 35])
torch.Size([82])
At Epoch: 0.0
Loss 2.6645
time elapsed 0.3691
['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSSYNIP

['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSSYNIPPLGERYYDLTPEDEMTNLCANSIYQNLQTSAQGSLEAFNEADTVSEEVRCGPLTERLMASLIPCYTQNDEEQKPSIAVGEFAETDSGSEKSKIGTSIDGIESGNNEYTEQPDIQESSLSICEDRLRYTLKQLGILYDGDVDWSKRQDDEISATLRSLNARLKVVSDENEKMRNALLQMLPEEMAFQEFQNVMDDLDKQIEQAYVKRNRSLKVKKKRIVTDKIGSSATSGSFPVIKSLMDKRSMWLEKLQPLFQDKLTQHLGSPTSIFNDLSDHTTSNYSTSV']
torch.Size([1, 552])
torch.Size([1, 552])
torch.Size([1, 552, 35])
torch.Size([82, 35])
torch.Size([82])
At Epoch: 0.0
Loss 2.5473
time elapsed 5272.5798
['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSSY

['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSSYNIPPLGERYYDLTPEDEMTNLCANSIYQNLQTSAQGSLEAFNEADTVSEEVRCGPLTERLMASLIPCYTQNDEEQKPSIAVGEFAETDSGSEKSKIGTSIDGIESGNNEYTEQPDIQESSLSICEDRLRYTLKQLGILYDGDVDWSKRQDDEISATLRSLNARLKVVSDENEKMRNALLQMLPEEMAFQEFQNVMDDLDKQIEQAYVKRNRSLKVKKKRIVTDKIGSSATSGSFPVIKSLMDKRSMWLEKLQPLFQDKLTQHLGSPTSIFNDLSDHTTSNYSTSV']
torch.Size([1, 552])
torch.Size([1, 552])
torch.Size([1, 552, 35])
torch.Size([82, 35])
torch.Size([82])
At Epoch: 0.0
Loss 2.4531
time elapsed 10540.6107
['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSS

['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSSYNIPPLGERYYDLTPEDEMTNLCANSIYQNLQTSAQGSLEAFNEADTVSEEVRCGPLTERLMASLIPCYTQNDEEQKPSIAVGEFAETDSGSEKSKIGTSIDGIESGNNEYTEQPDIQESSLSICEDRLRYTLKQLGILYDGDVDWSKRQDDEISATLRSLNARLKVVSDENEKMRNALLQMLPEEMAFQEFQNVMDDLDKQIEQAYVKRNRSLKVKKKRIVTDKIGSSATSGSFPVIKSLMDKRSMWLEKLQPLFQDKLTQHLGSPTSIFNDLSDHTTSNYSTSV']
torch.Size([1, 552])
torch.Size([1, 552])
torch.Size([1, 552, 35])
torch.Size([82, 35])
torch.Size([82])
At Epoch: 0.0
Loss 2.0437
time elapsed 14566.7686
['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSS

['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSSYNIPPLGERYYDLTPEDEMTNLCANSIYQNLQTSAQGSLEAFNEADTVSEEVRCGPLTERLMASLIPCYTQNDEEQKPSIAVGEFAETDSGSEKSKIGTSIDGIESGNNEYTEQPDIQESSLSICEDRLRYTLKQLGILYDGDVDWSKRQDDEISATLRSLNARLKVVSDENEKMRNALLQMLPEEMAFQEFQNVMDDLDKQIEQAYVKRNRSLKVKKKRIVTDKIGSSATSGSFPVIKSLMDKRSMWLEKLQPLFQDKLTQHLGSPTSIFNDLSDHTTSNYSTSV']
torch.Size([1, 552])
torch.Size([1, 552])
torch.Size([1, 552, 35])
torch.Size([82, 35])
torch.Size([82])
At Epoch: 0.0
Loss 1.0917
time elapsed 18035.4885
['Q9USU8']
['MSSEQQNEADSKPAVIPQCFKIENQYETFSRLSETSTPGVVPSVSTLWRLLFELQKMIECEPSCVEYFRQRKEELESHVDSEIETSKDESSVNKVEEKVEEFKEDNVEQEIKQKRSLSESPQESMLEKVSKKPKVSEAHNEEISPENVETIENELDLPVKGKDEQTTGLVYKNANDLLTGSLLSFIVDDSFSYEQKKKLLCVDSFPTSDVRSLVAGTPATDDFSHNKPNNQISISTFYSSLDPYFRAFNDDDIAFLKKGFDVSSS

In [18]:
torch.save(model.state_dict(), "../../data/esm1_t12_85M_UR50S_kinesin_201102.pt")

In [19]:
print("done")

done
