<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/af_design_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AF Design
WARNING: This notebook is in BETA stage, not intended for serious use!

In [1]:
#@title install
%%bash
if [ ! -d af_backprop ]; then
  git clone https://github.com/sokrypton/af_backprop.git
  pip -q install biopython dm-haiku ml-collections py3Dmol
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi
wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py

In [2]:
#@title import libraries
import sys
sys.path.append('/content/af_backprop')

import numpy as np

import jax
import jax.numpy as jnp
from jax.example_libraries.optimizers import adam, sgd

def clear_mem():
  backend = jax.lib.xla_bridge.get_backend()
  for buf in backend.live_buffers(): buf.delete()

from alphafold.common import protein
from alphafold.data import pipeline, templates
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding


# custom functions
from utils import *
import colabfold as cf

##############################################################
# PLOTTING FUNCTIONS
##############################################################

import py3Dmol
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from IPython.display import HTML


def get_seqs(seq):
  x = seq.argmax(-1)
  if x.ndim == 1:
    return "".join([order_restype[a] for a in x])
  else:
    return ["".join([order_restype[a] for a in s]) for s in x]

@jax.jit
def subsample_params(n, params):
  return jax.tree_map(lambda x:x[n], params)

def make_animation(xyz, seq, plddt=None, pae=None,
                   pos_ref=None, line_w=2.0,
                   dpi=100, interval=60, color_msa="Taylor",
                   length=None):

  def align(P, Q, P_trim=None):
    if P_trim is None: P_trim = P
    p_trim = P_trim - P_trim.mean(0,keepdims=True)
    p = P - P_trim.mean(0,keepdims=True)
    q = Q - Q.mean(0,keepdims=True)
    return p @ cf.kabsch(p_trim,q)

  # compute reference position
  if pos_ref is None: pos_ref = xyz[-1]
  if length is None: length = len(pos_ref)
  
  # align to reference
  pos_ref_trim = pos_ref[:length]
  # align to reference position
  new_positions = []
  for i in range(len(xyz)):
    new_positions.append(align(xyz[i],pos_ref_trim,xyz[i][:length]))
  pos = np.asarray(new_positions)

  # rotate for best view
  pos_mean = pos.mean(0)
  rot_mtx = cf.kabsch(pos_mean, pos_mean, return_v=True)
  pos = pos @ rot_mtx
  pos_ref_full = (pos_ref - pos_ref_trim.mean()) @ rot_mtx

  # initialize figure
  fig = plt.figure()
  gs = GridSpec(4,3, figure=fig)
  if pae is None:
    ax1, ax2 = fig.add_subplot(gs[:3,:]), fig.add_subplot(gs[3:,:])
  else:
    ax1, ax2, ax3 = fig.add_subplot(gs[:3,:2]), fig.add_subplot(gs[3:,:]), fig.add_subplot(gs[:3,2:])

  fig.subplots_adjust(top=0.95,bottom=0.1,right=0.95,left=0.05,hspace=0,wspace=0)
  fig.set_figwidth(8); fig.set_figheight(6); fig.set_dpi(dpi)
  ax2.set_xlabel("positions"); ax2.set_yticks([])
  if seq[0].shape[0] > 1: ax2.set_ylabel("sequences")
  else: ax2.set_ylabel("amino acids")

  ax1.set_title("N→C") if plddt is None else ax1.set_title("pLDDT")
  if pae is not None:
    ax3.set_title("pAE")
    ax3.set_xticks([])
    ax3.set_yticks([])

  # set bounderies
  x_min,y_min,z_min = np.minimum(pos.min(1).mean(0),pos_ref_full.min(0)) - 5
  x_max,y_max,z_max = np.maximum(pos.max(1).mean(0),pos_ref_full.max(0)) + 5

  x_pad = ((y_max - y_min) * 2 - (x_max - x_min)) / 2
  y_pad = ((x_max - x_min) / 2 - (y_max - y_min)) / 2
  if x_pad > 0:
    x_min -= x_pad
    x_max += x_pad
  else:
    y_min -= y_pad
    y_max += y_pad

  ax1.set_xlim(x_min, x_max)
  ax1.set_ylim(y_min, y_max)
  ax1.set_xticks([])
  ax1.set_yticks([])

  # get animation frames
  ims = []
  for k in range(len(pos)):
    ims.append([])
    if plddt is None:
      ims[-1].append(cf.plot_pseudo_3D(pos[k], ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max))
    else:
      ims[-1].append(cf.plot_pseudo_3D(pos[k], c=plddt[k], cmin=0.5, cmax=0.9, ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max))
    if seq[k].shape[0] == 1:
      ims[-1].append(ax2.imshow(seq[k][0].T, animated=True, cmap="bwr_r",vmin=-1, vmax=1))
    else:
      cmap = matplotlib.colors.ListedColormap(jalview_color_list[color_msa])
      vmax = len(jalview_color_list[color_msa]) - 1
      ims[-1].append(ax2.imshow(seq[k].argmax(-1), animated=True, cmap=cmap, vmin=0, vmax=vmax, interpolation="none"))
    if pae is not None:
      ims[-1].append(ax3.imshow(pae[k], animated=True, cmap="bwr",vmin=0, vmax=30))

  # make animation!
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=interval)
  plt.close()
  return ani.to_html5_video()

