In [11]:
# This notebook is heavily based on the methodology of the notebook here: https://github.com/evolutionaryscale/esm/blob/main/examples/gfp_design.ipynb
# The notebook is similar to what the authors' did to generate a new candidate GFP protein, so it seemed like a sensible methodology to start with


# Tweak these variables to change different model parameters
protein_pdb_id = "1YFP"  # the PDB id of the protein we want to run inference on
protein_chain_id = "A"  # the chain id of the protein we want to run inference on
percent_masked = 0.7  # ratio of masked amino acids, from [0.0, 1.0]
num_seqs = 10  # how many sequences to generate

model_name = "esm3-medium-2024-03"  # the name of the model, change this if you want a bigger/smaller model
# other options: esm3-large-2024-03, esm3-medium-2024-08, esm3-small-2024-03, etc. 
model_temperature = 1.0 # a value in [0.0, 1.0?], higher numbers have more stochasticity

In [2]:
# from IPython.display import clear_output
# !pip install git+https://github.com/evolutionaryscale/esm.git
# !pip install py3Dmol
# clear_output()  # Suppress pip install log lines after installation is complete.

In [6]:
import biotite.sequence as seq
import biotite.sequence.align as align
import biotite.sequence.graphics as graphics
from getpass import getpass
import matplotlib.pyplot as pl
# import py3Dmol
import torch

from esm.sdk import client
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain

In [7]:
# token: get from EvolutionaryScale Forge
# eqvDHS6R2EcIKsDJExA6a
token = getpass("Token from Forge console: ")

In [8]:
model = client(
    model=model_name,
    url="https://forge.evolutionaryscale.ai",
    token=token,
)

In [12]:
template_gfp = ESMProtein.from_protein_chain(
    ProteinChain.from_rcsb(protein_pdb_id, chain_id=protein_chain_id)
)

print("Original Sequence:")
print(template_gfp.sequence)

Original Sequence:
KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFLQCFARYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGI


In [13]:
%%time

import csv
import random 
from tqdm import tqdm

def mask_sequence(s: str, fraction: float) -> str: 
    if not (0.0 <= fraction <= 1.0):
        raise ValueError("Fraction must be between 0.0 and 1.0")
    num_to_mask = int(len(s) * fraction)
    indices_to_mask = random.sample(range(len(s)), num_to_mask)
    s_list = list(s)
    for i in indices_to_mask:
        s_list[i] = "_"
    return ''.join(s_list)


# Set up the CSV file for logging
csv_file = "generation_results.csv"
with open(csv_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["Sequence_ID", "Generated_Sequence", "Sequence_Identity", "Backbone_RMSD"])

template_gfp_tokens = model.encode(template_gfp)

# Loop through the generation and logging process
for i in tqdm(range(num_seqs), desc="Generating Sequences"):
    prompt_sequence = mask_sequence(template_gfp.sequence, percent_masked)
    prompt = model.encode(
        ESMProtein(sequence=prompt_sequence)
    )
    # Copy over the same structure
    prompt.structure = template_gfp_tokens.structure
    
    # Generation of sequence
    num_tokens_to_decode = (prompt.sequence == 32).sum().item()
    sequence_generation = model.generate(
        prompt,
        GenerationConfig(
            track="sequence",
            num_steps=num_tokens_to_decode,
            temperature=model_temperature,
        )
    )

    # Generate a new structure from the sequence
    sequence_generation.structure = None
    length_of_sequence = sequence_generation.sequence.numel() - 2
    sequence_generation = model.generate(
        sequence_generation,
        GenerationConfig(
            track="structure",
            num_steps=length_of_sequence,
            temperature=0.0,
        )
    )

    # Decode to AA string and coordinates.
    sequence_generation_protein = model.decode(sequence_generation)
    generated_sequence = sequence_generation_protein.sequence
    print(f"Generated Sequence {i+1}: {generated_sequence}")

    # Sequence alignment and identity calculation
    seq1 = seq.ProteinSequence(template_gfp.sequence)
    seq2 = seq.ProteinSequence(generated_sequence)
    alignments = align.align_optimal(
        seq1,
        seq2,
        align.SubstitutionMatrix.std_protein_matrix(),
        gap_penalty=(-10, -1),
    )
    alignment = alignments[0]
    identity = align.get_sequence_identity(alignment)
    sequence_identity = 100 * identity
    print(f"Sequence {i+1} Identity: {sequence_identity:.2f}%")

    # RMSD calculation
    template_chain = template_gfp.to_protein_chain()
    generation_chain = sequence_generation_protein.to_protein_chain()
    backbone_rmsd = template_chain.rmsd(generation_chain)
    print(f"Backbone RMSD {i+1}: {backbone_rmsd:.2f}")

    # Log the results in the CSV file
    with open(csv_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([f"Sequence_{i+1}", generated_sequence, f"{sequence_identity:.2f}%", f"{backbone_rmsd:.2f}"])

print("Generation complete and results logged in generation_results.csv")


Generating Sequences:  10%|█         | 1/10 [01:24<12:41, 84.65s/it]

Generated Sequence 1: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFLFCFARYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYITADKQKNGIKVNFKIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 1 Identity: 98.21%
Backbone RMSD 1: 0.46


Generating Sequences:  20%|██        | 2/10 [02:49<11:17, 84.68s/it]

Generated Sequence 2: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFXQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 2 Identity: 98.66%
Backbone RMSD 2: 0.46


Generating Sequences:  30%|███       | 3/10 [04:14<09:55, 85.00s/it]

Generated Sequence 3: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFLVXFSRYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 3 Identity: 98.21%
Backbone RMSD 3: 0.51


Generating Sequences:  40%|████      | 4/10 [05:39<08:29, 84.97s/it]

Generated Sequence 4: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFCPCFSRYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 4 Identity: 98.21%
Backbone RMSD 4: 0.47


Generating Sequences:  50%|█████     | 5/10 [07:04<07:04, 84.89s/it]

Generated Sequence 5: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFTSCFARYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYITADKQKNGIKVNFKIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 5 Identity: 97.77%
Backbone RMSD 5: 0.46


Generating Sequences:  60%|██████    | 6/10 [08:28<05:38, 84.75s/it]

Generated Sequence 6: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFLQQFARYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 6 Identity: 99.11%
Backbone RMSD 6: 0.46


Generating Sequences:  70%|███████   | 7/10 [09:54<04:14, 84.87s/it]

Generated Sequence 7: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFYLCFSRYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 7 Identity: 98.21%
Backbone RMSD 7: 0.46


Generating Sequences:  80%|████████  | 8/10 [11:19<02:49, 84.92s/it]

Generated Sequence 8: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFXLCFARYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 8 Identity: 99.11%
Backbone RMSD 8: 0.46


Generating Sequences:  90%|█████████ | 9/10 [12:44<01:24, 84.96s/it]

Generated Sequence 9: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFLQCFARYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVHLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 9 Identity: 99.55%
Backbone RMSD 9: 0.46


Generating Sequences: 100%|██████████| 10/10 [14:08<00:00, 84.89s/it]

Generated Sequence 10: KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFXQCFARYPDYMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGI
Sequence 10 Identity: 99.11%
Backbone RMSD 10: 0.46
Generation complete and results logged in generation_results.csv
CPU times: user 1.32 s, sys: 106 ms, total: 1.42 s
Wall time: 14min 11s



