# Zero Shot Protein Design using ESM2

In [1]:
from bionemo.api import BionemoClient
from bionemo.request_id import RequestId
import torch
from transformers import AutoTokenizer
import pandas as pd


# Credentials

In [2]:
# use NGC api key. Must have access to BioNeMo service (not the BioNeMo framework)
api = BionemoClient(
    api_key='YOUR_API_KEY',
)

# Prepare data

In [3]:
# sequences
# lsozyme
protein1 = "MKALIVLGLVLLSVTVQGKVFERCELARTLKRLGMDGYRGISLANWMCLAKWESGYNTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRDPQGIRAWVAWRNRCQNRDVRQYVQGCGV"

# EGFR
protein2 = "MRPSGTAGAALLALLAALCPASRALEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQMDVNPEGKYSFGATCVKKCPRNYVVTDHGSCVRACGADSYEMEEDGVRKCKKCEGPCRKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGDSFTHTPPLDPQELDILKTVKEITGFLLIQAWPENRTDLHAFENLEIIRGRTKQHGQFSLAVVSLNITSLGLRSLKEISDGDVIISGNKNLCYANTINWKKLFGTSGQKTKIISNRGENSCKATGQVCHALCSPEGCWGPEPRDCVSCRNVSRGRECVDKCNLLEGEPREFVENSECIQCHPECLPQAMNITCTGRGPDNCIQCAHYIDGPHCVKTCPAGVMGENNTLVWKYADAGHVCHLCHPNCTYGCTGPGLEGCPTNGPKIPSIATGMVGALLLLLVVALGIGLFMRRRHIVRKRTLRRLLQERELVEPLTPSGEAPNQALLRILKETEFKKIKVLGSGAFGTVYKGLWIPEGEKVKIPVAIKELREATSPKANKEILDEAYVMASVDNPHVCRLLGICLTSTVQLITQLMPFGCLLDYVREHKDNIGSQYLLNWCVQIAKGMNYLEDRRLVHRDLAARNVLVKTPQHVKITDFGLAKLLGAEEKEYHAEGGKVPIKWMALESILHRIYTHQSDVWSYGVTVWELMTFGSKPYDGIPASEISSILEKGERLPQPPICTIDVYMIMVKCWMIDADSRPKFRELIIEFSKMARDPQRYLVIQGDERMHLPSPTDSNFYRALMDEEDMDDVVDADEYLIPQQGFFSSPSTSRTPLLSSLSATSNNSTVACIDRNGLQSCPIKEDSFLQRYSSDPTGALTEDSIDDTFLPVPEYINQSVPKRPAGSVQNPVYHNQPLNPAPSRDPHYQDPHSTAVGNPEYLNTVQPTCVNSTFDSPAHWAQKGSHQISLDNPDYQQDFFPKEAKPNGIFKGSTAENAEYLRVAPQSSEFIGA"

print("Length of Protein 1: ", len(protein1))
print("Length of Protein 2: ", len(protein2))

Length of Protein 1:  148
Length of Protein 2:  1210


Note that the Protein 2 has a sequence longer than 1024 amino acid. This will result in an error during call. Truncate it to 1024 amino acid.

In [4]:
max_input_length = 1024
protein2 = protein2[:max_input_length]

# ESM2

Current version only has `sync` prediction, but no `async`

## Submit request

In [5]:
result = api.esm2_sync(
    sequences=[protein1, protein2], 
    model="650m", # 650m, 3b, 15b 
)


## Results

In [6]:
# lengths of resullt = number of proteins
len(result)

2

In [7]:
protein_id = 0
result[protein_id].keys()

dict_keys(['embeddings', 'logits', 'tokens', 'representations'])

### Per sequence embedding

Get per sequence embedding. Shape = (Number_of_features, )


In [8]:
result[protein_id]['embeddings'].shape

(1280,)

### Per residue embedding

Get per residue embedding. Shape = (Number_of_residues, Number_of_features)

In [9]:
# logits
result[protein_id]['representations'].shape

(148, 1280)

note that the sequences in the same requests are NOT padded to the longest length in the submission. 

## Zero-shot protein design workflow

The model's probability prediction for each token at each position. Can be used for zero-shot protein/antibody designs as described by [papers like this](https://www.nature.com/articles/s41587-023-01763-2):  

