<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/af_relax_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#af_relax_design (WIP)


**Efficient and scalable de novo protein design using a relaxed sequence space**

Christopher Josef Frank, Ali Khoshouei, Yosta de Stigter, Dominik Schiewitz, Shihao Feng, Sergey Ovchinnikov, Hendrik Dietz

doi: https://doi.org/10.1101/2023.02.24.529906

**<font color="red">WARNING</font>** This notebook is in development, we are still working on adding all the options from the manuscript above.

In [None]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model

from IPython.display import HTML
from google.colab import files
import numpy as np

import requests, time
if not os.path.isfile("TMscore"):
  os.system("wget -qnc https://zhanggroup.org/TM-score/TMscore.cpp")
  os.system("g++ -static -O3 -ffast-math -lm -o TMscore TMscore.cpp")
def tmscore(x,y):
  # pass to TMscore
  output = os.popen(f'./TMscore {x} {y}')
  # parse outputs
  parse_float = lambda x: float(x.split("=")[1].split()[0])
  o = {}
  for line in output:
    line = line.rstrip()
    if line.startswith("RMSD"): o["rms"] = parse_float(line)
    if line.startswith("TM-score"): o["tms"] = parse_float(line)
    if line.startswith("GDT-TS-score"): o["gdt"] = parse_float(line)
  return o
def esmfold_api(sequence):
  esmfold_api_url = 'https://api.esmatlas.com/foldSequence/v1/pdb/'
  r = requests.post(esmfold_api_url, data=sequence)
  while r.status_code != 200:
    time.sleep(5)
    r = requests.post(esmfold_api_url, data=sequence)
  structure = r.text
  return structure

In [None]:
#@title # hallucination
#@markdown For a given length, generate/hallucinate a protein sequence that AlphaFold thinks folds into a well structured protein (high plddt, low pae, many contacts).
LENGTH = 100 #@param {type:"integer"}
COPIES = 1 #@param ["1", "2", "3", "4", "5", "6", "7", "8"] {type:"raw"}
MODE = "manuscript" #@param ["original", "manuscript"]
use_rg_loss = True #@param {type:"boolean"}
use_mpnn_loss = False #@param {type:"boolean"}


In [None]:
import jax
import jax.numpy as jnp
from colabdesign.af.alphafold.common import residue_constants

def add_mpnn_loss(self, weight=0.1):
  '''add mpnn loss (maximize confidence of proteinmpnn)'''
  self.mpnn = mk_mpnn_model()
  self.mpnn.get_af_inputs(self)
  def loss_fn(inputs, outputs, key):
    atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"])
    xyz = outputs["structure_module"]
    I = {"X":           xyz["final_atom_positions"][:,atom_idx],
         "mask":        xyz["final_atom_mask"][:,residue_constants.atom_order["CA"]],
         "residue_idx": inputs["residue_index"],
         "chain_idx":   inputs["asym_id"],
         "key":         key}
    logits = self.mpnn._score(**I)["logits"]
    log_q = jax.nn.log_softmax(logits)[:,:20]
    q = jax.nn.softmax(logits[:,:20])
    loss = -(q * log_q).sum(-1).mean()
    return {"mpnn":loss}
  self._callbacks["model"]["loss"].append(loss_fn)
  self.opt["weights"]["mpnn"] = weight

def add_rg_loss(self, weight=0.1):
  '''add radius of gyration loss'''
  def loss_fn(inputs, outputs):
    xyz = outputs["structure_module"]
    ca = xyz["final_atom_positions"][:,residue_constants.atom_order["CA"]]
    if MODE == "manuscript":
      ca = ca[::5]
    rg = jnp.sqrt(jnp.square(ca - ca.mean(0)).sum(-1).mean() + 1e-8)
    if MODE == "original":
      rg_th = 2.38 * ca.shape[0] ** 0.365
      rg = jax.nn.elu(rg - rg_th)
    return {"rg":rg}
  self._callbacks["model"]["loss"].append(loss_fn)
  self.opt["weights"]["rg"] = weight

In [None]:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination")
af_model.prep_inputs(length=LENGTH, copies=COPIES)

# add extra losses
if use_rg_loss:   add_rg_loss(af_model)
if use_mpnn_loss: add_mpnn_loss(af_model)

print("length",af_model._lengths)
print("weights",af_model.opt["weights"])

In [None]:
af_model.restart()
if MODE == "original":
  # pre-design with gumbel initialization and softmax activation
  af_model.set_weights(plddt=0.0, pae=0.0)
  af_model.set_seq(mode=["gumbel"])
  af_model.design_soft(50)
  af_model.set_seq(af_model.aux["seq"]["pseudo"])

if MODE == "manuscript":
  af_model.set_seq(mode=["gumbel","soft"])

af_model.set_weights(plddt=1.0, pae=1.0)
af_model.design_logits(40)
af_model.design_logits(10, save_best=True)

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

#Redesign with ProteinMPNN

In [None]:
from colabdesign.shared.protein import alphabet_list as chain_list
mpnn_model = mk_mpnn_model()
mpnn_model.prep_inputs(pdb_filename=f"{af_model.protocol}.pdb",
                       chain=",".join(chain_list[:COPIES]),
                       homooligmer=COPIES>1,
                       rm_aa="C")
out = mpnn_model.sample(num=1, batch=8)

In [None]:
for seq,score in zip(out["seq"],out["score"]):
  print(score,seq.split("/")[0])

#ESMfold TEST

In [None]:
print("# rmsd tmscore sequence")
best = {}
best_rmsd = None
for n,seq in enumerate(out["seq"]):
  x = seq.split("/")[0]
  with open(f"{af_model.protocol}.esmfold.{n}.pdb","w") as handle:
    pdb_str = esmfold_api(x)
    handle.write(pdb_str)
  o = tmscore(f"{af_model.protocol}.pdb",
              f"{af_model.protocol}.esmfold.{n}.pdb")
  print(n,o["rms"],o["tms"],x)
  if best_rmsd is None or o["rms"] < best_rmsd:
    best_rmsd = o["rms"] 
    best = {**o,"seq":x}

In [None]:
best