### Setup

We will be using the BERT model and tokenizer from Hugging Face's transformers library. Additionally, the regular expressions module is used to preprocess our protein sequence to make it compatible with the model's expected input format.

We then load a tokenizer for protein sequences from the model hub, and then load the pre-trained protein sequence BERT model which acts as our encoder.

In [1]:
from transformers import BertModel, BertTokenizer, BertConfig
import re
import torch as torch
import torch.onnx

tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = BertModel.from_pretrained("Rostlab/prot_bert")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = BertConfig.from_pretrained("Rostlab/prot_bert")


40000

### Preprocessing

We'll define a sample protein sequence and preprocess it. Specifically, the U, Z, O, and B amino acids are much less common. For the purposes of standardizing input for the model, they are replaced by "X".

In [3]:
sequence_Example = "A E T C Z A O"
sequence_Example = re.sub(r"[UZOB]", "X", sequence_Example)

### Tokenizing

Now, we tokenize the sequence to convert it into a format that the BERT model understands. We use PyTorch tensors (`pt`) as that's the format the model expects.

In [5]:
tokens = tokenizer(sequence_Example,truncation=True, max_length=512,return_tensors='pt')

In [6]:
tokens

{'input_ids': tensor([[ 2,  6,  9, 15, 23, 25,  6, 25,  3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}

### Fetching Embeddings from BERT

With the processed sequence, fetch the embeddings or representations from the BERT model. These embeddings can be utilized for further analysis, such as classification.

In [14]:

encodings=model(**tokens).last_hidden_state.mean(dim=1)

In [15]:
encodings


tensor([[ 0.0596,  0.0577, -0.0590,  ..., -0.0516, -0.0697,  0.0888],
        [ 0.0621,  0.0518, -0.0627,  ..., -0.0475, -0.0510,  0.0878],
        [-0.0183,  0.0620, -0.1026,  ..., -0.0218, -0.0510,  0.0492],
        [ 0.0062,  0.0601, -0.0934,  ..., -0.0455, -0.0555,  0.0700],
        [-0.0242,  0.1097, -0.1192,  ..., -0.0284, -0.0762,  0.0547]],
       grad_fn=<MeanBackward1>)

In [55]:
encodings[1]

tensor([ 0.0639,  0.0582, -0.0569,  ..., -0.0532, -0.0593,  0.0873],
       grad_fn=<SelectBackward0>)

In [None]:
torch.onnx.export(model,             
                      args=(tokens['input_ids'],),
                      f="protein_embedding.onnx",   
                      input_names=['input_ids'],   
                      output_names=['protein_embeddings'], 
                      opset_version=11)     