# Documentation
> 201103: This notebook generate embedding vectors for motor_toolkit,pfam_target,pfam_balanced,kinesin_labelled from the models that currently finished training based on esm:
    - data/esm/
    - model_weights/esm1_t12_85M_UR50S_balanced_201102.pt
    - esm1_t12_85M_UR50S_kinesin_201102.pt
    - esm1_t12_85M_UR50S_motor_toolkit_201102.pt
    - kinesin_labelled.csv
    - motor_toolkit.csv
    - pfamA_motors_balanced.csv
    - pfamA_target_sub.csv

In [1]:
import torch
from argparse import Namespace
from esm.constants import proteinseq_toks
import math
import torch.nn as nn
import torch.nn.functional as F
from esm.modules import TransformerLayer, PositionalEmbedding  # noqa
from esm.model import ProteinBertModel

# model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t34_670M_UR50S")
import esm
from ych_util import prepare_mlm_mask
import pandas as pd
import time
import numpy as np

In [2]:
pfamA_random = pd.read_csv("../../data/pfamA_random_201027.csv")
motor_toolkit = pd.read_csv("../../data/esm/motor_toolkit.csv")
pfamA_balanced = pd.read_csv("../../data/esm/pfamA_motors_balanced.csv")
pfamA_target = pd.read_csv("../../data/esm/pfamA_target_sub.csv")
kinesin_labelled = pd.read_csv("../../data/esm/kinesin_labelled.csv")

In [3]:
alphabet = esm.Alphabet.from_dict(proteinseq_toks)
model_name = "esm1_t12_85M_UR50S"
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
if torch.cuda.is_available():
    print("cuda")
    model_data = torch.hub.load_state_dict_from_url(url, progress=False)
else:
    model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location=torch.device('cpu'))

pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s)
prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s)
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}

model_t12 = esm.ProteinBertModel(Namespace(**model_args), len(alphabet), padding_idx=alphabet.padding_idx)

cuda


In [4]:
paths = ["esm1_t12_85M_UR50S_balanced_201102.pt","esm1_t12_85M_UR50S_kinesin_201102.pt","esm1_t12_85M_UR50S_motor_toolkit_201102.pt"]
model_t12_weights_paths = ["../../data/esm/model_weights/"+p for p in paths]
model_t12_weights = [torch.load(p) for p in model_t12_weights_paths]
model_t12_weights.append(model_state)
len(model_t12_weights)


4

In [5]:
alphabet = esm.Alphabet.from_dict(proteinseq_toks)
model_name = "esm1_t34_670M_UR50S"
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
if torch.cuda.is_available():
    print("cuda")
    model_data = torch.hub.load_state_dict_from_url(url, progress=False)
else:
    model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location=torch.device('cpu'))

pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s)
prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s)
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}

model_t34 = esm.ProteinBertModel(Namespace(**model_args), len(alphabet), padding_idx=alphabet.padding_idx)
model_t34.load_state_dict(model_state)

cuda


<All keys matched successfully>

In [6]:
# model.load_state_dict()

In [7]:
print_every = 2000
def generate_embedding_transformer_t12(model,batch_converter,weight,weight_name, dat,dat_name,out_dir,seq_col):
    # initialize network 
    model.cuda()
    model.load_state_dict(weight)
    print("weight is: " + weight_name)
    print("output embedding for " + dat_name)
    
    sequence_embeddings = []
    for epoch in range(dat.shape[0]):
        data = [(dat.iloc[epoch, 1], dat.iloc[epoch, seq_col])]
        _, _, batch_tokens = batch_converter(data)
        with torch.no_grad():
            results = model(batch_tokens.to('cuda'), repr_layers=[12])
            # last layer
            token_embeddings = results["representations"][12]
            seq = dat.iloc[epoch,seq_col]
            sequence_embeddings.append(token_embeddings[0, 1:len(seq) + 1].mean(0).cpu().detach().numpy())
        if epoch % print_every == 0:
            print(f"At Epoch: %.2f"% epoch)
            print(seq)
    sequence_embeddings = np.array(sequence_embeddings)
    print(sequence_embeddings.shape)
    print(out_dir + weight_name  + '/' + dat_name + ".npy")
    np.save(out_dir + weight_name  + '/' + dat_name + ".npy", sequence_embeddings)
    return 



In [8]:
model_t12_weights_label = ["balanced","kinesin","motor_toolkit","raw"]

In [9]:
batch_converter = alphabet.get_batch_converter()

In [10]:
out_dir = "../../out/201102/embedding/t12/"

In [11]:
data = [pfamA_random,motor_toolkit,pfamA_balanced,pfamA_target,kinesin_labelled]
data_names = ["pfamA_random","motor_toolkit","pfamA_balanced","pfamA_target","kinesin_labelled"]
seq_cols = [2,7,3,3,7]

