In [1]:
# import esm
import torch
from argparse import Namespace
from esm.constants import proteinseq_toks
import math
import torch.nn as nn
import torch.nn.functional as F
from esm.modules import TransformerLayer, PositionalEmbedding  # noqa
from esm.model import ProteinBertModel

# model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t34_670M_UR50S")
import esm

In [2]:
from ych_util import prepare_mlm_mask
import pandas as pd
import time

In [3]:
pfamA_balanced = pd.read_csv("../../data/esm/pfamA_motors_balanced.csv")
pfamA_balanced = pfamA_balanced.sample(frac = 1)
pfamA_balanced.head()

Unnamed: 0.1,Unnamed: 0,id,description,seq,pfamA_acc,clan_x,pfamA_name
8049,318809,A0A1V0A8E6_9ACTN/194-355,A0A1V0A8E6_9ACTN/194-355 A0A1V0A8E6.1 PF13304....,TPHPHHRTACSTMRWFERGPGGDQLSLSDLEEDRSFEADRAQALAL...,PF13304,p_loop_gtpase,AAA_21
7421,1483573,A0A1G7KN25_9PROT/48-166,A0A1G7KN25_9PROT/48-166 A0A1G7KN25.1 PF01926.2...,EVAFVGRSNVGKSSLVNALTGRKTLARTSNTPGRTQEVIFFDLGGR...,PF01926,p_loop_gtpase,MMR_HSR1
7270,368351,R5DZV7_9FIRM/322-526,R5DZV7_9FIRM/322-526 R5DZV7.1 PF13604.7;AAA_30;,ELDEIQKEAVKKTVQNGLVVITGGPGTGKTTTINTIIRYFQMEGLD...,PF13604,p_loop_gtpase,AAA_30
6993,683920,A0A2B8ATP4_9ACTN/40-200,A0A2B8ATP4_9ACTN/40-200 A0A2B8ATP4.1 PF00005.2...,VNGVDYSVDAGETLAVLGESGSGKSVTAQAVMGILDMPPGRIPHGE...,PF00005,p_loop_gtpase,ABC_tran
14189,187106,A0A0G4IPH7_PLABS/263-393,A0A0G4IPH7_PLABS/263-393 A0A0G4IPH7.1 PF03953....,PRIHFPLCALAPVISAEIAYHEQLSVAEITNSVFEPANQMVKCDPR...,PF03953,tubulin_c,Tubulin_C


In [4]:
alphabet = esm.Alphabet.from_dict(proteinseq_toks)
# model_name = "esm1_t34_670M_UR50S"
model_name = "esm1_t12_85M_UR50S"
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"

In [5]:
if torch.cuda.is_available():
    print("cuda")
    model_data = torch.hub.load_state_dict_from_url(url, progress=False)
else:
    model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location=torch.device('cpu'))

cuda


In [6]:
pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s)
prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s)
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}

In [7]:
model = esm.ProteinBertModel(
        Namespace(**model_args), len(alphabet), padding_idx=alphabet.padding_idx
    )

model.load_state_dict(model_state)

<All keys matched successfully>

In [8]:
# model.load_state_dict(torch.load("../../data/esm1_t12_85M_UR50S_balanced_201102.pt"))

In [9]:
model.cuda()
model.train()

ProteinBertModel(
  (embed_tokens): Embedding(35, 768, padding_idx=1)
  (embed_positions): PositionalEmbedding()
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): BertLayerNorm()
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (final_layer_norm): BertLayerNorm()
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=Tr

In [10]:
batch_converter = alphabet.get_batch_converter()

In [11]:
criterion = nn.CrossEntropyLoss()
lr = 0.0001 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)


In [None]:
start_time = time.time()
print_every = 10000
for j in range(10):
    for i in range(pfamA_balanced.shape[0]):
        if len(pfamA_balanced.iloc[i,3])>1024:
            continue
        data = [(pfamA_balanced.iloc[i,1], pfamA_balanced.iloc[i,3])]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        true_aa,target_ind,masked_batch_tokens = prepare_mlm_mask(alphabet,batch_tokens)
        optimizer.zero_grad()
        results = model(masked_batch_tokens.to('cuda'), repr_layers=[34])   
        pred = results["logits"].squeeze(0)[target_ind,:]   
        target = true_aa.squeeze(0)
        loss = criterion(pred.cpu(),target)
        loss.backward()
        optimizer.step()

        if i % print_every == 0:
            print(batch_labels)
            print(batch_strs)
            print(batch_tokens.size())
            print(masked_batch_tokens.size())
            print(results["logits"].size())
            print(pred.size())
            print(target.size())
            print(f"At Epoch: %.1f"% i)
            print(f"Loss %.4f"% loss)
            elapsed = time.time() - start_time
            print(f"time elapsed %.4f"% elapsed)
            torch.save(model.state_dict(), "../../data/esm1_t12_85M_UR50S_balanced_201102.pt")
    #     loss_vector.append(loss)
    #     break

['A0A1V0A8E6_9ACTN/194-355']
['TPHPHHRTACSTMRWFERGPGGDQLSLSDLEEDRSFEADRAQALALLRLADLGIDDVLIDQCEVAHSDGPRTQRRIRLVHQTAHEKAPLDFAAESAGTRTWFHLIGPVLAALKAGSLLLFDELDASLHPTLCVQLLRLFQDPAMNPKGAQLVFTSHDTSLLN']
torch.Size([1, 163])
torch.Size([1, 163])
torch.Size([1, 163, 35])
torch.Size([24, 35])
torch.Size([24])
At Epoch: 0.0
Loss 2.1437
time elapsed 0.2020
['A0A2N5Y3G4_9GAMM/13-174']
['VIKVIGVGGGGGNAVKHMIENAVEGVDFICANTDAQALSDISSKTVLQLGGDITKGLGAGANPEIGRAAALEDRERIADALRGADMVFITAGMGGGTGTGGAPVVAEVAREMGILTVAVVTRPFAFEGKKRLAIAQEGVRELQQHVDSLITIPNEKLLEVLG']
torch.Size([1, 163])
torch.Size([1, 163])
torch.Size([1, 163, 35])
torch.Size([24, 35])
torch.Size([24])
At Epoch: 10000.0
Loss 1.3429
time elapsed 2207.5310
['A0A1V0A8E6_9ACTN/194-355']
['TPHPHHRTACSTMRWFERGPGGDQLSLSDLEEDRSFEADRAQALALLRLADLGIDDVLIDQCEVAHSDGPRTQRRIRLVHQTAHEKAPLDFAAESAGTRTWFHLIGPVLAALKAGSLLLFDELDASLHPTLCVQLLRLFQDPAMNPKGAQLVFTSHDTSLLN']
torch.Size([1, 163])
torch.Size([1, 163])
torch.Size([1, 163, 35])
torch.Size([24, 35])
torch.Size([24])

In [None]:
torch.save(model.state_dict(), "../../data/esm1_t12_85M_UR50S_balanced_201102.pt")

In [None]:
print("done")