In [3]:
#@title design function
class mk_design_model:

  ######################################
  # model initialization
  ######################################
  def __init__(self, num_seq=1, protocol="fixbb",
               num_models=5, model_mode="sample",
               num_recycles=0, recycle_mode="sample",
               seq_mode="logits", dropout=True, save_traj=True):
    
    use_templates = True if protocol=="binder" else False

    self.opt = {"num_seq":num_seq, "seq_mode":seq_mode,
                "num_models":num_models, "model_mode":model_mode,
                "num_recycles":num_recycles, "recycle_mode":recycle_mode,
                "dropout":dropout, "use_templates":use_templates,
                "save_traj":save_traj}
    
    self.protocol = protocol
    self._k = -1

    # setup which model params to use
    if use_templates: model_name = "model_1_ptm"
    else: model_name = "model_3_ptm"
    cfg = config.model_config(model_name)

    # enable checkpointing
    cfg.model.global_config.use_remat = True

    # number of sequences
    if use_templates:
      cfg.data.eval.max_templates = 1
      cfg.data.eval.max_msa_clusters = num_seq + 1
    else:
      cfg.data.eval.max_msa_clusters = num_seq

    cfg.data.common.max_extra_msa = 1
    cfg.data.eval.masked_msa_replace_fraction = 0

    # number of recycles
    cfg.model.num_recycle = num_recycles
    cfg.data.common.num_recycle = num_recycles

    # backprop through recycles
    cfg.model.add_prev = recycle_mode == "add_prev"
    cfg.model.backprop_recycle = recycle_mode == "backprop"
    cfg.model.embeddings_and_evoformer.backprop_dgram = recycle_mode == "backprop"

    # dropout
    if not dropout: cfg = set_dropout(cfg, 0.0)

    # setup model
    self._params = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
    self._runner = model.RunModel(cfg, self._params[0], is_training=True)

    # load the other model_params (during optimization, we randomly pick which model to use)
    if use_templates:
      model_names = ["model_2_ptm"]
    else:
      model_names = ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]

    for model_name in model_names[:num_models]:
      params = data.get_model_haiku_params(model_name, '.')
      self._params.append({k: params[k] for k in self._runner.params.keys()})

    # define gradient function
    if model_mode == "parallel":
      # combine model params (for parallel compute)
      self._params = jax.tree_multimap(lambda *values: jnp.stack(values, axis=0), *self._params)
      self._grad = jax.jit(jax.vmap(self._get_grad_fn(),(None,0,None,None,None)))
    else:
      self._grad = jax.jit(self._get_grad_fn())

    # define input function
    if protocol == "fixbb":          self.prep_inputs = self._prep_fixbb
    if protocol == "hallucination":  self.prep_inputs = self._prep_hallucination
    if protocol == "binder":         self.prep_inputs = self._prep_binder

  ######################################
  # setup gradient
  ######################################
  def _get_grad_fn(self):

    # setup function to get gradients
    def mod(params, model_params, inputs, key, opt):

      # initialize the loss function
      losses = {}
      w = opt["weight"]

      # set sequence
      seq_logits = params["seq_logits"]
      if self.opt["seq_mode"] == "softmax_gumbel":
        seq_logits += jax.random.gumbel(key, seq_logits.shape)

      # shuffle msa
      if self.opt["num_seq"] > 1:
        i = jax.random.randint(key,[],0,seq_logits.shape[0])
        seq_logits = seq_logits.at[0].set(seq_logits[i]).at[i].set(seq_logits[0])

      # reparameterization
      seq_soft = jax.nn.softmax(seq_logits)
      seq_hard = jax.nn.one_hot(seq_logits.argmax(-1),20)
      seq_hard = jax.lax.stop_gradient(seq_hard - seq_soft) + seq_soft
      mask_pseudo = opt["mask"][:,None]

      if self.opt["seq_mode"] == "logits":
        seq_pseudo = jnp.where(mask_pseudo, seq_hard, seq_logits * 0.1)
      else:
        seq_pseudo = jnp.where(mask_pseudo, seq_hard, seq_soft) 

      # entropy loss for msa
      if self.opt["num_seq"] > 1:
        seq_prf = seq_hard.mean(0)
        ent_loss = -(seq_prf * jnp.log(seq_prf + 1e-8)).sum(-1).mean()
        losses["ent"] = ent_loss
      
      if self.protocol == "binder":
        # concatenate target and binder sequence
        seq_target = jax.nn.one_hot(self._batch["aatype"][:self._target_len],20)
        seq_hard = jnp.concatenate([seq_target[None], seq_hard], 1)
        seq_soft = jnp.concatenate([seq_target[None], seq_soft], 1)
        seq_pseudo = jnp.concatenate([seq_target[None], seq_pseudo], 1)
      
      update_seq(seq_pseudo, inputs)
      
      # set sidechains identity
      N,L = inputs["aatype"].shape[:2]
      aatype = jnp.zeros((N,L,21)).at[...,:20].set(seq_hard[0])
      update_aatype(aatype, inputs)

      # set number of recycles to use
      inputs["num_iter_recycling"] = opt["recycles"]
      
      # get outputs
      outputs = self._runner.apply(model_params, key, inputs)
              
      # confidence losses
      pae_prob = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"])
      pae_loss = (pae_prob * jnp.arange(pae_prob.shape[-1])).mean(-1)
      
      plddt_prob = jax.nn.softmax(outputs["predicted_lddt"]["logits"])
      plddt_loss = (plddt_prob * jnp.arange(plddt_prob.shape[-1])[::-1]).mean(-1)
      
      # note: we find maximizing just the confidence results in single extended helices
      # to promote compact structure we add a loss to maximize number of contracts via dgram.
      con_prob = jax.nn.log_softmax(outputs["distogram"]["logits"])
      con_loss = -jax.nn.logsumexp(con_prob[...,:-1],-1)

      # protocol specific losses
      if self.protocol == "binder":
        TL = self._target_len
        losses.update({"con_intra":con_loss[...,TL:,TL:].mean(),
                       "con_inter":con_loss[...,:TL,TL:].mean(),
                       "plddt":plddt_loss[...,TL:].mean(),
                       "pae":pae_loss.mean()})

      if self.protocol == "hallucination":
        losses.update({"con":con_loss.mean(),
                       "plddt":plddt_loss.mean(),
                       "pae":pae_loss.mean()})

      if self.protocol == "fixbb":
        # note: we find dgram loss to easier to backprop through
        fape_loss = get_fape_loss(self._batch, outputs, model_config=self._runner.config)      
        dgram_loss = get_dgram_loss(self._batch, outputs, model_config=self._runner.config)
        losses.update({"plddt":plddt_loss.mean(),
                       "pae":pae_loss.mean(),
                       "dgram":dgram_loss, 
                       "fape":fape_loss})

      # loss
      loss = sum([v*w[k] if k in w else v for k,v in losses.items()])

      # save aux outputs
      outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
              "final_atom_mask":outputs["structure_module"]["final_atom_mask"],
              "plddt":get_plddt(outputs)}
      # protocol specific outputs
      if self.protocol == "fixbb": outs.update({"rmsd":get_rmsd_loss(self._batch, outputs)})
      if self.protocol == "binder": outs.update({"pae":get_pae(outputs)})

      return loss, ({"losses":losses,"outputs":outs,"seq":seq_hard,"seq_pseudo":seq_pseudo})
    
    return jax.value_and_grad(mod, has_aux=True, argnums=0)
  

  ######################################
  # input prep functions
  ######################################

  def _prep_inputs(self, length, template_features=None):
    '''prep input features'''
    num_seq = self.opt["num_seq"]
    sequence = "A" * length
    feature_dict = {
        **pipeline.make_sequence_features(sequence=sequence, description="none",
                                          num_res=length),
        **pipeline.make_msa_features(msas=[length*[sequence]],
                                     deletion_matrices=[num_seq*[[0]*length]]),
    }
    if template_features is not None: feature_dict.update(template_features)    
    inputs = self._runner.process_features(feature_dict, random_seed=0)
    if num_seq > 1:
      inputs["msa_row_mask"] = jnp.ones_like(inputs["msa_row_mask"])
      inputs["msa_mask"] = jnp.ones_like(inputs["msa_mask"])
    return inputs

  def _prep_pdb(self, pdb_filename, chain=None):
    # setup inputs
    protein_obj = protein.from_pdb_string(pdb_to_string(pdb_filename), chain_id=chain)
    batch = {'aatype': protein_obj.aatype,
              'all_atom_positions': protein_obj.atom_positions,
              'all_atom_mask': protein_obj.atom_mask}

    has_ca = batch["all_atom_mask"][:,0] == 1
    batch = jax.tree_map(lambda x:x[has_ca], batch)
    batch.update(all_atom.atom37_to_frames(**batch))

    template_features = {"template_aatype":jax.nn.one_hot(protein_obj.aatype[has_ca],22)[None],
                         "template_all_atom_masks":protein_obj.atom_mask[has_ca][None],
                         "template_all_atom_positions":protein_obj.atom_positions[has_ca][None],
                         "template_domain_names":np.asarray(["None"])}
    return {"batch":batch,
            "template_features":template_features,
            "residue_index": protein_obj.residue_index[has_ca]}

  def _prep_binder(self, pdb_filename, chain=None, binder_len=50):
    '''prep inputs for binder design'''
    pdb = self._prep_pdb(pdb_filename, chain=chain)

    target_len = pdb["residue_index"].shape[0]
    self._inputs = self._prep_inputs(target_len, pdb["template_features"])
    self._inputs["residue_index"][:,:] = pdb["residue_index"]

    total_len = target_len + binder_len
    self._inputs = make_fixed_size(self._inputs, self._runner, total_len)
    self._batch = make_fixed_size(pdb["batch"], self._runner, total_len, batch_axis=False)

    # offset residue index for binder
    self._inputs["residue_index"] = self._inputs["residue_index"].copy()
    self._inputs["residue_index"][:,target_len:] = pdb["residue_index"][-1] + np.arange(binder_len) + 100

    self._inputs["seq_mask"] = np.ones_like(self._inputs["seq_mask"])
    self._inputs["msa_mask"] = np.ones_like(self._inputs["msa_mask"])

    self._target_len = target_len
    self._binder_len = self._len = binder_len
    self._k = -1

  def _prep_fixbb(self, pdb_filename, chain=None):
    '''prep inputs for fixed backbone design'''
    
    pdb = self._prep_pdb(pdb_filename, chain=chain)
    length = pdb["residue_index"].shape[0]
    self._inputs = self._prep_inputs(length, pdb["template_features"])

    # update residue index from pdb
    self._inputs["residue_index"][:,:] = pdb["residue_index"]

    self._batch = pdb["batch"]
    self._len = length
    self._k = -1
    
  def _prep_hallucination(self, length=100):
    '''prep inputs for hallucination'''
    self._inputs = self._prep_inputs(length)
    self._len = length
    self._k = -1
    
  ######################################
  # design function
  ######################################
  def restart(self, seed=None, lr=0.1, hard=True):    

    # initialize trajectory
    if self.opt["save_traj"]:
      self.losses,self._traj = [],{"xyz":[],"seq":[],"plddt":[]}
      if self.protocol == "binder": self._traj.update({"pae":[]})
    self._best_loss, self._best_outs = np.inf, None

    # setup optimizer
    self._init_fun, self._update_fun, self._get_params = adam(lr)
    self._k = 0

    # initialize sequence and mask_pseudo
    if seed is None: seed = np.random.randint(10000)
    self._key = jax.random.PRNGKey(seed)
    N,L,A = self.opt["num_seq"],self._len,20    
    
    if hard:
      seq_logits = 0.01 * jax.random.normal(self._key, (N,L,A))
      self._mask = jnp.ones((L,))
    else:
      seq_logits = jnp.zeros((N,L,A))
      self._mask = jnp.zeros((L,))

    self._best_hard = self._mask.sum()
    self._state = self._init_fun({"seq_logits":seq_logits})

  def design(self, iters=300, weight=None,
             hard=True, hard_switch=1,
             verbose=True, seed=None, lr=0.1,
             restart=False, print_all=False):
       
    # gradient step function
    def step(k, state, key, opt):

      model_mode = self.opt["model_mode"]
      tot_models = 2 if self.opt["use_templates"] else 5
      num_models = min(self.opt["num_models"], tot_models)

      # compute loss & gradient
      if model_mode == "sample":
        n = jax.random.randint(key,[], 0, num_models)
        (loss, outs), grad = self._grad(self._get_params(state), self._params[n], self._inputs, key, opt)
        outs["outputs"]["model_used"] = n      
      
      if model_mode == "parallel":
        (loss, outs), grad = self._grad(self._get_params(state), self._params, self._inputs, key, opt)
        # take the mean of gradients and loss
        grad = jax.tree_map(lambda x: x.mean(0), grad)
        outs["losses"] = jax.tree_map(lambda x: x.mean(0), outs["losses"])
        loss = loss.mean(0)

        # use first model
        outs["outputs"] = jax.tree_map(lambda x:x[0], outs["outputs"])

        outs["seq_logits"] = outs["seq_logits"][0]
        outs["seq_pseudo"] = outs["seq_pseudo"][0]
        outs["outputs"]["model_used"] = "all"

      # normalize the gradients
      grad["seq_logits"] /= jnp.sqrt(jnp.square(grad["seq_logits"]).sum(-1,keepdims=True).mean(-2,keepdims=True))

      # apply gradient
      state = self._update_fun(k, grad, state)
      return state, outs, loss

    # set weights
    w = {"ent_loss":0.01}
    if self.protocol == "fixbb": w.update({"dgram":1.0,"fape":0.0,
                                           "pae":0.1,"plddt":0.1})
    if self.protocol == "hallucination": w.update({"pae":1.0,"plddt":1.0,"con":0.5})
    if self.protocol == "binder": w.update({"pae":1.0,"plddt":1.0,
                                            "con_intra":0.25,"con_inter":0.25})
    if weight is not None:
      for k,v in weight.items(): w[k] = v

    # start optimization
    if restart or self._k == -1:
      self.restart(seed=seed, lr=lr, hard=hard)

    for _ in range(iters):
      self._key, subkey, _subkey  = jax.random.split(self._key, 3)
      recycles = jnp.array([self.opt["num_recycles"]])
      if self.opt["recycle_mode"] == "sample":
        recycles = jax.random.randint(_subkey,(1,), 0, self.opt["num_recycles"] + 1)
      
      # options passed to compiled model
      opt = {"mask":self._mask,"weight":w,"recycles":recycles}

      # take step
      self._state, outs, loss = step(self._k, self._state, subkey, opt)
      
      recycle_used, model_used = int(recycles[0]), outs["outputs"]["model_used"]
      all_recycles = recycle_used == self.opt["num_recycles"]

      # print output
      losses = outs["losses"]   
      losses_print = f'{self._k}\t'
      losses.update({"model":model_used, "recycle":recycle_used, "n_hard":int(self._mask.sum()), "loss":loss})
      
      if self.protocol == "fixbb":
        losses["seqid"] = (outs["seq"].argmax(-1) == self._batch["aatype"]).mean()
        losses["rmsd"] = outs["outputs"]["rmsd"]
      self.losses.append(losses)

      for l in ["model","recycle","n_hard"]:
        if l in losses: losses_print += f' {l}: {losses[l]}'
      for l in ["loss","seqid","ent","pae","plddt","con","con_intra","con_inter","dgram","fape","rmsd"]:
        if l in losses: losses_print += f' {l}: {losses[l]:.3f}'
      
      if verbose:
        if print_all or loss < self._best_loss:
          print(losses_print)

      # save for animation
      if all_recycles and self.opt["save_traj"]:
        traj = {"xyz":outs["outputs"]["final_atom_positions"][:,1,:],
                "plddt":outs["outputs"]["plddt"], "seq":outs["seq_pseudo"]}
        if self.protocol == "binder":
          traj.update({"pae":outs["outputs"]["pae"]})
          traj["seq"] = traj["seq"][...,self._target_len:,:]
        for k,v in traj.items():
          self._traj[k].append(np.array(v))
          
      # save best result
      if all_recycles and (loss < self._best_loss or \
                           self._mask.sum() > self._best_hard):
        self._best_loss = loss
        self._best_outs = outs
        self._best_hard = self._mask.sum()

      if hard and (self._k + 1) % hard_switch == 0 and self._mask.sum() < self._len:          
        # pick random position to flip to hard
        i = np.random.choice(np.where(self._mask == 0)[0])
        self._mask = self._mask.at[i].set(1)   

      # increment
      self._k += 1

  def animate(self, s=0, e=None, dpi=100):
    sub_traj = {k:v[s:e] for k,v in self._traj.items()}
    if self.protocol == "fixbb":
      pos_ref = self._batch["all_atom_positions"][:,1,:]
      return make_animation(**sub_traj, pos_ref=pos_ref, dpi=dpi)
    
    elif self.protocol == "binder":
      pos_ref = self._best_outs["outputs"]["final_atom_positions"][:,1,:]
      TL = self._target_len
      return make_animation(**sub_traj,pos_ref=pos_ref,length=TL, dpi=dpi)    

    else:
      pos_ref = self._best_outs["outputs"]["final_atom_positions"][:,1,:]
      return make_animation(**sub_traj, pos_ref=pos_ref, dpi=dpi)

  def save_pdb(self, filename=None):
    p = {"residue_index":np.asarray(self._inputs["residue_index"][0]),
        "aatype":np.asarray(design_model._best_outs["seq"].argmax(-1)[0]),
        "atom_positions":np.asarray(design_model._best_outs["outputs"]["final_atom_positions"]),
        "atom_mask":np.asarray(design_model._best_outs["outputs"]["final_atom_mask"])}
    b_factors = np.asarray(design_model._best_outs["outputs"]["plddt"])[:,None] * p["atom_mask"]
    p = protein.Protein(**p,b_factors=b_factors)
    pdb_lines = protein.to_pdb(p)
    if filename is None:
      return pdb_lines
    else:
      with open(filename, 'w') as f:
        f.write(pdb_lines)

  def plot_pdb(self):
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    view.addModel(self.save_pdb(),'pdb')
    view.setStyle({'cartoon': {}})
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
    view.zoomTo()
    view.show()

  def get_loss(self, x = "loss"):
    return np.array([float(loss[x]) for loss in self.losses])
  
  def get_seqs(self):
    return get_seqs(self._best_outs["seq"])

