# ESM2 Protein Folding using HuggingFace

[ESMFold protein language model](https://github.com/facebookresearch/esm) to fold protein sequences based only on the protein sequence. 

This also demonstrates how to handle multimer predictions


In [None]:
import torch

In [None]:
torch.cuda.is_available()

## Load HuggingFace model

In [None]:
from transformers import AutoTokenizer, EsmForProteinFolding

In [None]:
model_name = "facebook/esmfold_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForProteinFolding.from_pretrained(model_name, low_cpu_mem_usage=True)


Put tensor(s) on the desired hardware device. If CUDA (GPU) is available, then use that. If not, then use CPU.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

if torch.cuda.is_available():
    model.esm = model.esm.half()
    torch.backends.cuda.matmul.allow_tf32 = True
    # Use chunks if your GPU memory is 16GB or less
    model.trunk.set_chunk_size(64)

## Multimers

If the protein consists of multiple chains (multimers), then connect them as one long sequence string by inserting a chain of "G" in between.

In [None]:
chain_A = "MRLIPLHNVDQVAKWSARYIVDRINQFQPTEARPFVLGLPTGGTPLKTYEALIELYKAGEVSFKHVVTFNMDEYVGLPKEHPESYHSFMYKNFFDHVDIQEKNINILNGNTEDHDAECQRYEEKIKSYGKIHLFMGGVGVDGHIAFNEPASSLSSRTRIKTLTEDTLIANSRFFDNDVNKVPKYALTIGVGTLLDAEEVMILVTGYNKAQALQAAVEGSINHLWTVTALQMHRRAIIVCDEPATQELKVKTVKYFTELEASAIRSVK"
chain_B = "HPESYHSFMYKNFFDHVDIQEKRTTDINRTQVAKWSARYIVDRINQFQPTHVGIQEKRATDIN"

linker_sequence = "G" * 25  # Put G linker in between chains (hide it later)

multimer_sequence = chain_A + linker_sequence + chain_B

Tokenize the input sequence string

In [None]:
tokenized_multimer = tokenizer([multimer_sequence], return_tensors="pt", add_special_tokens=False)

Renumber the positions of the second chain so that the model knows that the second chain is not really connected. 

In [None]:
with torch.no_grad():
    position_ids = torch.arange(len(multimer_sequence), dtype=torch.long)
    position_ids[len(chain_A) + len(linker_sequence):] += 512

In [None]:
tokenized_multimer['position_ids'] = position_ids.unsqueeze(0)

tokenized_multimer = {key: tensor.to(device) for key, tensor in tokenized_multimer.items()}

In [None]:
with torch.no_grad():
    output = model(**tokenized_multimer)

In [None]:
output.keys()

In [None]:
linker_mask = torch.tensor([1] * len(chain_A) + [0] * len(linker_sequence) + [1] * len(chain_B))[None, :, None]

output['atom37_atom_exists'] = output['atom37_atom_exists'] * linker_mask.to(output['atom37_atom_exists'].device)

In [None]:
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdb

In [None]:
pdb = convert_outputs_to_pdb(output)

In [None]:
import py3Dmol

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=400)
view.addModel("".join(pdb), 'pdb')
view.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})