<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.0/mpnn/examples/afdesign_and_proteinmpnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AfDesign + ProteinMPNN (v1.1.0)
Backprop through AlphaFold for protein design.

**WARNING**
1.   This notebook is in active development and was designed for demonstration purposes only.
2.   Using AfDesign as the only "loss" function for design might be a bad idea, you may find adversarial sequences (aka. sequences that trick AlphaFold). To avoid this problem, we couple it with ProteinMPNN.

In [None]:
#@title install
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0
  # for debugging
  ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params
fi

In [None]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

from colabdesign.af import mk_af_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.shared.utils import copy_dict
from IPython.display import HTML
from google.colab import files

import numpy as np
from scipy.special import softmax, log_softmax
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

class setup_mpnn:
  def __init__(self, af_model, precompute=True, conditional=False, backprop=False, replace=0.01):
    self.af = af_model
    self.mpnn = mk_mpnn_model()
    self.atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"])
    self.replace = replace
    self.conditional = conditional
    self.backprop = backprop

    if precompute:
      self.precompute()
    else:
      self.af._callbacks["design"]["post"].append(self._design_callback)

    self.af._callbacks["model"]["loss"].append(self._loss_callback)
    self.af.opt["weights"]["mpnn_loss"] = 1.0
    self.af.opt["weights"]["mpnn_ent"] = 0.0

  def run(self, seq, atom_positions, atom_mask, residue_index, asym_id, key):
    # INPUTS
    I = {"X":           atom_positions[:,self.atom_idx],
         "mask":        atom_mask[:,1],
         "residue_idx": residue_index,
         "chain_idx":   asym_id,
         "key":         key}
    if self.conditional:
      I["S"] = seq[0]
      I["ar_mask"] = 1 - np.eye(I["S"].shape[0])
      if self.af.protocol == "binder":
        L = self.af._target_len
        I["ar_mask"][L:,L:] = 0

    # RUN   
    logits = self.mpnn._score(**I)["logits"]

    # OUTPUTS
    if self.af.protocol == "binder":
      L = self.af._target_len
      logits = logits[L:]
    else:
      L = self.af._params["seq"].shape[1]
      logits = logits[:L]
    return logits

  def precompute(self):
    inputs = self.af._inputs
    logits = self.run(inputs["batch"]["aatype"],
                      inputs["batch"]["all_atom_positions"],
                      inputs["batch"]["all_atom_mask"],
                      inputs["residue_index"],
                      inputs["asym_id"],
                      self.af.key())
    self.af.opt["mpnn"] = self.logits = logits
  
  def _design_callback(self, af_model):
    self.logits = af_model.aux["mpnn"]
    af_model._inputs["bias"] = (1-self.replace) * af_model._inputs["bias"] + self.replace * af_model.aux["mpnn"][:,:20]

  def _loss_callback(self, inputs, aux, opt, seq, key):
    if "mpnn" in opt:
      logits = opt["mpnn"]
    else:
      logits = self.run(seq["hard"],
                        aux["atom_positions"],
                        aux["atom_mask"],
                        inputs["residue_index"],
                        inputs["asym_id"],
                        key)
      if not self.backprop:
        logits = jax.lax.stop_gradient(logits)
        
      aux["mpnn"] = logits

    # define loss function
    log_q = jax.nn.log_softmax(logits)[:,:20]
    q = jax.nn.softmax(logits[:,:20])
    p = inputs["seq"]["soft"]
    losses = {}
    losses["mpnn_ent"] = -(q * log_q).sum(-1).mean()
    losses["mpnn_loss"] = -(p * log_q).sum(-1).mean()
    return losses

# fixed backbone design (fixbb)
For a given protein backbone, generate/design a new sequence that AlphaFold thinks folds into that conformation. 

In [None]:
clear_mem()
af_model = mk_af_model(protocol="fixbb")
af_model.prep_inputs(pdb_filename=get_pdb("1TEN"), chain="A")
mpnn_model = setup_mpnn(af_model, precompute=True)

print("length",  af_model._len)
print("weights", af_model.opt["weights"])

In [None]:
# precompute unconditional probabilities from mpnn
print("max_mpnn_loss",-log_softmax(mpnn_model.logits,-1).max(-1).mean())
plt.imshow(softmax(mpnn_model.logits,-1).T,vmin=0,vmax=1)

In [None]:
af_model.restart()
af_model.set_seq(bias=mpnn_model.logits[:,:20])
af_model.set_weights(mpnn_loss=0.1)
af_model.design_3stage(0,200,10)

In [None]:
af_model.plot_traj()

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

# hallucination
For a given length, generate/hallucinate a protein sequence that AlphaFold thinks folds into a well structured protein (high plddt, low pae, many contacts).

In [None]:
def rg_loss(inputs, outputs):
  positions = outputs["structure_module"]["final_atom_positions"]
  ca = positions[:,residue_constants.atom_order["CA"]]
  center = ca.mean(0)
  rg = jnp.sqrt(jnp.square(ca - center).sum(-1).mean() + 1e-8)
  rg_th = 2.38 * ca.shape[0] ** 0.365
  rg = jax.nn.elu(rg - rg_th)
  return {"rg":rg}

In [None]:
clear_mem()
af_model = mk_af_model(protocol="hallucination",
                       loss_callback=rg_loss) # add custom Radius of Gyration loss
af_model.prep_inputs(length=100)
af_model.opt["weights"]["rg_loss"] = 0.1
mpnn_model = setup_mpnn(af_model,
                        precompute=False) # since we do not know what structure we want, we cannot precompute the mpnn logits
mpnn_model.replace = 0.01 # rate at which to copy output mpnn logits to alphafold bias

print("length",af_model._len)
print("weights",af_model.opt["weights"])

In [None]:
# pre-design with gumbel initialization and softmax activation
af_model.restart()
af_model.set_seq(mode="gumbel")
af_model.set_weights(mpnn_ent=0.1,   # maximize confidence of mpnn output
                     mpnn_loss=0.01, # minimize difference between mpnn output and input sequence
                     helix=-0.1,     # encourage non-helical content
                     ) 
af_model.design_soft(100, verbose=10)

In [None]:
# lets see what the PDB looks like (if you don't like, rerun the cell before)
af_model.plot_pdb()

In [None]:
# refinement round!
af_model.set_seq(seq=af_model.aux["seq"]["pseudo"])
af_model.set_weights(mpnn_ent=0.1, mpnn_loss=0.1, helix=0, pae=0.1) # increase mpnn weights
af_model.design_3stage(100, 100, 10)

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb(color="pLDDT")

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

# binder hallucination
For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure.
To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder.

In [None]:
clear_mem()
mpnn_model = mk_mpnn_model()
af_model = mk_af_model(protocol="binder")
af_model.prep_inputs(pdb_filename=get_pdb("4MZK"), chain="A", binder_len=18)
setup_mpnn(af_model,
           precompute=False,
           conditional=True) # conditioned both on the structure and sequence of target

print("target_length",af_model._target_len)
print("binder_length",af_model._binder_len)
print("weights",af_model.opt["weights"])

In [None]:
af_model.restart()
af_model.set_weights(mpnn_loss=0.01, mpnn_ent=0.01)
af_model.design_3stage(100,0,0)
af_model.set_weights(mpnn_loss=0.1, mpnn_ent=0.1)
af_model.design_3stage(0,100,10)

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()