<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.0.9/mpnn/examples/afdesign_and_proteinmpnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AfDesign + ProteinMPNN (v1.0.9)
Backprop through AlphaFold for protein design.

**WARNING**
1.   This notebook is in active development and was designed for demonstration purposes only.
2.   Using AfDesign as the only "loss" function for design might be a bad idea, you may find adversarial sequences (aka. sequences that trick AlphaFold). To avoid this problem, we couple it with ProteinMPNN.

In [None]:
#@title install
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.0.9
  # for debugging
  ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params
fi

In [None]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign.af import mk_af_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.shared.utils import copy_dict
from IPython.display import HTML
from google.colab import files
import numpy as np
from scipy.special import softmax
import matplotlib.pyplot as plt

#########################
def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

from colabdesign.af.alphafold.common import residue_constants
def af2mpnn(self, use_seq=False, use_aux=False, get_best=True):
  '''extract mpnn input features from alphafold'''
  atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"])
  if use_aux:
    aux = self._tmp["best"]["aux"] if (get_best and "aux" in self._tmp["best"]) else self.aux
    inputs = {"X":aux["atom_positions"][None,:,atom_idx],
              "mask":aux["atom_mask"][None,:,1],
              "residue_idx":aux["residue_index"][None]}
    if use_seq: inputs["S"] = aux["aatype"][None]
  else:
    I = copy_dict(self._inputs)
    inputs = {"X":I["batch"]["all_atom_positions"][None,:,atom_idx],
              "mask":I["batch"]["all_atom_mask"][None,:,1],
              "residue_idx":I["residue_index"][None]}
    if use_seq: inputs["S"] = I["batch"]["aatype"][None]

  inputs.update({"chain_idx":self._inputs["asym_id"][None],
                 "key":self.key()})
  return inputs

def mpnn_callback(mpnn_model=None, update_bias=False):
  if mpnn_model is None:
    mpnn_model = mk_mpnn_model()
  
  def get_mpnn_scores(self):
    # compute mpnn score
    mpnn_inputs = af2mpnn(self, use_aux=True, use_seq=True)
    seq_logits = np.array(mpnn_model.get_logits(**mpnn_inputs))[0]
    seq = np.eye(20)[self.aux["aatype"]]
    self.aux["losses"]["mpnn"] = -(seq * np.log(softmax(seq_logits,-1))).sum(-1).mean()

    # update bias
    if update_bias:
      if self.protocol == "binder":
        mpnn_inputs = af2mpnn(af_model, use_aux=True, use_seq=True)

        # get unconditional logits for binder
        mpnn_inputs["S"][:,self._target_len:] = -1
        mpnn_inputs["decoding_order"] = np.append(np.arange(self._target_len),
                                                  np.full(self._binder_len,-1))[None]

        seq_logits = np.array(mpnn_model.get_logits(**mpnn_inputs))[0,self._target_len:]

      else:
        mpnn_inputs = af2mpnn(self, use_aux=True, use_seq=False)
        seq_logits = np.array(mpnn_model.get_logits(**mpnn_inputs))[0]

      self.opt["bias"] = 0.9 * self.opt["bias"] + 0.1 * seq_logits
  return get_mpnn_scores

# fixed backbone design (fixbb)
For a given protein backbone, generate/design a new sequence that AlphaFold thinks folds into that conformation. 

In [None]:
clear_mem()
mpnn_model = mk_mpnn_model()
af_model = mk_af_model(protocol="fixbb")
af_model.prep_inputs(pdb_filename=get_pdb("1TEN"), chain="A")

print("length",  af_model._len)
print("weights", af_model.opt["weights"])

In [None]:
# get unconditional probabilities from mpnn
mpnn_inputs = af2mpnn(af_model)
seq_logits = np.array(mpnn_model.get_logits(**mpnn_inputs))[0]
plt.imshow(softmax(seq_logits,-1).T,vmin=0,vmax=1)

In [None]:
af_model.restart()
af_model.set_weights(seq_bias=1.0)
af_model.set_seq(bias=seq_logits)
af_model.design_3stage(0,200,10)

In [None]:
af_model.plot_traj()  

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

# hallucination
For a given length, generate/hallucinate a protein sequence that AlphaFold thinks folds into a well structured protein (high plddt, low pae, many contacts).

In [None]:
clear_mem()
mpnn_model = mk_mpnn_model()
af_model = mk_af_model(protocol="hallucination")
af_model.prep_inputs(length=100)

print("length",af_model._len)
print("weights",af_model.opt["weights"])

In [None]:
# pre-design with gumbel initialization and softmax activation
af_model.restart()
af_model.set_seq(mode="gumbel")
af_model.design_soft(50, save_best=True)

In [None]:
# lets see what the PDB looks like (if you don't like rerun the cell before)
af_model.plot_pdb()

In [None]:
# get unconditional probabilities from mpnn
mpnn_inputs = af2mpnn(af_model, use_aux=True)
seq_logits = np.array(mpnn_model.get_logits(**mpnn_inputs))[0]
plt.imshow(softmax(seq_logits,-1).T,vmin=0,vmax=1)

In [None]:
af_model.opt["bias"]

In [None]:
# three stage design
af_model.clear_best()
af_model.set_weights(seq_bias=1.0)
af_model.set_seq(bias=seq_logits)
af_model.design_3stage(0, 100, 10, callback=mpnn_callback(mpnn_model, update_bias=True))

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

# binder hallucination
For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure.
To do this, we minimize PAE and maximize number of contacts at the interface and within the binder, and we maximize pLDDT of the binder.

In [None]:
clear_mem()
mpnn_model = mk_mpnn_model()
af_model = mk_af_model(protocol="binder")
af_model.prep_inputs(pdb_filename=get_pdb("4MZK"), chain="A", binder_len=19)

print("target_length",af_model._target_len)
print("binder_length",af_model._binder_len)
print("weights",af_model.opt["weights"])

In [None]:
af_model.set_seq()
af_model.design_logits(50, save_best=True)

In [None]:
mpnn_inputs = af2mpnn(af_model, use_aux=True, use_seq=True)

# get unconditional logits for binder
mpnn_inputs["S"][:,af_model._target_len:] = -1
mpnn_inputs["decoding_order"] = np.append(np.arange(af_model._target_len),
                                          np.full(af_model._binder_len,-1))[None]

seq_logits = np.array(mpnn_model.get_logits(**mpnn_inputs))[0]
seq_logits = seq_logits[af_model._target_len:]

plt.figure(dpi=100)
plt.imshow(softmax(seq_logits,-1).T,vmin=0,vmax=1)
plt.xlabel("positions");plt.ylabel("amino_acids")
plt.show()

In [None]:
af_model.clear_best()
af_model.set_seq(bias=seq_logits)
af_model.set_weights(seq_bias=1.0)
af_model.design_3stage(0, 100, 10, callback=mpnn_callback(mpnn_model, update_bias=True))

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()