In [12]:
for i in range(len(model_t12_weights)):
    weight = model_t12_weights[i]
    weight_name = model_t12_weights_label[i]
    for j in range(len(data)):
        dat = data[j]
        dat_name = data_names[j]
        seq_col = seq_cols[j]
        generate_embedding_transformer_t12(model_t12,batch_converter,weight,weight_name,dat,dat_name,out_dir,seq_col)

weight is: balanced
output embedding for pfamA_random
At Epoch: 0.00
NVVYVGNKEVMSYVLAVTTQFNEGSDEVVIKARGRAISTAVDTAEVVRNRFLEDVEVEDIKIST
(1600, 768)
../../out/201102/embedding/t12/balanced/pfamA_random.npy
weight is: balanced
output embedding for motor_toolkit
At Epoch: 0.00
MASQPNSSAKKKEEKGKNIQVVVRCRPFNLAERKASAHSIVECDPVRKEVSVRTGGLADKSSRKTYTFDMVFGASTKQIDVYRSVVCPILDEVIMGYNCTIFAYGQTGTGKTFTMEGERSPNEEYTWEEDPLAGIIPRTLHQIFEKLTDNGTEFSVKVSLLEIYNEELFDLLNPSSDVSERLQMFDDPRNKRGVIIKGLEEITVHNKDEVYQILEKGAAKRTTAATLMNAYSSRSHSVFSVTIHMKETTIDGEELVKIGKLNLVDLAGSENIGRSGAVDKRAREAGNINQSLLTLGRVITALVERTPHVPYRESKLTRILQDSLGGRTRTSIIATISPASLNLEETLSTLEYAHRAKNILNKPEVNQKLTKKALIKEYTEEIERLKRDLAAAREKNGVYISEENFRVMSGKLTVQEEQIVELIEKIGAVEEELNRVTELFMDNKNELDQCKSDLQNKTQELETTQKHLQETKLQLVKEEYITSALESTEEKLHDAASKLLNTVEETTKDVSGLHSKLDRKKAVDQHNAEAQDIFGKNLNSLFNNMEELIKDGSSKQKAMLEVHKTLFGNLLSSSVSALDTITTVALGSLTSIPENVSTHVSQIFNMILKEQSLAAESKTVLQELINVLKTDLLSSLEMILSPTVVSILKINSQLKHIFKTSLTVADKIEDQKKELDGFLSILCNNLHELQENTICSLVESQKQCGNLTEDLKTIKQTHSQELCKLMN

(3255, 768)
../../out/201102/embedding/t12/kinesin/motor_toolkit.npy
weight is: kinesin
output embedding for pfamA_balanced
At Epoch: 0.00
HQDNVHARSLMGLVRNVFEQAGLEKTALDAVAVSSGPGSYTGLRIGVSVAKGLAYALDKPVIGVGTLEALAFRAIPFSDSTDTIIPMLDARRMEVYALVMDGLGDTLISPQPFILEDNPFMEYLEKGKVFFLGDGVPKSKEILSHPNSRFVPLFNSSQSIGELAYKKFLKADFESLAYFEPNYIKEFRI
At Epoch: 2000.00
ARKIGIDLGTTNLLICVDNKGILVDEPSIITVDATTKKCIAAGLDARDMLGRTPKNMICIRPLKDGVVADFEATDMMLNYFLKKCDLKGMFKKNVILICHPTKITSVEKNAIRDCAYRAGAKKVYLEEEPKIAALGAGLDIGKASGNMVLDIGGGTSDIAVLSLGDIVCSTSIKTAGNKITQDILENVRIQKKMYIGEQTADEIKRRIANALVVKEPETITISGRDVETGLPHSIDINSNEVESYIRSSLQEIVHATKTILEVTPPELAADIVQHGLVLTGGGALLKNLDQLMRNELQIPVYVAENALKCVVDGCTIMLQNL
At Epoch: 4000.00
HIAVDIGGSLAKLVYFSRDPTSKELGGRLNFLKFETARIDECIDFLRKLKLKYEIINGSRPSDLCVMATGGGAFKYYDEIKGALEVEVVREDEMECLIIGLDFFITEIPHEVFTYSQEEPMRFIAARPNIYPYLLVNIGSGVSMVKVSGPRQYERVGGTSLGGGTLWGLLSLLTGARTFEDMLSLAERGDNTAVDMLVGDIYGSGYGKIGLKSTTIASSFGKVYKMKRQAEQEAEDTGNLKEDSSQEHGRSFKSEDISKSLLYAVSNNIGQIAYLHAEKHNLEHIYFGGSFIGGHPQTMHTLSYAIKFWSKG

