# ESM Protein Folding using HuggingFace

The [ESMFold protein language model](https://github.com/facebookresearch/esm) will fold proteins based only on the amino acid sequence. Accurate protein folding is critical to the design and optimization of new medicines. 

We'll use [HuggingFace esmfold_v1](https://huggingface.co/facebook/esmfold_v1) in this example, but there are many other protein folding models out there. This example also demonstrates how to handle multimer (proteins with multiple chains) predictions.

This code ran on a Microsoft Surface Laptop with an NVIDIA RTX 2000 Ada using [WSL2 Ubuntu](https://learn.microsoft.com/en-us/windows/wsl/install) (8 GB GPU RAM, Driver Version: 537.58, CUDA Version: 12.2). It took ~ 11 minutes.

On a Ubuntu 24.04 with an NVidia RTX A2000 GPU it took ~ 40 seconds.

```
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 537.58       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX 2000 Ada Gene...    On  | 00000000:01:00.0 Off |                  N/A |
| N/A   54C    P1              30W /  60W |   7950MiB /  8188MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     71450      C   /python3.12                               N/A      |
+---------------------------------------------------------------------------------------+
```

In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

## Load HuggingFace model

In [3]:
from transformers import AutoTokenizer, EsmForProteinFolding

In [4]:
model_name = "facebook/esmfold_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForProteinFolding.from_pretrained(model_name, low_cpu_mem_usage=True)

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Put tensor(s) on the desired hardware device. If CUDA (GPU) is available, then use that. If not, then use CPU.

In [5]:
device_name = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_name)

# If you are on limited GPU resources (< 16 GB of GPU RAM), then use this.
if device_name == "cuda":

    gpu_ram_size = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU has {gpu_ram_size:.2f} GB of RAM")
    model.esm = model.esm.half()

    torch.backends.cuda.matmul.allow_tf32 = True
    # Use chunks if your GPU memory is 16GB or less
    if gpu_ram_size < 16:
        model.trunk.set_chunk_size(8)

GPU has 12.62 GB of RAM


In [6]:
model = model.to(device)

## Multimers

If the protein consists of multiple chains (multimers), then connect them as one long sequence string by inserting a chain of "G" in between.
Glycine (G) is chosen because it is the smallest amino acid. We aribtrarily assign a Glycine sequence of 25 residues (but 26, 27, 28, 29, etc would also work).

In this example, we'll use antibody [3HFM](https://www.rcsb.org/sequence/3HFM) which has heavy (H) and light (L) chains.

In [7]:
chain_H = "DVQLQESGPSLVKPSQTLSLTCSVTGDSITSDYWSWIRKFPGNRLEYMGYVSYSGSTYYNPSLKSRISITRDTSKNQYYLDLNSVTTEDTATYYCANWDGDYWGQGTLVTVSAAKTTPPSVYPLAPGSAAQTNSMVTLGCLVKGYF"
chain_L = "DIVLTQSPATLSVTPGNSVSLSCRASQSIGNNLHWYQQKSHESPRLLIKYASQSISGIPSRFSGSGSGTDFTLSINSVETEDFGMYFCQQSNSWPYTFGGGTKLEIKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYPKDINVKWKIDGSERQNGVLNSWTDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNEC"

linker_sequence = "G" * 25  # Put G linker in between chains (hide it later)

multimer_sequence = chain_H + linker_sequence + chain_L

Tokenize the input sequence string

In [8]:
tokenized_multimer = tokenizer(
    [multimer_sequence], return_tensors="pt", add_special_tokens=False
)

Renumber the positions of the second chain so that the model knows that the second chain is not really connected. 
Note the 512 is just any number that is significantly greater than the length of the first chain. 
So there is nothing special about 512 other than it is much larger than `len(chain_H)`.

In [9]:
with torch.no_grad():
    position_ids = torch.arange(len(multimer_sequence), dtype=torch.long)
    position_ids[len(chain_H) + len(linker_sequence) :] += 512

In [10]:
tokenized_multimer["position_ids"] = position_ids.unsqueeze(0)

tokenized_multimer = {
    key: tensor.to(device) for key, tensor in tokenized_multimer.items()
}

### Use the model to predict the 3D structure from the sequence input

In [11]:
from datetime import datetime

start_time = datetime.now()
print(f"Model inference started at: {start_time}")

Model inference started at: 2024-07-01 16:01:14.081113


In [12]:
with torch.no_grad():
    output = model(**tokenized_multimer)

In [13]:
stop_time = datetime.now()

print(f"Model inference stopped at: {stop_time}")
print(f"Elapsed time = {stop_time - start_time}")

Model inference stopped at: 2024-07-01 16:01:51.573942
Elapsed time = 0:00:37.492829


In [14]:
output.keys()

odict_keys(['frames', 'sidechain_frames', 'unnormalized_angles', 'angles', 'positions', 'states', 's_s', 's_z', 'distogram_logits', 'lm_logits', 'aatype', 'atom14_atom_exists', 'residx_atom14_to_atom37', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index', 'lddt_head', 'plddt', 'ptm_logits', 'ptm', 'aligned_confidence_probs', 'predicted_aligned_error', 'max_predicted_aligned_error'])

## Mask the linker section in the output

This will mask the linker section so that when the PDB file is created from the output tensor, then it will hide the GGGGG linker that we put in to separate the chains.
So a 1 means that the residue is visible and a 0 means that a residue is masked.

In [15]:
linker_mask = torch.tensor(
    [1] * len(chain_H) + [0] * len(linker_sequence) + [1] * len(chain_L)
)[None, :, None]

output["atom37_atom_exists"] = output["atom37_atom_exists"] * linker_mask.to(
    output["atom37_atom_exists"].device
)

### Convert ESMFold output tensor to a PDB file

In [16]:
import typing as T

import numpy as np
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from transformers.models.esm.openfold_utils.protein import Protein as OFProtein
from transformers.models.esm.openfold_utils.protein import to_pdb


def convert_outputs_to_pdb(output: T.Dict, len_chain_H: int) -> T.List[str]:
    """Convert from the ESMFold model output to a PDB-formatted protein string

    atom14_to_atom37 must be called first, as it fails on latest numpy if the
     input is a numpy array. It will work if the input is a torch tensor.

    Args:
        output: Output tensor of HuggingFace ESMFold model
        len_chain_H(int): Length of the H chain

    Returns:
        String with formatted PDB of protein structure

    """

    import string
    import warnings

    warnings.simplefilter("ignore")

    from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
    from transformers.models.esm.openfold_utils.protein import Protein as OFProtein
    from transformers.models.esm.openfold_utils.protein import to_pdb

    """Returns the pbd (file) string from the model given the model output."""
    # atom14_to_atom37 must be called first, as it fails on latest numpy if the
    # input is a numpy array. It will work if the input is a torch tensor.
    final_atom_positions = atom14_to_atom37(output["positions"][-1], output)

    # To change the chain labels we need to create
    # a list of integers the length of the unmasked (H and L) chains.
    # (The linker is masked so it doesn't get outputted)
    # You can change this to rename the chains whatever letter you want.
    chain_index_temp = np.ones(len(final_atom_positions[0]), dtype=int)
    chain_index_temp[:len_chain_H] *= string.ascii_uppercase.index("H")
    chain_index_temp[len_chain_H:] *= string.ascii_uppercase.index("L")

    output = {k: v.to("cpu").numpy() for k, v in output.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = output["atom37_atom_exists"]
    pdbs = []
    for i in range(output["aatype"].shape[0]):
        aa = output["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = output["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=output["plddt"][i],
            chain_index=chain_index_temp,
        )
        pdbs.append(to_pdb(pred))
    return pdbs[0]

### Convert the model output

Convert the model output to a PDB structure

In [17]:
pdb = convert_outputs_to_pdb(output, len(chain_H))

### Display the protein folding prediction

In [18]:
import py3Dmol

view = py3Dmol.view(js="https://3dmol.org/build/3Dmol.js", width=800, height=400)
view.addModel("".join(pdb), "pdb")
view.setStyle({"model": -1}, {"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()

## Write the PDB to a file

In [19]:
prediction_pdb_filename = "prediction_3HFM.pdb"

In [20]:
with open(prediction_pdb_filename, "w") as fptr:
    fptr.write("HEADER    COMPLEX(ANTIBODY-ANTIGEN)   3HFM ESMFold prediction\n")
    fptr.write(pdb)

You can compare the prediction with the real structure by using [USAlign](https://zhanggroup.org/US-align/).

The real 3HFM PDB structure is on [RCSB](https://files.rcsb.org/download/3HFM.pdb). Note that this structure also includes the antigen. So only two of the 3 chains will align (but USalign will compensate for that).

> Aligned length= 360, RMSD=   2.12, Seq_ID=n_identical/n_aligned= 0.958
> TM-score= 0.92993 (normalized by length of Structure_1: L=360, d0=6.90)
> TM-score= 0.61191 (normalized by length of Structure_2: L=558, d0=8.32)

In [21]:
%%sh 
wget --no-clobber https://zhanggroup.org/US-align/bin/module/USalign.cpp
wget --no-clobber https://files.rcsb.org/download/3HFM.pdb
if [ ! -f USalign ]; then
    g++ -v -static -O3 -ffast-math -o USalign USalign.cpp
fi

--2024-07-01 16:01:51--  https://zhanggroup.org/US-align/bin/module/USalign.cpp
Resolving zhanggroup.org (zhanggroup.org)... 141.213.137.249
Connecting to zhanggroup.org (zhanggroup.org)|141.213.137.249|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 698457 (682K) [text/plain]
Saving to: ‘USalign.cpp’

     0K .......... .......... .......... .......... ..........  7%  413K 2s
    50K .......... .......... .......... .......... .......... 14%  818K 1s
   100K .......... .......... .......... .......... .......... 21% 22.8M 1s
   150K .......... .......... .......... .......... .......... 29%  850K 1s
   200K .......... .......... .......... .......... .......... 36%  102M 0s
   250K .......... .......... .......... .......... .......... 43% 29.6M 0s
   300K .......... .......... .......... .......... .......... 51%  853K 0s
   350K .......... .......... .......... .......... .......... 58%  146M 0s
   400K .......... .......... .......... .......... ..........

In [22]:
!./USalign -mm 1 -ter 0 $prediction_pdb_filename 3HFM.pdb -o superimposed


 ********************************************************************
 * US-align (Version 20240602)                                      *
 * Universal Structure Alignment of Proteins and Nucleic Acids      *
 * Reference: C Zhang, M Shine, AM Pyle, Y Zhang. (2022) Nat Methods*
 *            C Zhang, AM Pyle (2022) iScience.                     *
 * Please email comments and suggestions to zhang@zhanggroup.org    *
 ********************************************************************

Name of Structure_1: prediction_3HFM.pdb:1,H:1,L: (to be superimposed onto Structure_2)
Name of Structure_2: 3HFM.pdb:1,H:1,L:1,Y
Length of Structure_1: 360 residues
Length of Structure_2: 558 residues

Aligned length= 360, RMSD=   2.12, Seq_ID=n_identical/n_aligned= 0.958
TM-score= 0.92995 (normalized by length of Structure_1: L=360, d0=6.90)
TM-score= 0.61192 (normalized by length of Structure_2: L=558, d0=8.32)
(You should use TM-score normalized by length of the reference structure)

(":" denotes re

In [23]:
view = py3Dmol.view(js="https://3dmol.org/build/3Dmol.js", width=800, height=400)
view.addModel(open("superimposed.pdb").read(), "pdb")
view.setStyle({"model": -1}, {"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()