[Generate](https://colab.research.google.com/github/evolutionaryscale/esm/blob/main/examples/generate.ipynb)

In [5]:
%set_env TOKENIZERS_PARALLELISM=false
import numpy as np
import torch
import py3Dmol
import huggingface_hub

from esm.utils.structure.protein_chain import ProteinChain
from esm.models.esm3 import ESM3
from esm.sdk import client
from esm.sdk.api import (
    ESMProtein,
    GenerationConfig,
)

env: TOKENIZERS_PARALLELISM=false


In [6]:
# huggingface_hub.login()  # will prompt you to get an API key and accept the ESM3 license # only needed the first time the script is run
model =  ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device("cuda"))



Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

  state_dict = torch.load(


In [7]:
pdb_id = "1ITU" # PDB ID corresponding to Renal Dipeptidase
chain_id = "A" # Chain ID corresponding to Renal Dipeptidase in the PDB structure
renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id) # loads from the internet
# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file

In [8]:
print(renal_dipep_chain.sequence)

DFFRDEAERIMRDSPVIDGHNDLPWQLLDMFNNRLQDERANLTTLAGTHTNIPKLRAGFVGGQFWSVYTPCDTQNKDAVRRTLEQMDVVHRMCRMYPETFLYVTSSAGIRQAFREGKVASLIGVEGGHSIDSSLGVLRALYQLGMRYLTLTHSCNTPWADNWLVDTGDSEPQSQGLSPFGQRVVKELNRLGVLIDLAHVSVATMKATLQLSRAPVIFSHSSAYSVCASRRNVPDDVLRLVKQTDSLVMVNFYNNYISCTNKANLSQVADHLDHIKEVAGARAVGFGGDFDGVPRVPEGLEDVSKYPDLIAELLRRNWTEAEVKGALADNLLRVFEAVEQASNLTQAPEEEPIPLDQLGGSCRTHYGYSS


In [10]:
print("atom37_positions shape: ", renal_dipep_chain.atom37_positions.shape)
print(renal_dipep_chain.atom37_positions[0])

