<a href="https://colab.research.google.com/github/virtualscreenlab/AptaFold/blob/main/tools/invfold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ESM3 Inverse Folding Notebook

This notebook is intended to be used as a tool for inverse folding using the ESM3 model.


### Setup

Install dependencies and setup the colab environment for asyncio requests


In [1]:
!pip install git+https://github.com/evolutionaryscale/esm
!pip install pydssp pygtrie dna-features-viewer nest_asyncio py3dmol

Collecting git+https://github.com/evolutionaryscale/esm
  Cloning https://github.com/evolutionaryscale/esm to /tmp/pip-req-build-sf7im4o3
  Running command git clone --filter=blob:none --quiet https://github.com/evolutionaryscale/esm /tmp/pip-req-build-sf7im4o3
  Resolved https://github.com/evolutionaryscale/esm to commit 8c91cc91916e1205ab78a142f17315c50c59090f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting torchtext (from esm==3.1.1)
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting biotite==0.41.2 (from esm==3.1.1)
  Downloading biotite-0.41.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.1 kB)
Collecting msgpack-numpy (from esm==3.1.1)
  Downloading msgpack_numpy-0.4.8-py2.py3-none-any.whl.metadata (5.0 kB)
Collecting biopython (from esm==3.1.1)
  Downloading biopython-1.84-cp310-c

In [2]:
import nest_asyncio

nest_asyncio.apply()

### Inference Settings


In [3]:
from esm.widgets.utils.types import ClientInitContainer
from esm.widgets.views.inverse_folding import create_inverse_folding_ui
from esm.widgets.views.login import create_login_ui



In [10]:
client_init = ClientInitContainer()
create_login_ui(client_init)

VBox(children=(HTML(value='\n    <div style="font-family: Arial, sans-serif;">\n        <h2>Inference Options …

We need to log into huggingface if using the model locally


In [9]:
from esm.utils.misc import huggingfacehub_login

if client_init.metadata["inference_option"] == "Local":
    huggingfacehub_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Inverse Folding UI

If running on Google colab, it is recommended to use the light theme and select the "View output fullscreen" option in the cell toolbar for the best experience.


In [11]:
client = client_init()
create_inverse_folding_ui(client)

VBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Workspace:</b>', layout=Layout(margin='0 0 10px 0'…

In [14]:
# Initialize the ESM model
model = client(
    model="esm3-medium-2024-03",
    url="https://forge.evolutionaryscale.ai",
    token=token
)

# Load the alpha hemolysin structure from RCSB
# PDB ID: 7AHL, Chain: A
template_hemolysin = ESMProtein.from_protein_chain(
    ProteinChain.from_rcsb("7AHL", chain_id="A")
)

# Encode the protein sequence and structure
template_hemolysin_tokens = model.encode(template_hemolysin)

print("Sequence tokens:")
print(
    "    ", ", ".join([str(token) for token in template_hemolysin_tokens.sequence.tolist()])
)

print("Structure tokens:")
print(
    "    ", ", ".join([str(token) for token in template_hemolysin_tokens.structure.tolist()])
)

# Initialize prompt sequence with placeholders
prompt_sequence = ["_"] * len(template_hemolysin.sequence)

# Specify mutation sites within the beta-barrel
# Selected Residues: 45, 85, 150, 200, 220, 250
# Mapping residue numbers to zero-based indices
# Assuming residue numbering starts at 1
mutation_sites = {
    45: "A",   # Leucine (L) to Alanine (A)
    85: "S",   # Glycine (G) to Serine (S)
    150: "K",  # Arginine (R) to Lysine (K)
    200: "Y",  # Phenylalanine (F) to Tyrosine (Y)
    220: "T",  # Valine (V) to Threonine (T)
    250: "D"   # Glutamic Acid (E) to Aspartic Acid (D)
}

for res_num, new_aa in mutation_sites.items():
    index = res_num - 1  # Convert to zero-based index
    if index < len(prompt_sequence):
        prompt_sequence[index] = new_aa
    else:
        print(f"Warning: Residue number {res_num} is out of sequence range.")

# Convert prompt sequence list to string
prompt_sequence = "".join(prompt_sequence)

print("Original Sequence:")
print(template_hemolysin.sequence)

print("Prompt Sequence with Mutations:")
print(prompt_sequence)

# Encode the prompt sequence
prompt = model.encode(ESMProtein(sequence=prompt_sequence))

# Initialize structure tokens with placeholders (4096)
prompt.structure = torch.full_like(prompt.sequence, 4096)
prompt.structure[0] = 4098  # <bos>
prompt.structure[-1] = 4097  # <eos>

# Fill in structure tokens at key residues near the beta-barrel
# Preserving structural context around mutation sites
# Preserved Residues: 44-46, 84-86, 149-151, 199-201, 219-221, 249-251
preserved_residues = (
    list(range(44, 47)) +   # Residues 45-47
    list(range(84, 87)) +   # Residues 85-87
    list(range(149, 152)) + # Residues 150-152
    list(range(199, 202)) + # Residues 200-202
    list(range(219, 222)) + # Residues 220-222
    list(range(249, 252))   # Residues 250-252
)

for res in preserved_residues:
    if res < len(template_hemolysin_tokens.structure):
        prompt.structure[res] = template_hemolysin_tokens.structure[res]
    else:
        print(f"Warning: Preserved residue index {res+1} is out of sequence range.")

print("Structure Tokens Overview:")
print("".join(["✔" if st < 4096 else "_" for st in prompt.structure]))


# Generate new structure tokens
structure_generation = model.generate(
    prompt,
    GenerationConfig(
        # Generate a structure.
        track="structure",
        # Sample one token per forward pass of the model.
        num_steps=(prompt.structure == 4096).sum().item(),
        # Sampling temperature trades perplexity with diversity.
        temperature=1.0,
    ),
)


print("Generated Structure Tokens:")
print(
    "    ", ", ".join([str(token) for token in structure_generation.structure.tolist()])
)

# Decode structure tokens to backbone coordinates
structure_generation_protein = model.decode(structure_generation)

print("")

# Visualize the modified protein structure
view = py3Dmol.view(width=1000, height=500)
view.addModel(
    structure_generation_protein.to_protein_chain().infer_oxygen().to_pdb_string(),
    "pdb",
)
view.setStyle({"cartoon": {"color": "lightblue"}})
view.zoomTo()
view.show()