In [4]:
#@title setup model


##############################################################
# GET OPTIONS
##############################################################
dropout = True #@param ["True", "False"] {type:"raw"}
#@markdown - `dropout` - Use dropout during design (helps jump out of local minima)

#@markdown ###model options

num_models = 5 #@param ["1", "2", "3", "4", "5"] {type:"raw"}
#@markdown - `num_models` - number of model params to use
model_mode = "sample" #@param ["sample", "parallel"]
#@markdown - `sample` - at each iteration, randomly select one model param to use.
#@markdown - `parallel` - run `num_models` in parallel, average the gradients. 

#@markdown ###recycle options
num_recycles = 0 #@param ["0", "1", "2", "3"] {type:"raw"}
#@markdown - `num_recycles` - max number of recycles to use during design (for denovo proteins we find 0 is often enough)
recycle_mode = "sample" #@param ["sample", "add_prev", "last", "backprop"]
#@markdown - `sample` - at each iteration, randomly select number of recycles to use. (Recommended)
#@markdown - `add_prev` - add prediction logits (dgram, pae, plddt) across all recycles. (Most stable, but slow and requires more memory).
#@markdown - `last` - only use gradients from last recycle.
#@markdown - `backprop` - use outputs from last recycle, but backprop through all recycles.


OPT = {"dropout":dropout,
       "num_models":num_models, "model_mode":model_mode,
       "num_recycles":num_recycles, "recycle_mode":recycle_mode}

