In [3]:
import biotite.sequence as seq
import biotite.sequence.align as align
import biotite.sequence.graphics as graphics
from getpass import getpass
import matplotlib.pyplot as pl
import py3Dmol
import torch

from esm.sdk import client
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.esm3 import ESM3


In [4]:
model =  ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device("cuda"))



Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

  state_dict = torch.load(


In [5]:
template_gfp = ESMProtein.from_protein_chain(
    ProteinChain.from_rcsb("1qy3", chain_id="A")
)

template_gfp_tokens = model.encode(template_gfp)

print("Sequence tokens:")
print("    ", ", ".join([
    str(token) for token in template_gfp_tokens.sequence.tolist()
]))

print("Structure tokens:")
print("    ", ", ".join([
    str(token) for token in template_gfp_tokens.structure.tolist()
]))

  state_dict = torch.load(
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore


Sequence tokens:
     0, 15, 6, 9, 9, 4, 18, 11, 6, 7, 7, 14, 12, 4, 7, 9, 4, 13, 6, 13, 7, 17, 6, 21, 15, 18, 8, 7, 8, 6, 9, 6, 9, 6, 13, 5, 11, 19, 6, 15, 4, 11, 4, 15, 18, 12, 23, 11, 11, 6, 15, 4, 14, 7, 14, 22, 14, 11, 4, 7, 11, 11, 4, 11, 19, 6, 7, 16, 23, 18, 8, 10, 19, 14, 13, 21, 20, 15, 16, 21, 13, 18, 18, 15, 8, 5, 20, 14, 9, 6, 19, 7, 16, 9, 5, 11, 12, 8, 18, 15, 13, 13, 6, 17, 19, 15, 11, 10, 5, 9, 7, 15, 18, 9, 6, 13, 11, 4, 7, 17, 10, 12, 9, 4, 15, 6, 12, 13, 18, 15, 9, 13, 6, 17, 12, 4, 6, 21, 15, 4, 9, 19, 17, 19, 17, 8, 21, 17, 7, 19, 12, 11, 5, 13, 15, 16, 15, 17, 6, 12, 15, 5, 17, 18, 15, 12, 10, 21, 17, 12, 9, 13, 6, 8, 7, 16, 4, 5, 13, 21, 19, 16, 16, 17, 11, 14, 12, 6, 13, 6, 14, 7, 4, 4, 14, 13, 17, 21, 19, 4, 8, 11, 16, 8, 5, 4, 8, 15, 13, 14, 17, 9, 15, 10, 13, 21, 20, 7, 4, 4, 9, 18, 7, 11, 5, 5, 6, 12, 2
Structure tokens:
     4098, 2221, 3124, 1129, 3395, 3019, 1645, 2037, 2490, 60, 1591, 3819, 457, 2708, 383, 2219, 653, 2545, 2984, 3370, 66, 608, 2410, 103

In [6]:
print(len(template_gfp_tokens.sequence.tolist()))
print(len(template_gfp_tokens.structure.tolist()))

229
229


In [7]:
prompt_sequence = ["_"] * len(template_gfp.sequence)
prompt_sequence[59] = "T"
prompt_sequence[62] = "T"
prompt_sequence[63] = "Y"
prompt_sequence[64] = "G"
prompt_sequence[93] = "R"
prompt_sequence[219] = "E"
prompt_sequence = "".join(prompt_sequence)

print(template_gfp.sequence)
print(prompt_sequence)

prompt = model.encode(
    ESMProtein(sequence=prompt_sequence)
)

# We construct an empty structure track like |<bos> <mask> ... <mask> <eos>|...
prompt.structure = torch.full_like(prompt.sequence, 4096)
prompt.structure[0] = 4098
prompt.structure[-1] = 4097
# ... and then we fill in structure tokens at key residues near the alpha helix
# kink and at the stabilizing R and E positions on the beta barrel.
prompt.structure[55:70] = template_gfp_tokens.structure[56:71]
prompt.structure[93] = template_gfp_tokens.structure[93]
prompt.structure[219] = template_gfp_tokens.structure[219]

print("".join(["✔" if st < 4096 else "_" for st in prompt.structure]))

KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFSRYPDHMKQHDFFKSAMPEGYVQEATISFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYITADKQKNGIKANFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGI
___________________________________________________________T__TYG____________________________R_____________________________________________________________________________________________________________________________E_______
_______________________________________________________✔✔✔✔✔✔✔✔✔✔✔✔✔✔✔_______________________✔_____________________________________________________________________________________________________________________________✔_________


In [8]:
%%time

num_tokens_to_decode = (prompt.structure == 4096).sum().item()

structure_generation = model.generate(
    prompt,
    GenerationConfig(
        # Generate a structure.
        track="structure",
        # Sample one token per forward pass of the model.
        num_steps=num_tokens_to_decode,
        # Sampling temperature trades perplexity with diversity.
        temperature=1.0,
    )
)

print("These are the structure tokens corresponding to our new design:")
print("    ", ", ".join([
    str(token) for token in structure_generation.structure.tolist()
]))

# Decodes structure tokens to backbone coordinates.
structure_generation_protein = model.decode(structure_generation)

print("")

100%|██████████| 210/210 [00:09<00:00, 21.50it/s]
  state_dict = torch.load(


These are the structure tokens corresponding to our new design:
     4098, 3715, 1046, 3452, 1599, 1360, 385, 2835, 1716, 3320, 4005, 3153, 808, 2780, 808, 705, 2746, 845, 681, 3161, 2407, 2416, 3046, 231, 1319, 598, 3638, 1903, 1, 2171, 1218, 1047, 2693, 1595, 903, 1349, 2490, 1537, 1502, 1803, 1825, 1634, 2513, 3598, 128, 3903, 3926, 232, 3384, 3295, 3803, 2453, 1067, 588, 1053, 1774, 732, 1797, 748, 3403, 2370, 2582, 3704, 2737, 3007, 1660, 499, 484, 2202, 2786, 1400, 2618, 2809, 3059, 3328, 2572, 1853, 3453, 3629, 1940, 748, 4043, 1552, 2138, 86, 3135, 686, 3512, 659, 111, 4018, 2965, 1383, 1066, 3182, 2547, 3585, 1726, 3148, 1416, 3652, 1364, 981, 2995, 1867, 3421, 2078, 3729, 2693, 3607, 2159, 2942, 1998, 880, 4008, 1275, 753, 1165, 306, 2247, 2963, 37, 4079, 26, 145, 1186, 3808, 2224, 767, 197, 3554, 3560, 521, 4032, 69, 1646, 1046, 2930, 3089, 3697, 1912, 2507, 4055, 1853, 2857, 318, 216, 3586, 699, 824, 987, 598, 3101, 2056, 2048, 1476, 1195, 1035, 3205, 2103, 3954, 2874, 532,

  state_dict = torch.load(



CPU times: user 10.1 s, sys: 192 ms, total: 10.3 s
Wall time: 17.5 s


In [None]:
view = py3Dmol.view(width=1000, height=500)
view.addModel(structure_generation_protein.to_protein_chain().infer_oxygen().to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgreen"}})
view.zoomTo()
view.show()

In [None]:
constrained_site_positions = [59, 62, 63, 64, 93, 219]

template_chain = template_gfp.to_protein_chain()
generation_chain = structure_generation_protein.to_protein_chain()

constrained_site_rmsd = template_chain[constrained_site_positions].rmsd(
    generation_chain[constrained_site_positions]
)
backbone_rmsd = template_chain.rmsd(generation_chain)

c_pass = "✅" if constrained_site_rmsd < 1.5 else "❌"
b_pass = "✅" if backbone_rmsd > 1.5 else "❌"

print(f"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}")
print(f"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}")

In [None]:
%%time

num_tokens_to_decode = (prompt.sequence == 32).sum().item()

sequence_generation = model.generate(
    # Generate a sequence.
    structure_generation,
    GenerationConfig(
        track="sequence",
        num_steps=num_tokens_to_decode,
        temperature=1.0,
    )
)

# Refold
sequence_generation.structure = None
length_of_sequence = sequence_generation.sequence.numel() - 2
sequence_generation = model.generate(
    sequence_generation,
    GenerationConfig(
        track="structure",
        num_steps=length_of_sequence,
        temperature=0.0,
    )
)

# Decode to AA string and coordinates.
sequence_generation_protein = model.decode(sequence_generation)

In [None]:
sequence_generation_protein.sequence

In [None]:
seq1 = seq.ProteinSequence(template_gfp.sequence)
seq2 = seq.ProteinSequence(sequence_generation_protein.sequence)

alignments = align.align_optimal(
    seq1,
    seq2,
    align.SubstitutionMatrix.std_protein_matrix(),
    gap_penalty=(-10, -1),
)

alignment = alignments[0]

identity = align.get_sequence_identity(alignment)
print(f"Sequence identity: {100*identity:.2f}%")

print("\nSequence alignment:")
fig = pl.figure(figsize=(8.0, 4.0))
ax = fig.add_subplot(111)
graphics.plot_alignment_similarity_based(
    ax, alignment, symbols_per_line=45, spacing=2,
    show_numbers=True,
)
fig.tight_layout()
pl.show()