In [8]:
from transformers import BertForMaskedLM, BertModel, BertTokenizer, pipeline

## ProtBERT Protein Language Model

In [2]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)

Downloading:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/361 [00:00<?, ?B/s]

In [3]:
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")

Downloading:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### prediction of MASKed token

In [7]:
unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer)
# original sequence
#         M V L S P A D K T N V K   A    A W G K V G A H A G E Y
unmasker('M V L S P A D K T N V K [MASK] A W G K V G A H A G E Y')

[{'score': 0.10394349694252014,
  'token': 6,
  'token_str': 'A',
  'sequence': 'M V L S P A D K T N V K A A W G K V G A H A G E Y'},
 {'score': 0.0907510295510292,
  'token': 12,
  'token_str': 'K',
  'sequence': 'M V L S P A D K T N V K K A W G K V G A H A G E Y'},
 {'score': 0.07521520555019379,
  'token': 8,
  'token_str': 'V',
  'sequence': 'M V L S P A D K T N V K V A W G K V G A H A G E Y'},
 {'score': 0.06997817009687424,
  'token': 5,
  'token_str': 'L',
  'sequence': 'M V L S P A D K T N V K L A W G K V G A H A G E Y'},
 {'score': 0.06744708865880966,
  'token': 7,
  'token_str': 'G',
  'sequence': 'M V L S P A D K T N V K G A W G K V G A H A G E Y'}]

### Get embeddings

In [18]:
sequence = "M V L S P A D K T N V K A A W G K V G A H A G E Y"
model = BertModel.from_pretrained("Rostlab/prot_bert")
encoded_input = tokenizer(sequence, return_tensors='pt')
output = model(**encoded_input)

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [31]:
embeddings = output.last_hidden_state
embeddings[0, :27]

tensor([[ 1.1234e-01,  3.8124e-02,  4.5919e-02,  ..., -4.7524e-02,
         -3.0284e-02,  2.1404e-02],
        [ 1.0622e-01,  8.6489e-02,  1.3534e-01,  ...,  2.9468e-02,
         -9.1207e-02, -1.0648e-01],
        [ 1.4300e-02, -2.4049e-02,  9.7432e-02,  ..., -3.1925e-02,
          3.6396e-02,  1.0129e-01],
        ...,
        [ 1.8457e-01, -2.3567e-02,  1.0809e-02,  ...,  6.4580e-02,
         -1.7176e-01,  7.9359e-02],
        [ 1.5196e-01,  2.1889e-02,  2.6714e-02,  ...,  9.3726e-02,
          7.4341e-03,  1.4771e-01],
        [ 1.3322e-01,  4.9643e-02,  4.4365e-02,  ..., -4.4930e-02,
         -4.8270e-02,  6.1550e-05]], grad_fn=<SelectBackward0>)

In [25]:
embeddings.size()

torch.Size([27, 1024])

In [26]:
embeddings[0].mean(dim=0)

tensor([ 0.1047,  0.0174,  0.0590,  ..., -0.0019, -0.0286,  0.0559],
       grad_fn=<MeanBackward1>)

In [27]:
embeddings[0].mean(dim=0).size()

torch.Size([1024])