# fixed backbone design (fixbb)

In [9]:
protocol = "fixbb" #@param ["fixbb", "hallucination", "binder"] 
seq_mode = "logits" #@param ["softmax","softmax_gumbel","logits"] 
example = "1QYS" #@param {type:"string"}
chain = "A" #@param ["A", "B", "C"] {allow-input: true}

!wget -qnc https://files.rcsb.org/view/{example}.pdb
clear_mem()
design_model = mk_design_model(**OPT, protocol=protocol, seq_mode=seq_mode)
design_model.prep_inputs(pdb_filename=f"{example}.pdb", chain=chain)

In [7]:
# note pae and plddt values are between 0 and 1 (lower better)
design_model.design(100)

0	 model: 3 recycle: 0 n_hard: 92 loss: 3.994 seqid: 0.065 pae: 0.523 plddt: 0.453 dgram: 3.896 fape: 2.522 rmsd: 18.106
1	 model: 0 recycle: 0 n_hard: 92 loss: 3.661 seqid: 0.098 pae: 0.563 plddt: 0.562 dgram: 3.549 fape: 2.157 rmsd: 13.446
12	 model: 1 recycle: 0 n_hard: 92 loss: 3.656 seqid: 0.043 pae: 0.530 plddt: 0.560 dgram: 3.547 fape: 2.099 rmsd: 12.930
13	 model: 2 recycle: 0 n_hard: 92 loss: 3.652 seqid: 0.043 pae: 0.577 plddt: 0.592 dgram: 3.535 fape: 1.932 rmsd: 10.824
14	 model: 2 recycle: 0 n_hard: 92 loss: 3.585 seqid: 0.043 pae: 0.603 plddt: 0.621 dgram: 3.462 fape: 2.088 rmsd: 15.088
17	 model: 2 recycle: 0 n_hard: 92 loss: 3.578 seqid: 0.054 pae: 0.563 plddt: 0.572 dgram: 3.465 fape: 1.983 rmsd: 9.998
20	 model: 4 recycle: 0 n_hard: 92 loss: 3.571 seqid: 0.054 pae: 0.556 plddt: 0.569 dgram: 3.458 fape: 1.955 rmsd: 10.068
21	 model: 1 recycle: 0 n_hard: 92 loss: 3.525 seqid: 0.076 pae: 0.542 plddt: 0.586 dgram: 3.412 fape: 1.894 rmsd: 13.263
22	 model: 3 recycle: 0 n_h