The tokens used in ESM2nv-650M is same as in the [huggingface hub](https://huggingface.co/facebook/esm2_t33_650M_UR50D/blob/main/vocab.txt). For example, here we have 33 tokens. Hence the logits has shape (Number_of_residues, 33)

In [10]:
result[protein_id]['logits'].shape

(148, 33)

The model assigns a probability to each of the tokens at each position in the protein. The token with the highest probability is the predicted token. You can retrieve the probability using the `softmax` function.

In [11]:
probs = torch.softmax(torch.tensor(result[protein_id]['logits']), dim=-1)
probs.shape

torch.Size([148, 33])

Get the token with the highest probability: 

In [12]:
preds = torch.argmax(probs, axis =-1)
preds

tensor([20, 15,  5,  4, 12,  7,  4,  6,  4,  7,  4,  4,  8,  7, 11,  7, 16,  6,
        15,  7, 18,  9, 10, 23,  9,  4,  5, 10, 11,  4, 15, 10,  4,  6, 20, 13,
         6, 19, 10,  6, 12,  8,  4,  5, 17,  4, 20, 23,  4,  5, 15, 22,  9,  8,
         6, 19, 17, 11, 10,  5, 11, 17, 19, 17,  5,  6, 13, 10,  8, 11,  7, 19,
         6, 12, 18, 16, 12, 17,  8, 10, 19, 22, 23, 17, 13,  6, 15, 11, 14,  6,
         5,  7, 17,  5, 23, 21,  4,  8, 23,  8,  5,  4,  4, 16, 13, 17, 12,  5,
        13,  5,  7,  5, 23,  5, 15, 10,  7,  7, 10, 13, 14, 16,  6, 12, 10,  5,
        22,  7,  5, 22, 10, 17, 10, 23, 16, 17, 10, 13,  7, 10, 16, 19,  7, 16,
         6, 23,  6,  7])

To conver the token back to amino acids, use tokenizer: 

In [13]:
# load the tokenizer, for example, 650M 
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") # change it to a different tokenizer if it is not ESM2nv-650M

In [14]:
# look at the predicted sequence
protein1_pred = ''.join(tokenizer.convert_ids_to_tokens(preds))
protein1_pred

'MKALIVLGLVLLSVTVQGKVFERCELARTLKRLGMDGYRGISLANLMCLAKWESGYNTRATNYNAGDRSTVYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRDPQGIRAWVAWRNRCQNRDVRQYVQGCGV'

Now we can align the original protein 1 and the predicted protein 1 sequence. An exmaple is shown below 

![image](https://github.com/xinyu-dev/bionemo-demo/blob/main/assets/images/notebook_lysozyme_zero_shot.png?raw=true)

Note that in most positions, the model predicted the same amino acid as the original input. This is expected for a protein that is evolved through the natural selection process. 

The model proposed 2 mutations: 
- W46L
- D71V

Single and/or combo mutations can the be designed for testing in wet lab. 

You can also view the complete probability matrix like this: 

In [15]:
# get vocab
vocab = tokenizer.get_vocab()
probs = pd.DataFrame(probs.numpy())
probs.columns = list(vocab.keys())

probs

Unnamed: 0,<cls>,<pad>,<eos>,<unk>,L,A,G,V,S,E,...,C,X,B,U,Z,O,.,-,<null_1>,<mask>
0,8.093410e-08,1.267929e-09,1.089087e-07,1.267929e-09,0.009205,0.005532,0.008580,0.010299,0.004118,0.003821,...,0.002218,0.000059,1.713375e-07,9.171035e-08,6.709672e-08,1.482360e-09,6.475950e-10,7.571153e-10,7.512233e-10,1.248272e-09
1,1.896748e-07,8.251510e-10,1.524080e-07,7.997638e-10,0.082292,0.051321,0.029507,0.060990,0.033929,0.021998,...,0.012315,0.000009,7.544743e-08,6.353311e-08,4.298865e-08,5.551461e-09,6.644229e-09,3.163127e-09,5.465393e-09,8.123582e-10
2,5.547430e-08,1.098609e-09,1.020329e-07,1.098609e-09,0.041994,0.703366,0.021825,0.030682,0.022905,0.012020,...,0.008893,0.000017,1.555815e-07,1.394624e-07,8.864767e-08,9.945993e-09,5.985605e-09,5.756299e-09,4.625317e-09,1.081577e-09
3,5.958092e-08,4.620696e-10,4.862853e-08,4.549059e-10,0.805728,0.022834,0.016320,0.021736,0.017519,0.009459,...,0.005498,0.000009,4.497404e-08,5.958092e-08,2.369965e-08,2.700901e-09,1.401209e-09,1.468454e-09,2.153346e-09,4.549059e-10
4,1.463436e-07,1.703751e-09,1.632581e-07,1.651332e-09,0.065732,0.042089,0.032799,0.047821,0.039308,0.023232,...,0.012484,0.000035,1.893827e-07,2.489389e-07,1.291478e-07,6.330204e-09,4.631277e-09,4.055281e-09,4.216825e-09,1.677337e-09
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
143,8.465729e-07,3.889901e-09,1.581606e-06,3.770221e-09,0.036411,0.032623,0.040649,0.041390,0.022561,0.018385,...,0.030646,0.000190,7.018336e-07,7.184771e-07,3.341229e-07,6.216045e-09,5.442951e-09,5.615729e-09,5.839434e-09,3.951159e-09
144,1.856874e-07,2.527396e-09,3.579217e-07,2.488212e-09,0.027457,0.029000,0.693233,0.029041,0.019126,0.012211,...,0.018875,0.000119,3.692834e-07,1.915817e-07,1.563644e-07,5.308859e-09,4.871683e-09,4.366948e-09,3.648722e-09,2.607624e-09
145,5.626312e-07,6.653376e-09,8.714202e-07,6.550224e-09,0.046078,0.029336,0.037663,0.044956,0.027696,0.017341,...,0.558753,0.000288,8.579102e-07,4.116298e-07,2.457954e-07,1.114231e-08,9.382764e-09,3.094107e-09,5.304534e-09,6.550224e-09
146,3.176321e-07,1.719690e-09,1.227158e-06,1.693028e-09,0.028059,0.019421,0.701208,0.039328,0.018682,0.010110,...,0.025169,0.000390,9.335728e-07,4.479331e-07,2.892073e-07,5.727441e-09,4.937363e-09,5.465163e-09,3.998399e-09,1.746771e-09


We can then take a look at amino acid 46 again. This time, we can get the top 3 predictions

In [16]:
pos = 46 # amino acid position, starting form 1

# get the top 3 predictions
probs.loc[pos-1, :].sort_values(ascending=False).head(3)

L    0.513982
W    0.256936
V    0.064647
Name: 45, dtype: float32

In the table above, 
- `L` is the amino acid with the highest probaiblity, which is what the model recommended. 
- `W` is the original amino acid. 
- `V` is the 3rd highest probability amino acid. However it has a low probability compared to the first 2 proposals. If the probablility is similar to the first two, then it might be useful to consider this proposal as well