# ESM2 Embedder using HuggingFace

[ESM2 protein language model](https://github.com/facebookresearch/esm) to embed protein sequences into an embedding space.


In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

## Load HuggingFace model

In [3]:
from transformers import AutoTokenizer, EsmModel

In [4]:
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name)


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Put tensor(s) on the best hardware device. If CUDA (GPU) is available, then use that. If not, then use CPU.

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


## Multimers

If the protein consists of multiple chains (multimers), then connect them as one long sequence string by inserting a chain of "G" in between.

In [6]:
chain_A = "MRLIPLHNVDQVAKWSARYIVDRINQFQPTEARPFVLGLPTGGTPLKTYEALIELYKAGEVSFKHVVTFNMDEYVGLPKEHPESYHSFMYKNFFDHVDIQEKNINILNGNTEDHDAECQRYEEKIKSYGKIHLFMGGVGVDGHIAFNEPASSLSSRTRIKTLTEDTLIANSRFFDNDVNKVPKYALTIGVGTLLDAEEVMILVTGYNKAQALQAAVEGSINHLWTVTALQMHRRAIIVCDEPATQELKVKTVKYFTELEASAIRSVK"
chain_B = "HPESYHSFMYKNFFDHVDIQEKRTTDINRT"

linker_sequence = "G" * 25  # Put G linker in between chains (remove it later)

multimer_sequence = chain_A + linker_sequence + chain_B

Tokenize the input sequence string

In [7]:
tokenized_multimer = tokenizer([multimer_sequence], return_tensors="pt", add_special_tokens=False)

Renumber the positions of the second chain so that the model knows that the second chain is not really connected. 

In [8]:
with torch.no_grad():
    position_ids = torch.arange(len(multimer_sequence), dtype=torch.long)
    position_ids[len(chain_A) + len(linker_sequence):] += 512

In [9]:
tokenized_multimer['position_ids'] = position_ids.unsqueeze(0)

tokenized_multimer = {key: tensor.to(device) for key, tensor in tokenized_multimer.items()}

In [10]:
with torch.no_grad():
    output = model(**tokenized_multimer)

In [11]:
output.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

The embeddings are the last hidden state of the model.

In [12]:
output["last_hidden_state"].shape

torch.Size([1, 322, 320])

In [13]:
position_ids.shape

torch.Size([322])