In [9]:
# run for another 200 iterations
design_model.design(200)

131	 model: 4 recycle: 0 n_hard: 92 loss: 2.909 seqid: 0.076 pae: 0.308 plddt: 0.324 dgram: 2.845 fape: 1.452 rmsd: 2.563
160	 model: 3 recycle: 0 n_hard: 92 loss: 2.908 seqid: 0.065 pae: 0.381 plddt: 0.409 dgram: 2.829 fape: 1.671 rmsd: 3.364
162	 model: 4 recycle: 0 n_hard: 92 loss: 2.863 seqid: 0.065 pae: 0.384 plddt: 0.404 dgram: 2.784 fape: 1.597 rmsd: 2.881
213	 model: 1 recycle: 0 n_hard: 92 loss: 2.863 seqid: 0.087 pae: 0.371 plddt: 0.396 dgram: 2.786 fape: 1.551 rmsd: 3.119
214	 model: 3 recycle: 0 n_hard: 92 loss: 2.849 seqid: 0.087 pae: 0.351 plddt: 0.374 dgram: 2.776 fape: 1.489 rmsd: 2.884
217	 model: 1 recycle: 0 n_hard: 92 loss: 2.836 seqid: 0.087 pae: 0.336 plddt: 0.363 dgram: 2.766 fape: 1.429 rmsd: 2.615
222	 model: 1 recycle: 0 n_hard: 92 loss: 2.812 seqid: 0.076 pae: 0.345 plddt: 0.364 dgram: 2.741 fape: 1.369 rmsd: 2.802
231	 model: 3 recycle: 0 n_hard: 92 loss: 2.804 seqid: 0.087 pae: 0.355 plddt: 0.371 dgram: 2.731 fape: 1.457 rmsd: 3.061


