In [2]:
%set_env TOKENIZERS_PARALLELISM=false
!pip install esm
import numpy as np
import torch

!pip install py3Dmol
import py3Dmol

from esm.utils.structure.protein_chain import ProteinChain
from esm.sdk import client
from esm.sdk.api import (
    ESMProtein,
    GenerationConfig,
)

env: TOKENIZERS_PARALLELISM=false
Collecting esm
  Using cached esm-3.0.5-py3-none-any.whl.metadata (9.4 kB)
Collecting torchtext (from esm)
  Using cached torchtext-0.18.0-cp311-cp311-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting einops (from esm)
  Using cached einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting biotite==0.41.2 (from esm)
  Using cached biotite-0.41.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.1 kB)
Collecting msgpack-numpy (from esm)
  Using cached msgpack_numpy-0.4.8-py2.py3-none-any.whl.metadata (5.0 kB)
Collecting biopython (from esm)
  Using cached biopython-1.84-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Using cached esm-3.0.5-py3-none-any.whl (148 kB)
Using cached biotite-0.41.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.0 MB)
Using cached biopython-1.84-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
Using cached einops-0.8.0-py3-none-any.whl (43 kB

In [3]:
from getpass import getpass

token = getpass("Token from Forge console: ")
model = client(
    model="esm3-small-2024-08",
    url="https://forge.evolutionaryscale.ai",
    token=token,
)

Token from Forge console:  ········


In [4]:
from Bio.PDB import PDBParser

# Load the structure from the PDB ID
parser = PDBParser(QUIET=True)
structure = parser.get_structure("structure", "./8db4.pdb")

# Iterate over the chains in the structure
for model in structure:
    for chain in model:
        print(f"Chain ID: {chain.id}, Number of residues: {len(chain)}")

with open('8db4.pdb', 'r') as pdb_file:
    for line in pdb_file:
        if line.startswith('COMPND'):
            print(line.strip())

Chain ID: A, Number of residues: 292
Chain ID: B, Number of residues: 250
Chain ID: C, Number of residues: 249
Chain ID: D, Number of residues: 236
Chain ID: E, Number of residues: 108
Chain ID: F, Number of residues: 104
Chain ID: G, Number of residues: 235
Chain ID: H, Number of residues: 218
Chain ID: I, Number of residues: 258
Chain ID: J, Number of residues: 270
COMPND    MOL_ID: 1;
COMPND   2 MOLECULE: 13T1 HEAVY CHAIN;
COMPND   3 CHAIN: A, G;
COMPND   4 ENGINEERED: YES;
COMPND   5 MOL_ID: 2;
COMPND   6 MOLECULE: 13T1 LIGHT CHAIN;
COMPND   7 CHAIN: B, H;
COMPND   8 ENGINEERED: YES;
COMPND   9 MOL_ID: 3;
COMPND  10 MOLECULE: 22S1 HEAVY CHAIN;
COMPND  11 CHAIN: C, I;
COMPND  12 ENGINEERED: YES;
COMPND  13 MOL_ID: 4;
COMPND  14 MOLECULE: 22S1 LIGHT CHAIN;
COMPND  15 CHAIN: D, J;
COMPND  16 ENGINEERED: YES;
COMPND  17 MOL_ID: 5;
COMPND  18 MOLECULE: ARA H 2 ALLERGEN;
COMPND  19 CHAIN: E, F;
COMPND  20 ENGINEERED: YES




In [24]:
pdb_id = "8DB4"  # PDB ID corresponding to Ara h 2 bound by two neutralizing antibodies
pdb_file = "./8db4.pdb"
chain_id = "E"  # Chain ID corresponding to Ara h 2 in the PDB structure
arah2_chain = ProteinChain.from_pdb(pdb_file, chain_id)
# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file

In [25]:
print(arah2_chain.sequence)

AARRCQSQLERANLRPCEQHLMQKIQRSQHQERCCNELNEFENNQRCMCEALQQIMENQSDRLQGRQQEQQFKRELRNLPQQCGLRAPQRCDLDV


In [7]:
print("atom37_positions shape: ", arah2_chain.atom37_positions.shape)
print(arah2_chain.atom37_positions[:3])

atom37_positions shape:  (95, 37, 3)
[[[ -5.856 -14.267   4.598]
  [ -5.618 -14.534   6.013]
  [ -4.121 -14.612   6.297]
  [ -6.309 -15.819   6.433]
  [ -3.596 -13.876   7.135]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan 

In [8]:
# First we can create a `py3Dmol` view object
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = arah2_chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

In [9]:
!pip install freesasa

Collecting freesasa
  Using cached freesasa-2.2.1-cp311-cp311-linux_x86_64.whl
Installing collected packages: freesasa
Successfully installed freesasa-2.2.1


In [23]:
import freesasa
from Bio import PDB
from Bio.SeqUtils import seq1
import csv

# Load the structure using BioPython's PDBParser
pdb_parser = PDB.PDBParser(QUIET=True)  # QUIET mode to suppress warnings
structure = pdb_parser.get_structure("structure", "./8db4.pdb")

# Specify the chain of interest (e.g., Chain E)
chain_id = "E"

# Extract just the chain of interest from the structure
chain_structure = None
for model in structure:
    chain_structure = model[chain_id]
    break  # Exit after extracting the first model (if multiple models exist)

# Write a temporary PDB file containing only the selected chain
with open("temp_chain.pdb", "w") as temp_pdb:
    io = PDB.PDBIO()
    io.set_structure(chain_structure)
    io.save(temp_pdb)

# Initialize FreeSASA structure for the specific chain
freesasa_structure = freesasa.Structure("temp_chain.pdb")

# Run FreeSASA to calculate ASA for each atom
result = freesasa.calc(freesasa_structure)

# Max ASA values for RSA calculation (adjusted per residue type)
max_asa = {
    'A': 113, 'R': 241, 'N': 158, 'D': 151, 'C': 140, 'Q': 189, 'E': 183,
    'G': 85,  'H': 194, 'I': 182, 'L': 180, 'K': 211, 'M': 204, 'F': 218,
    'P': 143, 'S': 122, 'T': 146, 'W': 259, 'Y': 229, 'V': 160
}

# Function to check for N-glycosylation motif (N-X-S/T)
def is_nglycosylated(seq, pos):
    # Ensure the position is within the bounds of the sequence
    if pos + 2 < len(seq) and seq[pos] == 'N':
        if seq[pos + 2] in ['S', 'T'] and seq[pos + 1] != 'P':
            return 1
    return 0

# Prepare CSV output
with open('asa_rsa_output.csv', mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["Residue", "ASA", "RSA"])

    # FreeSASA iterates over atoms, not residues, so we have to match atoms
    atom_idx = 0  # Keep track of FreeSASA atom index

    # Get the sequence of residues for N-glycosylation check
    chain_sequence = [seq1(res.resname) for res in chain_structure.get_residues()]
    print("Chain E Sequence:", "".join(chain_sequence))  # Print sequence for debugging

    # Iterate over residues in the selected chain
    for i, residue in enumerate(chain_structure.get_residues()):
        try:
            # Filter out non-protein residues (e.g., metals, water)
            if residue.id[0] != ' ':  # Standard residues have an empty field for id[0]
                continue

            res_id = residue.id[1]  # residue position in the chain
            amino = residue.resname  # 3-letter amino acid code
            amino_one_letter = seq1(amino)  # Convert to 1-letter code

            # Initialize ASA for the entire residue
            residue_asa = 0.0

            # Iterate over atoms in the residue to sum up their ASA values
            for atom in residue:
                residue_asa += result.atomArea(atom_idx)
                atom_idx += 1  # Move to the next atom

            # Calculate RSA (Relative Solvent Accessibility)
            rsa = residue_asa / max_asa.get(amino_one_letter, 1)  # Use the max ASA for that residue

            # Check if the residue is N-glycosylated and ensure index is valid
            if i + 2 < len(chain_sequence):
                n_glycosylation = is_nglycosylated(chain_sequence, i)
            else:
                n_glycosylation = 0  # Set to 0 if out of range

            # Create the "Residue" field as position:amino:n-glycosylation
            residue_field = f"{res_id}:{amino_one_letter}:{n_glycosylation}"

            # Write to CSV (position:amino:n-glycosylation:ASA:RSA)
            writer.writerow([residue_field, residue_asa, rsa])
        except Exception as e:
            print(f"Error processing residue {residue}: {e}")


Chain E Sequence: AARRCQSQLERANLRPCEQHLMQKIQRSQHQERCCNELNEFENNQRCMCEALQQIMENQSDRLQGRQQEQQFKRELRNLPQQCGLRAPQRCDLDVXXXXXXXXXXXX


In [None]:
motif_inds = np.arange(123, 146)
# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues
motif_sequence = renal_dipep_chain[motif_inds].sequence
motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions
print("Motif sequence: ", motif_sequence)
print("Motif atom37_positions shape: ", motif_atom37_positions.shape)

In [None]:
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
motif_res_inds = (
    motif_inds + 1
).tolist()  # residue indices are 1-indexed in PDB files, so we add 1 to the indices
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

In [None]:
prompt_length = 200
# First, we can construct a sequence prompt of all masks
sequence_prompt = ["_"] * prompt_length
# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)
sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)
sequence_prompt = "".join(sequence_prompt)
print("Sequence prompt: ", sequence_prompt)
print("Length of sequence prompt: ", len(sequence_prompt))

# Next, we can construct a structure prompt of all nan coordinates
structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72
structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(
    motif_atom37_positions
)
print("Structure prompt shape: ", structure_prompt.shape)
print(
    "Indices with structure conditioning: ",
    torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),
)

# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)

In [None]:
# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use
sequence_generation_config = GenerationConfig(
    track="sequence",  # We want ESM3 to generate tokens for the sequence track
    num_steps=sequence_prompt.count("_")
    // 2,  # We'll use num(mask tokens) // 2 steps to decode the sequence
    temperature=0.5,  # We'll use a temperature of 0.5 to control the randomness of the decoding process
)

# Now, we can use the `generate` method of the model to decode the sequence
sequence_generation = model.generate(protein_prompt, sequence_generation_config)
print("Sequence Prompt:\n\t", protein_prompt.sequence)
print("Generated sequence:\n\t", sequence_generation.sequence)

In [None]:
structure_prediction_config = GenerationConfig(
    track="structure",  # We want ESM3 to generate tokens for the structure track
    num_steps=len(sequence_generation) // 8,
    temperature=0.7,
)
structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
structure_prediction = model.generate(
    structure_prediction_prompt, structure_prediction_config
)

In [None]:
# Convert the generated structure to a back into a ProteinChain object
structure_prediction_chain = structure_prediction.to_protein_chain()
# Align the generated structure to the original structure using the motif residues
motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))
structure_prediction_chain.align(
    renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)
crmsd = structure_prediction_chain.rmsd(
    renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)
print(
    "cRMSD of the motif in the generated structure vs the original structure: ", crmsd
)

view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(pdb_str, "pdb", viewer=(0, 0))
view.addModel(structure_prediction_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
view.addStyle(
    {"resi": (motif_inds_in_generation + 1).tolist()},
    {"cartoon": {"color": "cyan"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()