<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.0.9/mpnn/examples/afdesign_and_proteinmpnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AfDesign + ProteinMPNN (v1.0.9)
Backprop through AlphaFold for protein design.

**WARNING**
1.   This notebook is in active development and was designed for demonstration purposes only.
2.   Using AfDesign as the only "loss" function for design might be a bad idea, you may find adversarial sequences (aka. sequences that trick AlphaFold). To avoid this problem, we couple it with ProteinMPNN.

In [None]:
#@title install
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.0.9
  # for debugging
  ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params
fi

In [None]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

from colabdesign.af import mk_af_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.shared.utils import copy_dict
from IPython.display import HTML
from google.colab import files

import numpy as np
from scipy.special import softmax
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

def setup_mpnn(self, precompute=True, entropy=False, backprop=False):
  mpnn_atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"])
  
  def loss_callback(inputs, aux, opt, key):
    if precompute:
      logits = opt["mpnn"]
    
    else:
      I = {"X":           aux["atom_positions"][None,:,mpnn_atom_idx],
           "mask":        aux["atom_mask"][None,:,1],
           "residue_idx": inputs["residue_index"][None],
           "chain_idx":   inputs["asym_id"][None],
           "key":         key}      
      
      if self.protocol == "binder":
        L = self._target_len
        logits = mk_mpnn_model().get_logits(**I)[0,L:]
      else:
        L = self._params["seq"].shape[1]
        logits = mk_mpnn_model().get_logits(**I)[0,:L]

      logits = aux["mpnn"] = logits if backprop else jax.lax.stop_gradient(logits)

    # define loss function
    log_q = jax.nn.log_softmax(logits)
    if entropy:
      # maximize entropy of mpnn output (aka increase confidence of mpnn)
      q = jax.nn.softmax(logits)
      mpnn_loss = -(q * log_q).sum(-1).mean()
    else:
      # minimize similarity to mpnn output
      p = inputs["seq"]["soft"]
      mpnn_loss = -(p * log_q).sum(-1).mean()

    return {"mpnn_loss":mpnn_loss}  
  
  if precompute:
    inputs = self._inputs
    I = {"X":           inputs["batch"]["all_atom_positions"][None,:,mpnn_atom_idx],
         "mask":        inputs["batch"]["all_atom_mask"][None,:,1],
         "residue_idx": inputs["residue_index"][None],
         "chain_idx":   inputs["asym_id"][None],
         "key":         self.key()}
    if self.protocol == "binder":
      L = self._target_len
      logits = mk_mpnn_model().get_logits(**I)[0,L:]
    else:
      L = self._params["seq"].shape[1]
      logits = mk_mpnn_model().get_logits(**I)[0,:L]

    logits = np.asarray(logits)
    self.opt["mpnn"] = logits

  else:
    def design_callback(self):
      self._inputs["bias"] = 0.99 * self._inputs["bias"] + 0.01 * self.aux["mpnn"]
    self._callbacks["design"]["post"].append(design_callback)

  self._callbacks["model"]["loss"].append(loss_callback)
  self.opt["weights"]["mpnn_loss"] = 1.0

# fixed backbone design (fixbb)
For a given protein backbone, generate/design a new sequence that AlphaFold thinks folds into that conformation. 

In [None]:
clear_mem()
mpnn_model = mk_mpnn_model()
af_model = mk_af_model(protocol="fixbb")
af_model.prep_inputs(pdb_filename=get_pdb("1TEN"), chain="A")
setup_mpnn(af_model, precompute=True)

print("length",  af_model._len)
print("weights", af_model.opt["weights"])

In [None]:
# precompute unconditional probabilities from mpnn
print("max_mpnn_loss",-np.log(softmax(af_model.opt["mpnn"],-1)).max(-1).mean())
plt.imshow(softmax(af_model.opt["mpnn"],-1).T,vmin=0,vmax=1)

In [None]:
af_model.restart()
af_model.set_seq(bias=af_model.opt["mpnn"])
af_model.set_weights(mpnn_loss=0.1)
af_model.design_3stage(0,200,10)

In [None]:
af_model.plot_traj()

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

# hallucination
For a given length, generate/hallucinate a protein sequence that AlphaFold thinks folds into a well structured protein (high plddt, low pae, many contacts).

In [None]:
clear_mem()
af_model = mk_af_model(protocol="hallucination")
af_model.prep_inputs(length=100)
setup_mpnn(af_model, precompute=False)

print("length",af_model._len)
print("weights",af_model.opt["weights"])

In [None]:
# pre-design with gumbel initialization and softmax activation
af_model.restart()
af_model.set_seq(mode="gumbel")
af_model.set_weights(mpnn_loss=0.01)
af_model.design_soft(100)

In [None]:
# lets see what the PDB looks like (if you don't like rerun the cell before)
af_model.plot_pdb()

In [None]:
af_model.set_weights(mpnn_loss=0.1)
af_model.design_3stage(0, 100, 10)

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

# binder hallucination
For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure.
To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder.

In [None]:
clear_mem()
mpnn_model = mk_mpnn_model()
af_model = mk_af_model(protocol="binder")
af_model.prep_inputs(pdb_filename=get_pdb("4MZK"), chain="A", binder_len=18)
setup_mpnn(af_model, precompute=False)

print("target_length",af_model._target_len)
print("binder_length",af_model._binder_len)
print("weights",af_model.opt["weights"])

In [None]:
af_model.restart()
af_model.set_weights(mpnn_loss=0.01)
af_model.design_3stage(100,0,0)
af_model.set_weights(mpnn_loss=0.1)
af_model.design_3stage(0,100,10)

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()