At Epoch: 14000.00
GVAIMSTGYGEGENRVKHAIDEALHSPLLNNDDIFNSKKVLLSITFCAKDQDQLTMEEMNEINDFMTKFGEDVETKWGVATDDTLEKKVKITVLATGFG
At Epoch: 16000.00
GMAMMGSGFAQGIDRARLATEQAISSPFLDDVTLDGARGILVNITTAPGCLKMSEYREIMKAVNANAHPDAECKVGTAEDDSMSEDAIRVTIIATGLK
(18000, 768)
../../out/201102/embedding/t12/motor_toolkit/pfamA_balanced.npy
weight is: motor_toolkit
output embedding for pfamA_target
At Epoch: 0.00
IIKVLGVGGGGSNAVTHMFRQGIVGVDFAICNTDSQAMELSPVTTRIQLGPNLTEGRGAGSKPNIGKMACEESIEAVKAYLENNCRMLFITAGMGGGTGTGAAPIIAKTAKEMDILTVGIVTLPFTFEGRRRTSQGFEGLEELKKNVDTLIVISNDKLRQIHG
At Epoch: 2000.00
EVRTGTYRQLFSPSNLITGKEDAANNYARGHYTVGKEQKNVVVEAIRKQQEMCSGLQGFLIFHSFGGGTGSGFGSLLLEELSVEYPKKSKLCFSILSSTTASLRFDGAL
At Epoch: 4000.00
GVAMIGLGEADSDAKAADSVQSALRSPLLDVDISSANSALVNVTGGPGMSIEEAEGVVEQLYDRIDPDARIIWGTSIDEQIQEEMRTMVVVTGVD
(5544, 768)
../../out/201102/embedding/t12/motor_toolkit/pfamA_target.npy
weight is: motor_toolkit
output embedding for kinesin_labelled
At Epoch: 0.00
MASQPNSSAKKKEEKGKNIQVVVRCRPFNLAERKASAHSIVECDPVRKEVSVRT

In [13]:
out_dir = "../../out/201102/embedding/t34/"

In [14]:
print_every = 2000
def generate_embedding_transformer_t34(model,batch_converter, dat,dat_name,out_dir,seq_col):
    # initialize network 
    model.cuda()
    print("output embedding for " + dat_name)
    
    sequence_embeddings = []
    for epoch in range(dat.shape[0]):
        data = [(dat.iloc[epoch, 1], dat.iloc[epoch, seq_col])]
        _, _, batch_tokens = batch_converter(data)
        with torch.no_grad():
            results = model(batch_tokens.to('cuda'), repr_layers=[34])
            # last layer
            token_embeddings = results["representations"][34]
            seq = dat.iloc[epoch,seq_col]
            sequence_embeddings.append(token_embeddings[0, 1:len(seq) + 1].mean(0).cpu().detach().numpy())
        if epoch % print_every == 0:
            print(f"At Epoch: %.2f"% epoch)
            print(seq)
    sequence_embeddings = np.array(sequence_embeddings)
    print(sequence_embeddings.shape)
    print(out_dir + "t34_" + dat_name + ".npy")
    np.save(out_dir + "t34_" + dat_name + ".npy", sequence_embeddings)
    return 

In [15]:
for j in range(len(data)):
    dat = data[j]
    dat_name = data_names[j]
    seq_col = seq_cols[j]
    generate_embedding_transformer_t34(model_t34,batch_converter,dat,dat_name,out_dir,seq_col)

output embedding for pfamA_random
At Epoch: 0.00
NVVYVGNKEVMSYVLAVTTQFNEGSDEVVIKARGRAISTAVDTAEVVRNRFLEDVEVEDIKIST
(1600, 1280)
../../out/201102/embedding/t34/t34_pfamA_random.npy
output embedding for motor_toolkit
At Epoch: 0.00
MASQPNSSAKKKEEKGKNIQVVVRCRPFNLAERKASAHSIVECDPVRKEVSVRTGGLADKSSRKTYTFDMVFGASTKQIDVYRSVVCPILDEVIMGYNCTIFAYGQTGTGKTFTMEGERSPNEEYTWEEDPLAGIIPRTLHQIFEKLTDNGTEFSVKVSLLEIYNEELFDLLNPSSDVSERLQMFDDPRNKRGVIIKGLEEITVHNKDEVYQILEKGAAKRTTAATLMNAYSSRSHSVFSVTIHMKETTIDGEELVKIGKLNLVDLAGSENIGRSGAVDKRAREAGNINQSLLTLGRVITALVERTPHVPYRESKLTRILQDSLGGRTRTSIIATISPASLNLEETLSTLEYAHRAKNILNKPEVNQKLTKKALIKEYTEEIERLKRDLAAAREKNGVYISEENFRVMSGKLTVQEEQIVELIEKIGAVEEELNRVTELFMDNKNELDQCKSDLQNKTQELETTQKHLQETKLQLVKEEYITSALESTEEKLHDAASKLLNTVEETTKDVSGLHSKLDRKKAVDQHNAEAQDIFGKNLNSLFNNMEELIKDGSSKQKAMLEVHKTLFGNLLSSSVSALDTITTVALGSLTSIPENVSTHVSQIFNMILKEQSLAAESKTVLQELINVLKTDLLSSLEMILSPTVVSILKINSQLKHIFKTSLTVADKIEDQKKELDGFLSILCNNLHELQENTICSLVESQKQCGNLTEDLKTIKQTHSQELCKLMNLWTERFCALEEKCENIQKPLSSVQENIQQKSKDIVNKMTFHSQK

In [16]:
print("done")

done
