<a href="https://colab.research.google.com/github/sirius777coder/GPDL/blob/main/GPDL_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## GPDL Jupyter Notebook

This is a Google Colab notebook for running GPDL (Generative Protein Design by protein-Language model). This notebook provides an interactive environment for setting up the required dependencies (ESMFold, OpenFold, PyTorch Geometric, etc.), running protein structure prediction and design tasks, and experimenting with generative protein models. The notebook is organized into sections for environment setup and dependency installation, model loading and initialization, and protein design and prediction workflows. Google Colab environment with GPU support is recommended, and an internet connection is required for downloading models and dependencies. Run cells sequentially from top to bottom. The setup cells will install all necessary packages and download model parameters automatically.

# Part one seeding

In [None]:
#@title Install dependencies

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# # Install biotite
!pip install -q biotite

# # Install esmfold
print("installing libs...")
os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif")
os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

print("installing openfold...")
# # install openfold
os.system(f"pip install -q git+https://github.com/sokrypton/openfold.git")
print("installing esmfold...")
# # install esmfold
# os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")


# Install esmif torch geometry
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-geometric

# Install GPDL
!git clone https://github.com/sirius777coder/GPDL.git

import time
print("downloading parameters...")
os.system("apt-get install aria2 -qq")

command = """
aria2c -q -x 16 https://zenodo.org/records/17254400/files/seeding.model &
aria2c -q -x 16 https://zenodo.org/records/17254400/files/esm_if.model &
aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &
wait
"""
os.system(command)
print("\nAll parameters have been downloaded successfully!")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m84.8 MB/s[0m eta [36m0:00:00[0m
[?25hinstalling libs...
installing openfold...
installing esmfold...
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m66.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m23.0 MB/s[0m 

In [2]:
#@title Loading GPDL

%cd GPDL/gpdl_inpainting
import modules

seed_model = torch.load("/content/seeding.model",weights_only=False,map_location="cuda")
seed_model.eval().requires_grad_(False)

/content/GPDL/gpdl_inpainting


esm_inpaint(
  (esmfold): ESMFold(
    (esm): ESM2(
      (embed_tokens): Embedding(33, 2560, padding_idx=1)
      (layers): ModuleList(
        (0-35): 36 x TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
            (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
            (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
            (out_proj): Linear(in_features=2560, out_features=2560, bias=True)
            (rot_emb): RotaryEmbedding()
          )
          (self_attn_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
          (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        )
      )
      (contact_head): ContactPredictionHead(
        (regression): Linear(in_fea

In [None]:
#@title Motif-scaffolding preparation

# protein_name = '2KL8' #@param ['1QJG', '4JHW', '5TPN', '1PRW', '4ZYP', '1YCR', '7r8f', '1BCF', '6E6R', '5TRV', '2KL8', '5WN9', '5YUI', '7MRX', '5IUS', '6EXZ', '3IXT', '6VW1']
protein_name = '2KL8' #@param {type:"string"}
protein_name = protein_name.upper()
inpaint_seq = "0,A1-7,20,A28-79,0" #@param {type:"string"}
input_pdb = f"./benchmark_set/{protein_name}.pdb"
import json
info_data = {"protein_name": protein_name, "inpaint_seq": inpaint_seq}
with open(f"/content/data.json", "w") as f:
    json.dump(info_data, f, indent=4)
if not os.path.exists(input_pdb):
    print(f"Downloading pdb from https://files.rcsb.org/download/{protein_name}.pdb")
    os.system(f"wget -O {input_pdb} https://files.rcsb.org/download/{protein_name}.pdb")
else:
    print(f"Loading PDB from: {input_pdb}")

import torch
import torch.nn as nn
import torch.nn.functional as F

from openfold.utils.rigid_utils import Rigid
import numpy as np
import pickle
from collections import OrderedDict



import numpy as np
import biotite
import biotite.structure as struc
import biotite.structure.io as strucio
from biotite.structure.residues import get_residues
from biotite.sequence import ProteinSequence
import utils
inapint_info = []
motif_mask = ""  # ["1111"] 1 unmasked, 0 masked
# inpaint_seq = "0,A1-7,20,A28-79,0"
# protein_name = "2KL8"
mask_aa = "A"
import shutil
shutil.copy(f"{input_pdb}",f"/content/{protein_name}.pdb")
# parsing the inpaint_seq
segment = (inpaint_seq).split(",")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for i in range(len(segment)):
    # scaffold region
    if segment[i][0] not in [chr(ord('a')+_) for _ in range(26)] and segment[i][0] not in [chr(ord('A')+_) for _ in range(26)]:
        if "-" in segment[i]:
            a, b = segment[i].split("-")
            a, b = int(a), int(b)
            if a == 0:
                a = 1
            scaffold = np.random.randint(a, b+1)
        else:
            scaffold = int(segment[i])
        motif_mask += "0" * scaffold
        inapint_info.append({"mask": scaffold})
        # 1 meaning position is unmasked motif and 0 meaning position is masked scaffold.
    else:  # motif region
        chain = segment[i][0]
        if "-" in segment[i]:
            start, end = (segment[i][1:]).split("-")
            start = int(start)
            end = int(end)
            length = end-start+1
        else:
            start = end = int(segment[i][1:])
            length = 1
        motif_mask += "1" * length
        inapint_info.append({f"{chain}": [start, end]})

# load the input file by biotite (only standard aa will in this AtomArray)
structure = utils.load_structure(input_pdb)
# preliminaries
inpaint_seq = ""
inpaint_coord = np.zeros((len(motif_mask), 4, 3))
location = 0

# inpaint_info : [{'mask': 9}, {'A': [119, 140]}, {'mask': 18}, {'A': [63, 82]}, {'mask': 28}]
for item in inapint_info:
    if list(item.keys())[0] == "mask":  # mask region (scaffold region)
        inpaint_seq += mask_aa * item['mask']
        location += item['mask']
    else:  # motif region (fix to some coordinates)
        chain_name = list(item.keys())[0]
        start, end = int(item[chain_name][0]), int(item[chain_name][1])
        for res_id in range(start, end+1):
            res_atom_array = structure[(structure.chain_id == chain_name) & (
                structure.res_id == res_id)]
            res_name = ProteinSequence.convert_letter_3to1(
                get_residues(res_atom_array)[1][0])
            inpaint_seq += res_name
            inpaint_coord[location][0] = res_atom_array[res_atom_array.atom_name == "N"].coord[0]
            inpaint_coord[location][1] = res_atom_array[res_atom_array.atom_name == "CA"].coord[0]
            inpaint_coord[location][2] = res_atom_array[res_atom_array.atom_name == "C"].coord[0]
            inpaint_coord[location][3] = res_atom_array[res_atom_array.atom_name == "O"].coord[0]
            location += 1

device = "cuda"

seq = torch.tensor([utils.restype_order[i] for i in inpaint_seq],
                    dtype=torch.long).unsqueeze(0).to(device)
coord = (torch.from_numpy(inpaint_coord).to(
    torch.float)).unsqueeze(0).to(device)

Loading PDB from: ./benchmark_set/2KL8.pdb


In [4]:
#@title Motif-scaffolding running

import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    with torch.no_grad():
        output = seed_model.infer(coord, seq, T=1, motif_mask=torch.tensor(
            [int(i) for i in list(motif_mask)], device=coord.device).unsqueeze(0))

%cd ../../
output_file = "./temp.pdb"
output_txt = "./temp.txt"
output_esmif = "./esmif.fasta"
num_samples = 1
with open(f"{output_file}", "w") as f:
    f.write(output[0])

# simultaneously note the sequence location of each design
with open(f"{output_txt}", "a") as f:
    f.write(f"0\n{inapint_info}\n")

/content


# Part two seeding to optimization

In [6]:
#@title Loading esmif

esmif = torch.load("/content/esm_if.model",weights_only=False)
esmif.eval().to("cuda").requires_grad_(False)

GVPTransformerModel(
  (encoder): GVPTransformerEncoder(
    (dropout_module): Dropout(p=0.1, inplace=False)
    (embed_tokens): Embedding(35, 512, padding_idx=1)
    (embed_positions): SinusoidalPositionalEmbedding()
    (embed_gvp_input_features): Linear(in_features=15, out_features=512, bias=True)
    (embed_confidence): Linear(in_features=16, out_features=512, bias=True)
    (embed_dihedrals): DihedralFeatures(
      (node_embedding): Linear(in_features=6, out_features=512, bias=True)
      (norm_nodes): Normalize()
    )
    (gvp_encoder): GVPEncoder(
      (embed_graph): GVPGraphEmbedding(
        (embed_node): Sequential(
          (0): GVP(
            (wh): Linear(in_features=3, out_features=256, bias=False)
            (ws): Linear(in_features=263, out_features=1024, bias=True)
            (wv): Linear(in_features=256, out_features=256, bias=False)
          )
          (1): LayerNorm(
            (scalar_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
         

In [None]:
#@title Sampling seqeucnes

# esmif code
import sys
import GPDL.gpdl_inpainting.esm.inverse_folding as inverse_folding
from pathlib import Path


def sample_seq_singlechain(model):
    coords, native_seq = inverse_folding.util.load_coords(output_file, "A")
    print('Native sequence loaded from structure file:')
    print(native_seq)

    print(f'Saving sampled sequences to {output_esmif}.')

    with open(output_txt,"r") as f:
        data = f.readlines()
    motif_info = eval(data[1].strip())
    binary_string = ""
    for item in motif_info:
        if 'mask' in item:
            binary_string += '0' * item['mask']
        else:
            chain_name = list(item.keys())[0]
            start, end = item[chain_name]
            binary_string += '1' * (end - start + 1)
    partial_seq = [native_seq[i] if binary_string[i] == '1' else '<mask>' for i in range(len(binary_string))]
    Path(output_esmif).parent.mkdir(parents=True, exist_ok=True)
    with open(output_esmif, 'w') as f:
        for i in range(num_samples):
            print(f'\nSampling.. ({i+1} of {num_samples})')
            sampled_seq = model.sample(coords, temperature=0.1, partial_seq=partial_seq,device=torch.device('cuda'))
            print('Sampled sequence:')
            print(sampled_seq)
            f.write(f'>sampled_seq_{i+1}\n')
            f.write(sampled_seq + '\n')

            recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
            print('Sequence recovery:', recovery)
with torch.no_grad():
    sample_seq_singlechain(esmif)

import gc
if hasattr(seed_model, 'esmfold'):
    del seed_model.esmfold
del seed_model
del esmif
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()



import sys
import os
import shutil
import gc


# 1. reset variable
%reset -f

# 2. reset packages and path
import sys, os, gc, shutil
paths_to_remove = [path for path in sys.path if 'GPDL' in path]
for path in paths_to_remove:
    try: sys.path.remove(path)
    except ValueError: pass
modules_to_remove = [mod for mod in sys.modules if 'GPDL' in mod or 'esm' in mod]
for mod in modules_to_remove:
    del sys.modules[mod]


# 3. remove GPDL file
if os.path.exists('/content/GPDL'):
    try:
        os.chdir('/content/')
        shutil.rmtree('/content/GPDL')
    except OSError as e:
        print(f"  Failed: {e}")
else:
    print("  - '/content/GPDL' not exists。")

gc.collect()
!ls /content/

Native sequence loaded from structure file:
MEMDIRFSVESFELPSEFWCFPCTSQCKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
Saving sampled sequences to ./esmif.fasta.

Sampling.. (1 of 1)
Sampled sequence:
MEMDIRFTLSPDLTPPAELPPPALAALKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
Sequence recovery: 0.7721518987341772
2KL8.pdb   esmfold.model  esm_if.model	seeding.model  temp.txt
data.json  esmif.fasta	  sample_data	temp.pdb


# Part three optimization

In [None]:
#@title Preparing for optimization
import json
info_data = json.load(open(f"/content/data.json", "r"))
protein_name = info_data["protein_name"]
inpaint_seq = info_data["inpaint_seq"]

pre_sequence = "/content/esmif.fasta"
reference = f"./{protein_name}.pdb"
bb_suffix="GPDL"
step = 100 #@param [100, 500] {type:"raw"}
loss = 10
t1 = 1
t2 =  500
max_mut = 15
number = 1


def extract_from_inpaint_seq(inpaint_seq):
    """
    from inpainting seq -> mask_len & motif_id

    input:
        inpaint_seq: 字符串，例如 "0,A1-7,20,A28-79,0"

    output:
        tuple: (mask_len, motif_id)
    """
    segments = inpaint_seq.split(",")

    mask_len = ",".join([seg for seg in segments if seg and not seg[0].isalpha()])

    motif_id = ",".join([seg for seg in segments if seg and seg[0].isalpha()])

    return mask_len, motif_id

# mask_len="0,20,0"
# motif_id="A1-7,A28-79"
mask_len, motif_id = extract_from_inpaint_seq(inpaint_seq)

atoms = "N,CA,C,O".split(',')

output_dir = "./output"
final_des_dir = "./final_des_dir"

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(final_des_dir):
    os.makedirs(final_des_dir)



mask_len = [int(i) for i in (mask_len).split(',')]
motif_id = (motif_id).split(',')
scaf_len = sum(mask_len)
AA_freq = {'A': 0.07421620506799341,
 'R': 0.05161448614128464,
 'N': 0.044645808512757915,
 'D': 0.05362600083855441,
 'C': 0.02468745716794485,
 'Q': 0.03425965059141602,
 'E': 0.0543119256845875,
 'G': 0.074146941452645,
 'H': 0.026212984805266227,
 'I': 0.06791736761895376,
 'L': 0.09890786849715096,
 'K': 0.05815568230307968,
 'M': 0.02499019757964311,
 'F': 0.04741845974228475,
 'P': 0.038538003320306206,
 'S': 0.05722902947649442,
 'T': 0.05089136455028703,
 'W': 0.013029956129972148,
 'Y': 0.03228151231375858,
 'V': 0.07291909820561925}
letters = {'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLU': 'E', 'GLN': 'Q', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'}


from Bio.PDB import *
parser = PDBParser()
structure = parser.get_structure("ref", reference)
model = structure[0]
motif_seq={} #0:motif1_seq; 1:motif2_seq.
coord = []
dm_id = []
motif_start = mask_len[0]+1
for motif_idx,i in enumerate(motif_id):
    chain_id = i[0]
    s,e = [int(x) for x in i[1:].split('-')]

    motif_len = e-s+1
    dm_id+=(list(range(motif_start,motif_start+motif_len)))
    motif_start = motif_start+motif_len+mask_len[motif_idx+1]

    motif_seq[motif_idx] = ''
    chain = model[chain_id]
    residues = chain.get_residues()
    for res in residues:
        resname = res.get_resname()
        res_id = res.get_id()[1]
        if resname in letters.keys() and int(res_id) in range(s,e+1):
            motif_seq[motif_idx] += letters[resname]
            for atom in atoms:
                pos = res[atom]
                coord.append(pos.get_coord())
ref = np.array(coord)


def parse_fasta(fasta_string: str) :
  """Parses FASTA string and returns list of strings with amino-acid sequences.


  Arguments:
    fasta_string: The string contents of a FASTA file.
    > with open(input_fasta_path) as f:
    >  input_fasta_str = f.read()

  Returns:
    A tuple of two lists:
    * A list of sequences.
    * A list of sequence descriptions taken from the comment lines. In the
      same order as the sequences.
  """
  sequences = []
  descriptions = []
  index = -1
  for line in fasta_string.splitlines():
    line = line.strip()
    if line.startswith('>'):
      index += 1
      descriptions.append(line[1:])  # Remove the '>' at the beginning.
      sequences.append('')
      continue
    elif not line:
      continue  # Skip blank lines.
    sequences[index] += line

  return sequences, descriptions


with open(pre_sequence, 'r') as f:
    fasta_string = f.read()
sequences, de = parse_fasta(fasta_string)
des_len = len(sequences[0])

import typing as T
from pathlib import Path
PathLike = T.Union[str, Path]

def create_batched_sequence_datasest(
    sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024
) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]:

    batch_headers, batch_sequences, num_tokens = [], [], 0
    for header, seq in sequences:
        if (len(seq) + num_tokens > max_tokens_per_batch) and num_tokens > 0:
            yield batch_headers, batch_sequences
            batch_headers, batch_sequences, num_tokens = [], [], 0
        batch_headers.append(header)
        batch_sequences.append(seq)
        num_tokens += len(seq)

    yield batch_headers, batch_sequences


def main(model,all_sequences, num, motif):
    # all_sequences = sorted(read_fasta(output_fasta), key=lambda header_seq: len(header_seq[1]))
    # logger.info(f"Loaded {len(all_sequences)} sequences from {output_fasta}")

    # logger.info("Starting Predictions")
    batched_sequences = create_batched_sequence_datasest(
        all_sequences, 1024
    )

    num_completed = 0
    num_sequences = len(all_sequences)
    for headers, sequences in batched_sequences:
        try:
            output = model.infer(sequences, num_recycles=3)
        except RuntimeError as e:
                raise os.error(f"Error during prediction: {e}")

        output = {key: value.cpu() for key, value in output.items()}
        pdbs = model.output_to_pdb(output)
        if len(sequences) > 1:
            time_string = time_string + f" (amortized, batch size {len(sequences)})"
        coord = np.empty(shape=(0,3), dtype = float)
        for header, seq, pdb_string, mean_plddt, ptm, pae in zip(
            headers, sequences, pdbs, output["mean_plddt"], output["ptm"], output['predicted_aligned_error']
        ):
            output_file = Path(f"{output_dir}/{num}_{header}.pdb")
            # output_file.write_text(pdb_string)
            num_completed += 1
            mean_pae = torch.mean(pae)

            plddts = [] #every res
            plddt = {}
            pdb_lines = pdb_string.split("\n")
            for line in pdb_lines:
                line = line.split()

                if line==[] or line[0] != 'ATOM':
                    continue

                if line[5] in plddt.keys():
                    plddt[line[5]].append(float(line[10]))
                else:
                    plddt[line[5]] = [float(line[10])]

                if int(line[5]) in motif and line[2] in atoms:
                    pos = np.array([[line[6],line[7],line[8]]],dtype=float)
                    coord = np.concatenate((coord,pos),0)

            for res in plddt.keys():
                plddts.append(np.mean(plddt[res]))
    # save_path = output_file
    return output_file,pdb_string, coord, mean_plddt, plddts, ptm, mean_pae


import os
# if 'esm' in sys.modules:
#     print("从缓存中卸载已导入的 esm 模块...")
#     del sys.modules['esm']
print("installing esmfold...")
# install esmfold
os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")


# loss file
from turtle import shape
from Bio.PDB import *
import numpy as np
from Bio.SVDSuperimposer import SVDSuperimposer
import sys, os, argparse, copy, subprocess, glob, time, pickle, json, tempfile, random

def get_coord (pdb, m_id, chain, option="CA"):
    paser = PDBParser()
    structure = paser.get_structure("pdb", pdb)

    res_dict={}#ref蛋白所有AA
    for residue in structure.get_residues():
        res_idx=int(str(residue).split()[3].split("=")[1]) #得到所有残基的序号
        res_dict[res_idx]=residue #字典【氨基酸序号】=氨基酸(生成器)
        # print(res_idx)

    if option == "CA":
        coord = []
        # x=0
        model = structure[0]
        chain = model[chain]
        for id in m_id:
            res = res_dict[id]
            ca = res["CA"]
            coord.append(ca.get_coord())
        coord = np.array(coord)
        return coord

def get_rmsd(ref_coord,des_coord):
    sup = SVDSuperimposer()
    sup.set(ref_coord, des_coord)
    sup.run()
    motif_rmsd = sup.get_rms()
    rot, tran = sup.get_rotran()
    return motif_rmsd,rot,tran

def get_lddt(pdb):
    plddt = {} #res_id:[plddt],omfold res_id start from 0
    plddts = [] #every res
    lddt=[] #every atom --- average
    with open(pdb) as f:
        atom = 0
        for line in f.readlines():
            line = line.replace("\n", "").split()
            if line[0] != 'ATOM':
                continue
            else:
                if line[5] in plddt.keys():
                    plddt[line[5]].append(float(line[10]))
                else:
                    plddt[line[5]] = [float(line[10])]

    for res in plddt.keys():
        plddts.append(np.mean(plddt[res]))
        lddt+=plddt[res]
    lddt=np.mean(lddt)
    return plddt,plddts,lddt

def get_potential(po, all_coord, rot, tran, van_r):
    clash = 0
    coord = np.dot(all_coord, rot) + tran
    for i in coord:
        distance = np.linalg.norm(i-po)
        if distance < van_r: #1.252
            clash += 1
    return clash

#!/usr/bin/env python
# coding: utf-8

import sys, os, argparse, copy, subprocess, glob, time, pickle, json, tempfile, random
import numpy as np



def select_positions(plddts,n_mutation,dm_id, des_len, option='r'):
    mutate_plddt_quantile = 0.25 # default worst pLDDT quantile to mutate.
    weights = np.array([0.25, 0.5, 0.75] + [1] * (des_len - 6) + [0.75, 0.5, 0.25])
    n_potential = round(des_len * mutate_plddt_quantile)
    sca=[]
    for i in np.argsort(plddts):
        if int(i)+1 in dm_id:
            pass
        else:
            sca.append(i)
    potential_sites = sca[:n_potential]
    sub_w = weights[potential_sites]
    sub_w = [w/np.sum(sub_w) for w in sub_w]

    if option == 'p':
        sites = np.random.choice(potential_sites, size=n_mutation, replace=False, p=sub_w)
    elif option == 'r':
        sites = np.random.choice(sca, size=n_mutation, replace=False, p=None)
    elif option == 'pr':
        sites1 = np.random.choice(potential_sites, size=round(n_mutation/2), replace=False, p=sub_w)
        sites2=np.random.choice(np.setdiff1d(sca,sites1),size=n_mutation-round(n_mutation/2),replace=False, p=None)
        sites=np.append(sites1,sites2)
    return sites

def random_mutate(seq,sites):
    # AA order.
    aas = np.array(list('ARNDCQEGHILKMFPSTWYV'))
    P_random = np.ones([20,20])
    np.fill_diagonal(P_random, 0)
    for p in sites:
        current_aa = seq[p]
        idx = np.argwhere(aas==current_aa)[0][0]
        sub_prob_renorm = P_random[:,idx] / P_random[:,idx].sum() # get subsitutation vector for that aa, and renormalise the vector.
        sub_prob = {a:f for a, f in list(zip(aas, sub_prob_renorm))}
        # print(current_aa, idx, P_random[:,idx],sub_prob_renorm)
        # Make mutations.
        seq = seq[:p] + np.random.choice(list(sub_prob.keys()), p=list(sub_prob.values())) + seq[p+1:]
    return seq

installing esmfold...


In [13]:
#@title Loading model

import torch
esm_model = torch.load("/content/esmfold.model",weights_only=False,map_location="cuda")
esm_model.eval().to("cuda").requires_grad_(False)

ESMFold(
  (esm): ESM2(
    (embed_tokens): Embedding(33, 2560, padding_idx=1)
    (layers): ModuleList(
      (0-35): 36 x TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (out_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=2560, out_features=10240, bias=True)
        (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
      )
    )
    (contact_head): ContactPredictionHead(
      (regression): Linear(in_features=1440, out_features=1, bias=True)
      (activation): Sigmo

In [None]:
#@title Running optimization

des_seqs = []
fst_suc_step = []
for init_seq_idx, des_seq in enumerate(sequences):
    traj = []
    num = init_seq_idx

    desc = f'{num}_mut0'
    all_sequences = [(desc, des_seq)]

    save_path,pdb,des_coord,plddt,plddts, ptm, mean_pae=main(esm_model,all_sequences,num,dm_id)

    M = np.linspace(round(max_mut/100*scaf_len), 1, step) # stepped linear decay of the mutation rate

    for i in range(step):
        # Update a few things.
        T = t1*(np.exp(np.log(0.5) / t2) ** i) # update temperature
        n_mutation = round(M[i]) # update mutation rate
        accepted = False # reset

        if i == 0: # do a first pass through the network before mutating anything -- baseline
            rmsd,rot,tran=get_rmsd(ref,des_coord)
            print(f"RMSD is {rmsd}")
            print(f"plddt is {plddt}")
            current_loss=100-plddt+loss*rmsd
            print(f"current loss is {current_loss}")
            print(f"{loss},{t1},{t2}")

            traj.append((i, desc, des_seq, pdb, rmsd, plddt, mean_pae, ptm, True))

        else:
            #introduce mutation
            sites=select_positions(plddts,n_mutation,dm_id, des_len, option='r')
            print(f'mut{i} mutation sites: {sites}')
            mut_seq=random_mutate(des_seq,sites)
            print(f'mut{i} seq: {mut_seq}')


            desc = f'{num}_mut{i}'
            all_sequences = [(desc, mut_seq)]

            #prediction
            save_path,pdb,mut_coord,mut_plddt,plddts, ptm, mean_pae=main(esm_model,all_sequences,num,dm_id)
            mut_rmsd,rot,tran=get_rmsd(ref,mut_coord)
            print(f"RMSD is {mut_rmsd}")
            try_loss=100-mut_plddt+loss*mut_rmsd

            delta = try_loss - current_loss
            print(f'current loss is {current_loss}')
            print(f'try loss is {try_loss}, {i}')

            # If the new solution is better, accept it.
            if delta < 0:
                accepted = True
                print(f"do accept")
                current_loss = try_loss # accept loss change
                des_seq=mut_seq # accept the mutation
                rmsd=mut_rmsd
                plddt=mut_plddt

            # If the new solution is not better, accept it with a probability of e^(-cost/temp).
            else:

                if np.random.uniform(0, 1) < np.exp( -delta / T):
                    accepted = True
                    print(f"do accept")
                    current_loss = try_loss # accept loss change
                    des_seq=mut_seq # accept the mutation
                    rmsd=mut_rmsd
                    plddt=mut_plddt

                else:
                    accepted = False
                    print(f'not accept')

            traj.append((i, desc, des_seq, pdb, mut_rmsd, mut_plddt, mean_pae, ptm, accepted))

        if rmsd < 1 and plddt > 80:
            break

    des_seqs.append(des_seq)
    fst_suc_step.append(i)

    with open(f"{output_dir}/{bb_suffix}_{num}.pkl", 'wb') as f:
        pickle.dump(traj, f)

    # seq = SeqRecord(Seq(des_seq),id=f"final_des",description="")


    desc = f'final_des{bb_suffix}_{num}'
    all_sequences = [(desc, des_seq)]
    save_path,pdb, des_coord,plddt,plddts, ptm, mean_pae=main(esm_model,all_sequences,f'final_des{bb_suffix}_{num}',dm_id)
    save_path = Path(f"{final_des_dir}/final_des{bb_suffix}_{num}.pdb")
    save_path.write_text(pdb)
    final_rmsd,rot,tran=get_rmsd(ref,des_coord)
    print(f'****** final_des{bb_suffix}_{num}: ,motif_RMSD:{final_rmsd}, plddt:{plddt} *******')
    t_end=time.time()

  return data[ranges]


RMSD is 1.0276070141791274
plddt is 54.63172912597656
current loss is 55.64434051513672
10,1,500
mut1 mutation sites: [ 9  8 10]
mut1 seq: MEMDIRFTNNDDLTPPAELPPPALAALKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
RMSD is 1.1572841547899568
current loss is 55.64434051513672
try loss is 54.36073303222656, 1
do accept
mut2 mutation sites: [17 16 26]
mut2 seq: MEMDIRFTNNDDLTPPMGLPPPALAAPKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
RMSD is 1.1359753270338386
current loss is 54.36073303222656
try loss is 45.40742874145508, 2
do accept
mut3 mutation sites: [25 26 19]
mut3 seq: MEMDIRFTNNDDLTPPMGLDPPALALDKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
RMSD is 0.8930911597087826
current loss is 45.40742874145508
try loss is 50.43828582763672, 3
not accept
mut4 mutation sites: [15 18 14]
mut4 seq: MEMDIRFTNNDDLTSNMGAPPPALAAPKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE


  if np.random.uniform(0, 1) < np.exp( -delta / T):


RMSD is 1.1256092902341437
current loss is 45.40742874145508
try loss is 56.283042907714844, 4
not accept
mut5 mutation sites: [19 25 26]
mut5 seq: MEMDIRFTNNDDLTPPMGLDPPALAHGKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
RMSD is 0.9359676560267194
current loss is 45.40742874145508
try loss is 47.47319030761719, 5
not accept
mut6 mutation sites: [21  9 20]
mut6 seq: MEMDIRFTNQDDLTPPMGLPMLALAAPKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
RMSD is 0.7831963317989128
current loss is 45.40742874145508
try loss is 50.73295974731445, 6
not accept
mut7 mutation sites: [ 8 14  7]
mut7 seq: MEMDIRFNMNDDLTYPMGLPPPALAAPKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
RMSD is 1.0362295571240767
current loss is 45.40742874145508
try loss is 57.034278869628906, 7
not accept
mut8 mutation sites: [23 12 16]
mut8 seq: MEMDIRFTNNDDATPPSGLPPPAEAAPKFAGTVTYTLDGNDLEIRITGVPEQVRKELAKEAERLAKEFNITVTYTIRLE
RMSD is 1.1157428991734746
current loss is 45.40742874145508
try loss is 47.53910064697265

In [15]:
#@title Saving output


with open(f"/content/final_des_dir/final_desGPDL_0.pdb","r") as f:
    pdb_str = f.read()
lengths = [len(sequences[0])]

In [16]:
#@title display (optional) {run: "auto"}
import py3Dmol
pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,
             color="pLDDT", chains=None, vmin=50, vmax=90,
             size=(800,480), hbondCutoff=4.0,
             Ls=None,
             animate=False):

  if chains is None:
    chains = 1 if Ls is None else len(Ls)
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])
  if animate:
    view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  else:
    view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  if color == "pLDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  view.zoomTo()
  if animate: view.animate()
  return view

color = "rainbow" #@param ["confidence", "rainbow", "chain"]
if color == "confidence": color = "pLDDT"
show_sidechains = True #@param {type:"boolean"}
show_mainchains = True #@param {type:"boolean"}
show_pdb(pdb_str, color=color,
         show_sidechains=show_sidechains,
         show_mainchains=show_mainchains,
         Ls=lengths).show()