atom37_positions shape:  (369, 37, 3)
[[-40.525  -9.87   -2.643]
 [-39.79   -9.325  -3.825]
 [-38.765 -10.354  -4.294]
 [-39.096  -8.012  -3.45 ]
 [-37.878 -10.748  -3.53 ]
 [-38.41   -7.359  -4.629]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [-39.105  -7.036  -5.617]
 [-37.177  -7.161  -4.562]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan     nan]
 [    nan     nan

In [11]:
# First we can create a `py3Dmol` view object
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = renal_dipep_chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

In [13]:
motif_inds = np.arange(123, 146)
# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues
motif_sequence = renal_dipep_chain[motif_inds].sequence
motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions
print("Motif sequence: ", motif_sequence)
print("Motif atom37_positions shape: ", motif_atom37_positions.shape)

Motif sequence:  VEGGHSIDSSLGVLRALYQLGMR
Motif atom37_positions shape:  (23, 37, 3)


In [14]:
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
motif_res_inds = (motif_inds + 1).tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

In [16]:
prompt_length = 200
# First, we can construct a sequence prompt of all masks
sequence_prompt = ["_"] * prompt_length
# Then we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)
sequence_prompt[72:72+len(motif_sequence)] = list(motif_sequence)
sequence_prompt = "".join(sequence_prompt)
print("Sequence prompt: ", sequence_prompt)
print("Length of sequence prompt: ", len(sequence_prompt))

Sequence prompt:  ________________________________________________________________________VEGGHSIDSSLGVLRALYQLGMR_________________________________________________________________________________________________________
Length of sequence prompt:  200


In [19]:
# Next, we can construct a structure prompt of all nan coordinates
structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
# Then we can insert the motif atomic coordinates into the prompt, starting at index 72
structure_prompt[72:72+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions)
print("Structure prompt shape: ", structure_prompt.shape)
print(
    "Indices with structure conditioning: ",
    torch.where(~torch.isnan(structure_prompt).any(dim=-1))[0].tolist()
)

Structure prompt shape:  torch.Size([200, 37, 3])
Indices with structure conditioning:  [72, 72, 72, 72, 72, 72, 72, 73, 73, 73, 73, 73, 73, 73, 73, 73, 74, 74, 74, 74, 75, 75, 75, 75, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 77, 77, 77, 77, 77, 77, 78, 78, 78, 78, 78, 78, 78, 78, 79, 79, 79, 79, 79, 79, 79, 79, 80, 80, 80, 80, 80, 80, 81, 81, 81, 81, 81, 81, 82, 82, 82, 82, 82, 82, 82, 82, 83, 83, 83, 83, 84, 84, 84, 84, 84, 84, 84, 85, 85, 85, 85, 85, 85, 85, 85, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 87, 87, 87, 87, 87, 88, 88, 88, 88, 88, 88, 88, 88, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 90, 90, 90, 90, 90, 90, 90, 90, 90, 91, 91, 91, 91, 91, 91, 91, 91, 92, 92, 92, 92, 93, 93, 93, 93, 93, 93, 93, 93, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94]


In [24]:
# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)

In [25]:
# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use
sequence_generation_config = GenerationConfig(
    track="sequence", # we want ESM3 to generate tokens for the sequence track
    num_steps=sequence_prompt.count("_")//2,
    temperature=0.5,
)

# Now, we can use the generate method of the model to decode the sequence
sequence_generation = model.generate(protein_prompt, sequence_generation_config)
print("Sequence Prompt:\n", protein_prompt.sequence)
print("Generated sequence:\n", sequence_generation.sequence)

  state_dict = torch.load(
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore
100%|██████████| 88/88 [00:05<00:00, 15.08it/s]
  state_dict = torch.load(
  state_dict = torch.load(


Sequence Prompt:
	 ________________________________________________________________________VEGGHSIDSSLGVLRALYQLGMR_________________________________________________________________________________________________________
Generated sequence:
	 LDKLKAGGVGAQFWSVYVPCSYKDKDPVRATLEHIDLVYRLAERYPDQIEIAVTAAEIKRIVAAGKIASLIGVEGGHSIDSSLGVLRALYQLGMRYMTLTWNCNNDWADSATDPKKKGVTAFGKEVVKEMNRLGMLVDISHVSEDTFWDVMEVSTAPVIASHSSARALCDHPRNMTDEQLKALAKKGGVVMINLYPGFLG


In [29]:
# use the generated sequence to determine the structure

structure_prediction_config = GenerationConfig(
    track="structure",
    num_steps=len(sequence_generation)//8,
    temperature=0.7,
)

structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
structure_prediction = model.generate(structure_prediction_prompt, structure_prediction_config)



100%|██████████| 25/25 [00:01<00:00, 22.32it/s]


In [30]:
# Convert generated structure back into a ProteinChain object
structure_prediction_chain = structure_prediction.to_protein_chain()

# Align the generated structure to the original structure using the motif residues
motif_inds_in_generation = np.arange(72, 72+len(motif_sequence))
structure_prediction_chain.align(
    renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)

crmsd = structure_prediction_chain.rmsd(
    renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)

print("cRMSD of the motif in the generated structure vs the original structure: ", crmsd)

view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(pdb_str, "pdb", viewer=(0, 0))
view.addModel(structure_prediction_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
view.addStyle(
    {"resi": (motif_inds_in_generation + 1).tolist()},
    {"cartoon": {"color": "cyan"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()


cRMSD of the motif in the generated structure vs the original structure:  0.20393863685881072


In [33]:
# secondary structure editing example

helix_shortening_chain = ProteinChain.from_rcsb("7XBQ", "A")
view = py3Dmol.view(width=500, height=500)
view.addModel(helix_shortening_chain.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
helix_region = np.arange(38,111) # zero-indexed
view.addStyle(
    {"resi": (helix_region+1).tolist()}, {"cartoon": {"color": "lightblue"}}
)
view.zoomTo()
view.show()
helix_shortening_ss8="CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC"
print(
    "Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \n\t",
    helix_shortening_ss8,
)


Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) 
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC


In [37]:
print(len(helix_shortening_ss8))
print(len(helix_shortening_chain.sequence))

233
233


In [40]:
helix_region

array([ 38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,
        51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,
        64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,
        77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
        90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102,
       103, 104, 105, 106, 107, 108, 109, 110])

In [51]:
shortened_region_length = 45

#sequence prompt to mask the shortened region
sequence_prompt = (
    helix_shortening_chain.sequence[:helix_region[0]]
    + "_" * shortened_region_length
    + helix_shortening_chain.sequence[helix_region[-1] + 1:]
)
print("Sequence prompt:\n\t", sequence_prompt)

ss8_prompt = (
    helix_shortening_ss8[:helix_region[0]]
    + (
        ((shortened_region_length - 3)//2) * "H"
        + "C"*3
        + ((shortened_region_length - 3)//2) * "H"
    )
    + helix_shortening_ss8[helix_region[-1]+1:]
)
print("SS8 prompt:\n\t", ss8_prompt)
print(
    "Proposed SS8 for shortened helix-coil-helix region:\n\t",
    " " * helix_region[0] + ss8_prompt[helix_region[0] : helix_region[0] + 45],
)

print("Original sequence:\n\t", helix_shortening_chain.sequence)
print("Original SS8:\n\t", helix_shortening_ss8)
print(
    "Original SS8 for helix-coil-helix region:\n\t",
    " " * helix_region[0]
    + helix_shortening_ss8[helix_region[0] : helix_region[-1] + 1],
)

Sequence prompt:
	 MAREENVYMAKLAEQAERYEEMVQFMEKVSTSLGSEEL_____________________________________________SASNGDSKVFYLKMKGDYHRYLAEFKTGAERKEAAESTLSAYKAAQDIANTELAPTHPIRLGLALNFSVFYYEILNSPDRACNLAKQAFDEAIAELDTLGEESYKDSTLIMQLLRDNLTLWT
SS8 prompt:
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCHHHHHHHHHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHHHHHHHHGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC
Proposed SS8 for shortened helix-coil-helix region:
	                                       HHHHHHHHHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHHHHHHHH
Original sequence:
	 MAREENVYMAKLAEQAERYEEMVQFMEKVSTSLGSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEESRGNEEHVKCIKEYRSKIESELSNICDGILKLLDSNLIPSASNGDSKVFYLKMKGDYHRYLAEFKTGAERKEAAESTLSAYKAAQDIANTELAPTHPIRLGLALNFSVFYYEILNSPDRACNLAKQAFDEAIAELDTLGEESYKDSTLIMQLLRDNLTLWT
Original SS8:
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHH

In [52]:
protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)

In [55]:
print("Generating protein sequence...")

config = GenerationConfig(track="sequence", num_steps=protein_prompt.sequence.count("_")//2, temperature=0.5)
sequence_generation = model.generate(protein_prompt, config)

print("Folding protein...")
config = GenerationConfig(track="structure", num_steps=len(protein_prompt)//4, temperature = 0)
structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), config)

Generating protein sequence...


100%|██████████| 22/22 [00:00<00:00, 22.52it/s]


Folding protein...


100%|██████████| 51/51 [00:02<00:00, 22.48it/s]


In [56]:
predicted_chain = structure_prediction.to_protein_chain()
predicted_chain = predicted_chain.align(
    helix_shortening_chain,
    mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)),
    target_inds=np.arange(
        len(helix_shortening_chain) - 120, len(helix_shortening_chain)
    ),
)
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(helix_shortening_chain.to_pdb_string(), "pdb", viewer=(0, 0))
view.addModel(predicted_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (helix_region + 1).tolist()},
    {"cartoon": {"color": "lightblue"}},
    viewer=(0, 0),
)
view.addStyle(
    {"resi": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},
    {"cartoon": {"color": "pink"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()

In [58]:
# SASA Example, expose the buried helix

lipase_chain = ProteinChain.from_rcsb("1LBS", "A")
span_start = 105
span_end = 116
view = py3Dmol.view(width=500, height=500)
view.addModel(lipase_chain.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (np.arange(span_start, span_end) + 1).tolist()},
    {"cartoon": {"color": "red"}},
)
view.zoomTo()
view.show()
lipase_ss8 = "CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC"

In [59]:
structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)
structure_prompt[span_start:span_end] = torch.tensor(
    lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32
)

sasa_prompt = [None] * len(lipase_chain)
sasa_prompt[span_start:span_end] = [40.0] * (span_end - span_start)

print("SASA prompt (just for buried region): ", sasa_prompt[span_start:span_end])

protein_prompt = ESMProtein(
    sequence="_" * len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt
)

SASA prompt (just for buried region):  [40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0]


In [60]:
generated_proteins = []
N_SAMPLES = 16
for i in range(N_SAMPLES):
    print("Generating protein sequence...")
    sequence_generation = model.generate(
        protein_prompt,
        GenerationConfig(
            track="sequence", num_steps=len(protein_prompt) // 8, temperature=0.7
        ),
    )
    print("Folding protein...")
    structure_prediction = model.generate(
        ESMProtein(sequence=sequence_generation.sequence),
        GenerationConfig(track="structure", num_steps=len(protein_prompt) // 32),
    )
    generated_proteins.append(structure_prediction)

# Sort generations by ptm
generated_proteins = sorted(
    generated_proteins, key=lambda x: x.ptm.item(), reverse=True
)

Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.93it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 21.40it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 22.07it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 22.69it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 20.07it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 21.19it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 20.15it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 19.54it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 20.61it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 20.78it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 20.31it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 22.03it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 20.19it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 20.90it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 21.49it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 21.47it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 21.08it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 20.43it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 21.46it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 20.97it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 20.16it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 21.27it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 21.90it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 21.43it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 21.67it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 22.01it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 22.20it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 21.87it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 22.42it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 21.85it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:01<00:00, 22.31it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 21.77it/s]


In [None]:
N_SAMPLES_TO_SHOW = 4
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW + 1))
view.addModel(lipase_chain.to_pdb_string(), "pdb", viewer=(0, 0))
for i in range(N_SAMPLES_TO_SHOW):
    print(
        "PTM of generated protein {}: {:.2f}".format(
            i + 1, generated_proteins[i].ptm.item()
        )
    )
    view.addModel(
        generated_proteins[i].to_protein_chain().to_pdb_string(),
        "pdb",
        viewer=(0, i + 1),
    )
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (np.arange(span_start, span_end) + 1).tolist()},
    {"cartoon": {"color": "red"}},
)
view.zoomTo()
view.show()

PTM of generated protein 1: 0.90
PTM of generated protein 2: 0.90
PTM of generated protein 3: 0.85
PTM of generated protein 4: 0.74