lets try again

In [7]:
# let's "cheat" and optimize a soft input sequence for 200 steps
design_model.design(200, restart=True, hard=False)

0	 model: 1 recycle: 0 n_hard: 0 loss: 5.368 seqid: 0.065 pae: 0.361 plddt: 0.125 dgram: 5.320 fape: 4.720 rmsd: 36.158
2	 model: 1 recycle: 0 n_hard: 0 loss: 4.984 seqid: 0.076 pae: 0.580 plddt: 0.346 dgram: 4.891 fape: 3.024 rmsd: 23.611
7	 model: 2 recycle: 0 n_hard: 0 loss: 4.909 seqid: 0.065 pae: 0.533 plddt: 0.334 dgram: 4.822 fape: 3.142 rmsd: 23.360
9	 model: 3 recycle: 0 n_hard: 0 loss: 4.314 seqid: 0.054 pae: 0.660 plddt: 0.443 dgram: 4.204 fape: 3.006 rmsd: 23.701
12	 model: 3 recycle: 0 n_hard: 0 loss: 4.104 seqid: 0.043 pae: 0.607 plddt: 0.420 dgram: 4.001 fape: 2.864 rmsd: 21.914
15	 model: 0 recycle: 0 n_hard: 0 loss: 3.916 seqid: 0.033 pae: 0.647 plddt: 0.575 dgram: 3.794 fape: 2.690 rmsd: 22.732
16	 model: 3 recycle: 0 n_hard: 0 loss: 3.791 seqid: 0.033 pae: 0.653 plddt: 0.588 dgram: 3.667 fape: 2.622 rmsd: 20.683
21	 model: 3 recycle: 0 n_hard: 0 loss: 3.687 seqid: 0.033 pae: 0.633 plddt: 0.688 dgram: 3.555 fape: 2.319 rmsd: 16.944
23	 model: 3 recycle: 0 n_hard: 0 lo

In [8]:
# lets make one_hot!
design_model.design(100, hard=True, hard_switch=1)

