# ESM Protein Folding using HuggingFace

The [ESMFold protein language model](https://github.com/facebookresearch/esm) will fold proteins based only on the amino acid sequence. Accurate protein folding is critical to the design and optimization of new medicines. 

We'll use [HuggingFace esmfold_v1](https://huggingface.co/facebook/esmfold_v1) in this example, but there are many other protein folding models out there. This example also demonstrates how to handle multimer (proteins with multiple chains) predictions.

This code ran on a Microsoft Surface Laptop with an NVIDIA RTX 2000 Ada using [WSL2 Ubuntu](https://learn.microsoft.com/en-us/windows/wsl/install) (8 GB GPU RAM, Driver Version: 537.58, CUDA Version: 12.2). It took ~ 11 minutes.

```
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 537.58       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX 2000 Ada Gene...    On  | 00000000:01:00.0 Off |                  N/A |
| N/A   54C    P1              30W /  60W |   7950MiB /  8188MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     71450      C   /python3.12                               N/A      |
+---------------------------------------------------------------------------------------+
```

In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

## Load HuggingFace model

In [3]:
from transformers import AutoTokenizer, EsmForProteinFolding

In [4]:
model_name = "facebook/esmfold_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForProteinFolding.from_pretrained(model_name, low_cpu_mem_usage=True)

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Put tensor(s) on the desired hardware device. If CUDA (GPU) is available, then use that. If not, then use CPU.

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# If you are on limited GPU resources (< 16 GB of GPU RAM), then use this.
if torch.cuda.is_available():
    model.esm = model.esm.half()
    torch.backends.cuda.matmul.allow_tf32 = True
    # Use chunks if your GPU memory is 16GB or less
    model.trunk.set_chunk_size(64)

## Multimers

If the protein consists of multiple chains (multimers), then connect them as one long sequence string by inserting a chain of "G" in between.

In this example, we'll use antibody [3HFM](https://www.rcsb.org/sequence/3HFM) which has heavy (H) and light (L) chains.

In [6]:
chain_H = "DVQLQESGPSLVKPSQTLSLTCSVTGDSITSDYWSWIRKFPGNRLEYMGYVSYSGSTYYNPSLKSRISITRDTSKNQYYLDLNSVTTEDTATYYCANWDGDYWGQGTLVTVSAAKTTPPSVYPLAPGSAAQTNSMVTLGCLVKGYF"
chain_L = "DIVLTQSPATLSVTPGNSVSLSCRASQSIGNNLHWYQQKSHESPRLLIKYASQSISGIPSRFSGSGSGTDFTLSINSVETEDFGMYFCQQSNSWPYTFGGGTKLEIKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYPKDINVKWKIDGSERQNGVLNSWTDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNEC"

linker_sequence = "G" * 25  # Put G linker in between chains (hide it later)

multimer_sequence = chain_H + linker_sequence + chain_L

Tokenize the input sequence string

In [7]:
tokenized_multimer = tokenizer(
    [multimer_sequence], return_tensors="pt", add_special_tokens=False
)

Renumber the positions of the second chain so that the model knows that the second chain is not really connected. 

In [8]:
with torch.no_grad():
    position_ids = torch.arange(len(multimer_sequence), dtype=torch.long)
    position_ids[len(chain_H) + len(linker_sequence) :] += 512

In [9]:
tokenized_multimer["position_ids"] = position_ids.unsqueeze(0)

tokenized_multimer = {
    key: tensor.to(device) for key, tensor in tokenized_multimer.items()
}

### Use the model to predict the 3D structure from the sequence input

In [10]:
from datetime import datetime

start_time = datetime.now()
print(f"Model inference started at: {start_time}")

Model inference started at: 2024-04-08 12:28:03.990376


In [11]:
with torch.no_grad():
    output = model(**tokenized_multimer)

In [12]:
stop_time = datetime.now()

print(f"Model inference stopped at: {stop_time}")
print(f"Elapsed time = {stop_time - start_time}")

Model inference stopped at: 2024-04-08 12:38:19.069084
Elapsed time = 0:10:15.078708


In [13]:
output.keys()

odict_keys(['frames', 'sidechain_frames', 'unnormalized_angles', 'angles', 'positions', 'states', 's_s', 's_z', 'distogram_logits', 'lm_logits', 'aatype', 'atom14_atom_exists', 'residx_atom14_to_atom37', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index', 'lddt_head', 'plddt', 'ptm_logits', 'ptm', 'aligned_confidence_probs', 'predicted_aligned_error', 'max_predicted_aligned_error'])

## Mask the linker section in the output

This will mask the linker section so that when the PDB file is created from the output tensor, then it will hide the GGGGG linker that we put in to separate the chains.

In [14]:
linker_mask = torch.tensor(
    [1] * len(chain_H) + [0] * len(linker_sequence) + [1] * len(chain_L)
)[None, :, None]

output["atom37_atom_exists"] = output["atom37_atom_exists"] * linker_mask.to(
    output["atom37_atom_exists"].device
)

### Convert ESMFold output tensor to a PDB file

In [17]:
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from transformers.models.esm.openfold_utils.protein import Protein as OFProtein
from transformers.models.esm.openfold_utils.protein import to_pdb


def convert_outputs_to_pdb(outputs):
    """Convert from the ESMFold model output to a PDB-formatted protein string

    Args:
        outputs: Output tensor of HuggingFace ESMFold model

    Returns:
        String with formatted PDB of protein structure

    """

    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs_dictionary = {key: value.to("cpu").numpy() for key, value in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs_dictionary["atom37_atom_exists"]
    
    pdbs = []
    
    for idx in range(outputs_dictionary["aatype"].shape[0]):
        aa = outputs_dictionary["aatype"][idx]
        predicted_positions = final_atom_positions[idx]
        mask = final_atom_mask[idx]
        residues = outputs_dictionary["residue_index"][idx] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=predicted_positions,
            atom_mask=mask,
            residue_index=residues,
            b_factors=outputs_dictionary["plddt"][idx],
            chain_index=outputs_dictionary["chain_index"][idx] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))

    return pdbs[0]

### Convert the model output

Convert the model output to a PDB structure

In [18]:
pdb = convert_outputs_to_pdb(output)

### Display the protein folding prediction

In [19]:
import py3Dmol

view = py3Dmol.view(js="https://3dmol.org/build/3Dmol.js", width=800, height=400)
view.addModel("".join(pdb), "pdb")
view.setStyle({"model": -1}, {"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()

## Write the PDB to a file

In [20]:
prediction_pdb_filename = "prediction_3HFM.pdb"

In [21]:
with open(prediction_pdb_filename, "w") as fptr:
    fptr.write("HEADER    COMPLEX(ANTIBODY-ANTIGEN)   3HFM ESMFold prediction\n")
    fptr.write(pdb)

You can compare the prediction with the real structure by using [USAlign](https://zhanggroup.org/US-align/).

The real 3HFM PDB structure is on [RCSB](https://files.rcsb.org/download/3HFM.pdb). Note that this structure also includes the antigen. So only two of the 3 chains will align (but USalign will compensate for that).

> Aligned length= 214, RMSD=   1.35, Seq_ID=n_identical/n_aligned= 1.000
TM-score= 0.94787 (normalized by length of Structure_1: L=214, d0=5.44)
> The root mean squared error (distance) between the predicted structure and the measured structure is 1.35 Angstrom. For comparison, the diameter of a single water molecule is about 2.75 Angstrom. The template matching score is 0.948 (perfect match = 1).

In [22]:
%%sh 
wget --no-clobber https://zhanggroup.org/US-align/bin/module/USalign.cpp
wget --no-clobber https://files.rcsb.org/download/3HFM.pdb
if [ ! -f USalign ]; then
    g++ -v -static -O3 -ffast-math -o USalign USalign.cpp
fi

File ‘USalign.cpp’ already there; not retrieving.

File ‘3HFM.pdb’ already there; not retrieving.



In [23]:
!./USalign 3HFM.pdb $prediction_pdb_filename -o superimposed


 ********************************************************************
 * US-align (Version 20240319)                                      *
 * Universal Structure Alignment of Proteins and Nucleic Acids      *
 * Reference: C Zhang, M Shine, AM Pyle, Y Zhang. (2022) Nat Methods*
 *            C Zhang, AM Pyle (2022) iScience.                     *
 * Please email comments and suggestions to zhang@zhanggroup.org    *
 ********************************************************************

Name of Structure_1: 3HFM.pdb:L (to be superimposed onto Structure_2)
Name of Structure_2: prediction_3HFM.pdb:A
Length of Structure_1: 214 residues
Length of Structure_2: 360 residues

Aligned length= 214, RMSD=   1.35, Seq_ID=n_identical/n_aligned= 1.000
TM-score= 0.94787 (normalized by length of Structure_1: L=214, d0=5.44)
TM-score= 0.57416 (normalized by length of Structure_2: L=360, d0=6.90)
(You should use TM-score normalized by length of the reference structure)

(":" denotes residue pairs of d 

In [None]:
view = py3Dmol.view(js="https://3dmol.org/build/3Dmol.js", width=800, height=400)
view.addModel(open("superimposed.pdb").read(), "pdb")
view.setStyle({"model": -1}, {"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()