201	 model: 0 recycle: 0 n_hard: 1 loss: 1.646 seqid: 0.196 pae: 0.112 plddt: 0.100 dgram: 1.625 fape: 0.330 rmsd: 0.854
204	 model: 4 recycle: 0 n_hard: 4 loss: 1.722 seqid: 0.185 pae: 0.110 plddt: 0.097 dgram: 1.701 fape: 0.417 rmsd: 0.894
206	 model: 3 recycle: 0 n_hard: 6 loss: 1.713 seqid: 0.196 pae: 0.117 plddt: 0.109 dgram: 1.690 fape: 0.366 rmsd: 0.819
207	 model: 1 recycle: 0 n_hard: 7 loss: 1.697 seqid: 0.196 pae: 0.112 plddt: 0.085 dgram: 1.678 fape: 0.379 rmsd: 0.834
208	 model: 3 recycle: 0 n_hard: 8 loss: 1.688 seqid: 0.196 pae: 0.117 plddt: 0.107 dgram: 1.666 fape: 0.333 rmsd: 0.765
209	 model: 3 recycle: 0 n_hard: 9 loss: 1.685 seqid: 0.185 pae: 0.116 plddt: 0.100 dgram: 1.664 fape: 0.318 rmsd: 0.788
212	 model: 0 recycle: 0 n_hard: 12 loss: 1.708 seqid: 0.185 pae: 0.109 plddt: 0.099 dgram: 1.687 fape: 0.334 rmsd: 0.839
215	 model: 2 recycle: 0 n_hard: 15 loss: 1.737 seqid: 0.196 pae: 0.107 plddt: 0.085 dgram: 1.718 fape: 0.340 rmsd: 0.910
217	 model: 4 recycle: 0 n_har

In [9]:
HTML(design_model.animate())

In [12]:
design_model.get_seqs()

['RILLHVMVRCPGKMHMITYEFDDPQELQQVMEEIKTMLRKHSDMCVCICFKMPSPAECIVCMIAAYALAREVGYTRIPLLIVPGWCTVIAFR']

In [13]:
design_model.plot_pdb()

In [14]:
design_model.save_pdb(f"{example}.design.pdb")

# hallucination

In [5]:
protocol = "hallucination" #@param ["fixbb", "hallucination", "binder"] 
seq_mode = "logits" #@param ["softmax","softmax_gumbel","logits"] 
length = 100 #@param {type:"raw"}

clear_mem()
design_model = mk_design_model(**OPT, protocol=protocol, seq_mode=seq_mode)
design_model.prep_inputs(length)

In [6]:
design_model.design(50)

0	 model: 3 recycle: 0 n_hard: 100 loss: 1.559 pae: 0.643 plddt: 0.651 con: 0.530
1	 model: 4 recycle: 0 n_hard: 100 loss: 1.480 pae: 0.625 plddt: 0.636 con: 0.437
2	 model: 1 recycle: 0 n_hard: 100 loss: 1.324 pae: 0.549 plddt: 0.550 con: 0.451
8	 model: 2 recycle: 0 n_hard: 100 loss: 1.175 pae: 0.467 plddt: 0.484 con: 0.446
16	 model: 3 recycle: 0 n_hard: 100 loss: 1.150 pae: 0.463 plddt: 0.486 con: 0.402
22	 model: 2 recycle: 0 n_hard: 100 loss: 1.125 pae: 0.435 plddt: 0.434 con: 0.512
26	 model: 2 recycle: 0 n_hard: 100 loss: 1.108 pae: 0.431 plddt: 0.424 con: 0.506
31	 model: 0 recycle: 0 n_hard: 100 loss: 1.099 pae: 0.437 plddt: 0.423 con: 0.477
32	 model: 0 recycle: 0 n_hard: 100 loss: 1.063 pae: 0.408 plddt: 0.402 con: 0.506
36	 model: 0 recycle: 0 n_hard: 100 loss: 1.000 pae: 0.359 plddt: 0.349 con: 0.585
42	 model: 3 recycle: 0 n_hard: 100 loss: 0.959 pae: 0.326 plddt: 0.332 con: 0.600


In [7]:
design_model.plot_pdb()

In [8]:
design_model.design(100)

53	 model: 3 recycle: 0 n_hard: 100 loss: 0.897 pae: 0.272 plddt: 0.273 con: 0.706
59	 model: 2 recycle: 0 n_hard: 100 loss: 0.889 pae: 0.281 plddt: 0.278 con: 0.661
62	 model: 1 recycle: 0 n_hard: 100 loss: 0.856 pae: 0.274 plddt: 0.278 con: 0.607
82	 model: 2 recycle: 0 n_hard: 100 loss: 0.837 pae: 0.263 plddt: 0.258 con: 0.632
83	 model: 2 recycle: 0 n_hard: 100 loss: 0.798 pae: 0.244 plddt: 0.234 con: 0.640
87	 model: 1 recycle: 0 n_hard: 100 loss: 0.777 pae: 0.204 plddt: 0.197 con: 0.750


In [12]:
design_model.plot_pdb()

# binder (binder hallucination)

In [13]:
protocol = "binder" #@param ["fixbb", "hallucination", "binder"] 
seq_mode = "softmax" #@param ["softmax","softmax_gumbel","logits"] 
target_pdb = "5TZQ" #@param {type:"string"}
chain = "A" #@param ["A", "B", "C"] {allow-input: true}
binder_length =  26#@param {type:"integer"}

!wget -qnc https://files.rcsb.org/view/{target_pdb}.pdb
clear_mem()
design_model = mk_design_model(**OPT, protocol=protocol, seq_mode=seq_mode)
design_model.prep_inputs(pdb_filename=f"{target_pdb}.pdb",
                         chain=chain, binder_len=binder_length)

In [16]:
design_model.design(100, hard=False)

0	 model: 0 recycle: 0 n_hard: 0 loss: 2.148 pae: 0.335 plddt: 0.465 con_intra: 1.304 con_inter: 4.091
1	 model: 1 recycle: 0 n_hard: 0 loss: 1.859 pae: 0.331 plddt: 0.353 con_intra: 0.982 con_inter: 3.718
2	 model: 1 recycle: 0 n_hard: 0 loss: 1.755 pae: 0.338 plddt: 0.343 con_intra: 0.918 con_inter: 3.376
6	 model: 1 recycle: 0 n_hard: 0 loss: 1.513 pae: 0.326 plddt: 0.208 con_intra: 0.939 con_inter: 2.978
8	 model: 0 recycle: 0 n_hard: 0 loss: 1.453 pae: 0.331 plddt: 0.137 con_intra: 1.126 con_inter: 2.814
9	 model: 0 recycle: 0 n_hard: 0 loss: 1.345 pae: 0.319 plddt: 0.222 con_intra: 1.062 con_inter: 2.156
12	 model: 0 recycle: 0 n_hard: 0 loss: 1.249 pae: 0.288 plddt: 0.321 con_intra: 0.897 con_inter: 1.663
14	 model: 1 recycle: 0 n_hard: 0 loss: 1.224 pae: 0.260 plddt: 0.332 con_intra: 0.709 con_inter: 1.816
16	 model: 0 recycle: 0 n_hard: 0 loss: 1.208 pae: 0.284 plddt: 0.245 con_intra: 0.904 con_inter: 1.810
17	 model: 0 recycle: 0 n_hard: 0 loss: 1.200 pae: 0.249 plddt: 0.288 

In [17]:
HTML(design_model.animate())

In [20]:
design_model.plot_pdb()

In [18]:
design_model.get_seqs()

['MKDETYYIALNMIQNYIIEYNTNKPRKSFVIDSISYDVLKAACKSVIKTNYNEFDIIISRNIDFNVIVTQVLEDKINWGRIITIIAFCAYYSKKVPQYYDGIISEAITDAILSKYRSWFIDQDYWNGIRIYKFNCLLVEPEMAAMMRAMVAEILRELG']

In [19]:
design_model.save_pdb(f"{target_pdb}.binder.pdb")

In [5]:
protocol = "binder" #@param ["fixbb", "hallucination", "binder"] 
seq_mode = "softmax" #@param ["softmax","softmax_gumbel","logits"] 
target_pdb = "4MZK" #@param {type:"string"}
chain = "A" #@param ["A", "B", "C"] {allow-input: true}
binder_length =  19#@param {type:"integer"}

!wget -qnc https://files.rcsb.org/view/{target_pdb}.pdb
clear_mem()
design_model = mk_design_model(**OPT, protocol=protocol, seq_mode=seq_mode)
design_model.prep_inputs(pdb_filename=f"{target_pdb}.pdb",
                         chain=chain, binder_len=binder_length)

In [9]:
design_model.design(100, restart=True, hard=False, weight={"con_intra":0.0,"plddt":0.0})

0	 model: 0 recycle: 0 n_hard: 0 loss: 0.949 pae: 0.373 plddt: 0.527 con_intra: 0.723 con_inter: 2.302
1	 model: 1 recycle: 0 n_hard: 0 loss: 0.780 pae: 0.347 plddt: 0.547 con_intra: 0.709 con_inter: 1.731
6	 model: 1 recycle: 0 n_hard: 0 loss: 0.746 pae: 0.334 plddt: 0.560 con_intra: 0.591 con_inter: 1.645
8	 model: 1 recycle: 0 n_hard: 0 loss: 0.730 pae: 0.334 plddt: 0.513 con_intra: 0.361 con_inter: 1.586
12	 model: 1 recycle: 0 n_hard: 0 loss: 0.707 pae: 0.331 plddt: 0.530 con_intra: 0.400 con_inter: 1.501
16	 model: 1 recycle: 0 n_hard: 0 loss: 0.703 pae: 0.312 plddt: 0.581 con_intra: 0.359 con_inter: 1.563
18	 model: 1 recycle: 0 n_hard: 0 loss: 0.694 pae: 0.319 plddt: 0.576 con_intra: 0.444 con_inter: 1.499
22	 model: 1 recycle: 0 n_hard: 0 loss: 0.692 pae: 0.331 plddt: 0.570 con_intra: 0.371 con_inter: 1.445
25	 model: 1 recycle: 0 n_hard: 0 loss: 0.648 pae: 0.312 plddt: 0.600 con_intra: 0.437 con_inter: 1.344
32	 model: 0 recycle: 0 n_hard: 0 loss: 0.630 pae: 0.345 plddt: 0.57

In [10]:
HTML(design_model.animate())