<a href="https://colab.research.google.com/github/sokrypton/af_backprop/blob/main/examples/fixbb_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Fixed backbone Design
Note: this notebook was not optimized or thoroughly tested for protein design. It was designed to test backprop functionality.

install

In [None]:
%%bash
if [ ! -d af_backprop ]; then
  git clone https://github.com/sokrypton/af_backprop.git
  pip -q install biopython dm-haiku ml-collections py3Dmol
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi
wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py

Cloning into 'af_backprop'...


import libraries

In [None]:
import sys
sys.path.append('/content/af_backprop')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol
from IPython.display import HTML

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline, templates
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from utils import *
import colabfold as cf

In [None]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

def make_animation(positions, seq, pos_ref=None, line_w=2.0, dpi=100, interval=60):

  def align(P, Q):
    p = P - P.mean(0,keepdims=True)
    q = Q - Q.mean(0,keepdims=True)
    return p @ cf.kabsch(p,q)

  if pos_ref is None: pos_ref = positions[-1]
  pos_ref = pos_ref - pos_ref.mean(0,keepdims=True)
  pos_ref = pos_ref @ cf.kabsch(pos_ref,pos_ref,return_v=True)

  new_positions = []
  for i in range(len(positions)):
    new_positions.append(align(positions[i],pos_ref))
  pos = np.asarray(new_positions)

  fig = plt.figure()
  gs = GridSpec(4,1, figure=fig)
  ax1,ax2 = fig.add_subplot(gs[:3,:]),fig.add_subplot(gs[3:,:])
  fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 0.9, left = 0.1, hspace = 0, wspace = 0)
  fig.set_figwidth(6); fig.set_figheight(6)
  fig.set_dpi(dpi)

  ax2.set_xlabel("positions")
  ax2.set_yticks([])
  if seq[0].ndim == 3: ax2.set_ylabel("sequences")

  z_min,z_max = pos_ref[...,2].min(),pos_ref[...,2].max()
  xy_min,xy_max = pos_ref[...,:2].min() - 5, pos_ref[...,:2].max() + 5
  for ax in [ax1]:
    ax.set_xlim(xy_min, xy_max); ax.set_ylim(xy_min, xy_max)
    ax.axis(False)

  ims = []
  for k,(x,s) in enumerate(zip(pos,seq)):
    ims.append([cf.add_text("colored by N→C", ax1),
                cf.plot_pseudo_3D(x, ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max)])
    if s.ndim == 2:
      ims[-1].append(ax2.imshow(s.T, animated=True, cmap="bwr_r",vmin=-1, vmax=1))
    else:
      ims[-1].append(ax2.imshow(s.argmax(-1), animated=True, cmap="rainbow"))

  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=interval)
  plt.close()
  return ani.to_html5_video()

setup model

In [None]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
# NOTE: for denovo-designed proteins we find 1 recycle is enough to predict them
model_config.model.num_recycle = 1
model_config.data.common.num_recycle = 1

# backprop through recycles
# we find adding backprop through all recycles does not seem to help
# but maybe it will in other contexts/problems.
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# dropout (disable dropout by uncomment line below)
# model_config = set_dropout(model_config, 0.0)

# setup model
MODEL_PARAMS = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
MODEL_RUNNER = model.RunModel(model_config, MODEL_PARAMS[0], is_training=True)

# load the other model_params (during optimization, we randomly pick which model to use)
for model_name in ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]:
  params = data.get_model_haiku_params(model_name, '.')
  MODEL_PARAMS.append({k: params[k] for k in MODEL_RUNNER.params.keys()})

# combine model params (for parallel compute)
MODEL_PARAMS_multi = jax.tree_multimap(lambda *values: jnp.stack(values, axis=0), *MODEL_PARAMS)

setup gradient

In [None]:
def get_grad_fn(inputs, batch):

  # setup function to get gradients
  def mod(params, key, model_params, opt):
    ############################
    # set amino acid sequence
    ############################
    seq_logits = params["seq"]
    if seq_logits.ndim == 3:
      # if MSA, select different sequence to be first at each iteration
      i = jax.random.randint(key,[],0,seq_logits.shape[0])
      seq_logits = seq_logits.at[0].set(seq_logits[i]).at[i].set(seq_logits[0])
    seq = soft_seq(seq_logits)
    
    if "mask" in opt:
      mask = opt["mask"][:,None]
      pseudo_seq = mask * seq + (1-mask) * seq_logits
      pseudo_seq_save = mask * soft_seq(params["seq"]) + (1-mask) * params["seq"]
    else:
      pseudo_seq = seq
      pseudo_seq_save = soft_seq(params["seq"])

    inputs_mod = inputs.copy()
    update_seq(pseudo_seq, inputs_mod, msa_input=(seq.ndim == 3))
    
    ####################
    # set sidechains identity
    ####################
    N,L = inputs_mod["aatype"].shape[:2]
    if seq.ndim == 3:
      aatype = jnp.zeros((N,L,21)).at[...,:20].set(seq[0])
    else:
      aatype = jnp.zeros((N,L,21)).at[...,:20].set(seq)

    if "mask" in opt:
      ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)
      aatype_ala = jnp.zeros((N,L,21)).at[:].set(ALA)
      aatype_pseudo = mask * aatype + (1-mask) * aatype_ala
    else:
      aatype_pseudo = aatype

    update_aatype(aatype_pseudo, inputs_mod)
    
    # get output
    outputs = MODEL_RUNNER.apply(model_params, key, inputs_mod)
            
    # losses
    dgram_loss = get_dgram_loss(batch, outputs, model_config=MODEL_RUNNER.config)
    fape_loss = get_fape_loss(batch, outputs, model_config=MODEL_RUNNER.config)

    # note: rmsd monitored but not used in loss
    rmsd_loss = jnp_rmsd(batch["all_atom_positions"][:,1,:],
                         outputs["structure_module"]["final_atom_positions"][:,1,:])

    losses = {"rmsd":rmsd_loss,"dgram":dgram_loss,"fape":fape_loss} 

    # note: we find dgram loss to easier to backprop through
    loss = dgram_loss # + fape_loss

    # if MSA add entropy loss
    if seq.ndim == 3 and "ent" in opt:
      seq_prf = seq.mean(0)
      ent_loss = -(seq_prf * jnp.log(seq_prf + 1e-8)).sum(-1).mean()
      losses["ent"] = ent_loss
      loss += ent_loss * opt["ent"]

    outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
            "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

    return loss, ({"losses":losses,"outputs":outs,
                   "seq":seq,"pseudo_seq":pseudo_seq_save})
  
  return jax.value_and_grad(mod, has_aux=True, argnums=0)

design function

In [None]:
def prep_inputs(pdb_filename, chain=None, num_seq=1):
  # setup inputs
  protein_obj = protein.from_pdb_string(pdb_to_string(pdb_filename), chain_id=chain)
  batch = {'aatype': protein_obj.aatype,
           'all_atom_positions': protein_obj.atom_positions,
           'all_atom_mask': protein_obj.atom_mask}
  batch.update(all_atom.atom37_to_frames(**batch))
  
  starting_sequence = "".join([order_restype[a] for a in protein_obj.aatype])
  # one_hot_encode
  feature_dict = {
      **pipeline.make_sequence_features(sequence=starting_sequence,
                                        description="none",
                                        num_res=len(starting_sequence)),
      **pipeline.make_msa_features(msas=[num_seq*[starting_sequence]],
                                   deletion_matrices=[num_seq*[[0]*len(starting_sequence)]]),
  }

  # number of sequences
  MODEL_RUNNER.config.data.eval.max_msa_clusters = num_seq
  MODEL_RUNNER.config.data.common.max_extra_msa = 1
  MODEL_RUNNER.config.data.eval.masked_msa_replace_fraction = 0
  inputs = MODEL_RUNNER.process_features(feature_dict, random_seed=0)

  if num_seq > 1:
    inputs["msa_row_mask"] = jnp.ones_like(inputs["msa_row_mask"])
    inputs["msa_mask"] = jnp.ones_like(inputs["msa_mask"])
  
  wt_seq = jax.nn.one_hot(protein_obj.aatype,20)
  return {"seq":wt_seq, "inputs":inputs, "batch":batch, "num_seq":num_seq}

In [None]:
def do_design(inputs, outputs=None, iters=200, seed=None,
              ent_weight=0.01, one_hot=True, one_hot_switch=5,
              sample=1, save_traj=True, restart=False):
  '''
  sample = number of alphafold models to sample gradients from 
  (if more than 1, gradients and losses are averaged across models specified)
  '''
  if seed is None:
    seed = np.random.randint(10000)

  def init_state(o, init_fun):
    o["L"],o["A"] = inputs["seq"].shape
    o["key"] = jax.random.PRNGKey(seed)
    if inputs["num_seq"] > 1:
      seq_shape = (inputs["num_seq"],o["L"],o["A"])
    else:
      seq_shape = (o["L"],o["A"])
    if one_hot:
      seq = 0.01 * jax.random.normal(o["key"],seq_shape)
      o["mask"] = jnp.ones((o["L"],))
      o["state"] = init_fun({"seq":seq})
    else:
      seq = jnp.zeros(seq_shape)
      o["mask"] = jnp.zeros((o["L"],))
      o["state"] = init_fun({"seq":seq})
    o["k"] = 0
    o["xyz"],o["seq"] = [],[]
    o["rmsd"],o["best_outs"] = np.inf, None
    return o    

  if outputs is None:
    o = {}
    # gradient function
    grad_fn = get_grad_fn(inputs["inputs"], inputs['batch'])
    o["grad_fn"] = jax.jit(grad_fn)
    o["grad_fn_multi"] = jax.jit(jax.vmap(grad_fn,(None,None,0,None)))
  
    # setup optimizer    
    o["adam"] = init_fun, update_fun, get_params = adam(step_size=1e-2)
    o = init_state(o, init_fun)

  else:
    o = outputs
    init_fun, update_fun, get_params = o["adam"]
    if restart: o = init_state(o, init_fun)  
  
  def step(k, state, key, model_params, opt, multi=False):
    if multi:
      (loss, outs), grad = o["grad_fn_multi"](get_params(state), key, model_params, opt)
      # take the mean of gradients and loss
      grad = jax.tree_map(lambda x: x.mean(0), grad)
      outs["losses"] = jax.tree_map(lambda x: x.mean(0), outs["losses"])
      loss = loss.mean(0)

      # use first model
      outs["outputs"] = jax.tree_map(lambda x:x[0], outs["outputs"])
      outs["seq"] = outs["seq"][0]
      outs["pseudo_seq"] = outs["pseudo_seq"][0]
    else:
      (loss, outs), grad = o["grad_fn"](get_params(state), key, model_params, opt)

    grad["seq"] = grad["seq"] / jnp.sqrt(jnp.square(grad["seq"]).sum([-1,-2],keepdims=True))
    state = update_fun(k, grad, state)
    return state, outs
  
  k = o["k"]
  while k < o["k"] + iters: # number of iterations
    o["key"], subkey = jax.random.split(o["key"])
    opt = {"mask":o["mask"],"ent":ent_weight}
    ## sample gradients
    if sample == 1:
      n = jax.random.randint(subkey,[],0,5)
      o["state"], outs = step(k, o["state"], subkey, MODEL_PARAMS[n], opt)
    else:
      if sample == 5:
        model_params_multi = MODEL_PARAMS_multi
      else:
        n = jax.random.choice(subkey,jnp.arange(5),(sample,),replace=False)
        model_params_multi = jax.tree_map(lambda x:x[n],MODEL_PARAMS_multi)

      o["state"], outs = step(k, o["state"], subkey, model_params_multi, opt, multi=True)

    seq = outs["seq"].argmax(-1)
    seq_id = (seq == inputs["seq"].argmax(-1)).mean()
    losses = outs["losses"]
    
    if (k+1) % 10 == 0:
      losses_print = f'dgram: {losses["dgram"]:.3f} fape: {losses["fape"]:.3f} rmsd: {losses["rmsd"]:.3f}'
      if "ent" in losses: losses_print += f' ent: {losses["ent"]:.3f}'
      print(f'{k+1} {int(o["mask"].sum())} {losses_print} seqid: {seq_id:.3f}')

    # save for animation
    if save_traj:
      o["xyz"].append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
      o["seq"].append(np.asarray(outs["pseudo_seq"]))

    if o["mask"].sum() == o["L"] and losses["rmsd"] < o["rmsd"]:
      o["rmsd"] = losses["rmsd"]
      o["best_outs"] = outs

    if one_hot and (k+1) % one_hot_switch == 0 and o["mask"].sum() < o["L"]:          
      # pick random position to flip to one_hot
      o["mask"] = o["mask"].at[np.random.choice(np.where(o["mask"] == 0)[0])].set(1)    

    k += 1
  o["k"] = k  
  return o

# Single Sequence Design

In [None]:
example = "1QYS"
!wget -qnc https://files.rcsb.org/view/{example}.pdb

inputs = prep_inputs(f"{example}.pdb", chain="A")

In [None]:
# do it
outputs = do_design(inputs)

10 92 dgram: 3.941 fape: 2.406 rmsd: 17.963 seqid: 0.054
20 92 dgram: 3.336 fape: 1.897 rmsd: 9.270 seqid: 0.087
30 92 dgram: 3.336 fape: 1.851 rmsd: 10.994 seqid: 0.065
40 92 dgram: 3.251 fape: 1.855 rmsd: 11.923 seqid: 0.065
50 92 dgram: 3.203 fape: 1.722 rmsd: 11.401 seqid: 0.065
60 92 dgram: 3.201 fape: 1.764 rmsd: 10.564 seqid: 0.054
70 92 dgram: 3.110 fape: 1.654 rmsd: 9.644 seqid: 0.033
80 92 dgram: 2.854 fape: 1.250 rmsd: 3.142 seqid: 0.076
90 92 dgram: 2.829 fape: 1.277 rmsd: 3.355 seqid: 0.076
100 92 dgram: 2.723 fape: 1.324 rmsd: 2.661 seqid: 0.054
110 92 dgram: 2.742 fape: 1.233 rmsd: 3.362 seqid: 0.054
120 92 dgram: 2.712 fape: 1.205 rmsd: 2.760 seqid: 0.065
130 92 dgram: 2.686 fape: 1.265 rmsd: 2.637 seqid: 0.076
140 92 dgram: 2.687 fape: 1.156 rmsd: 2.517 seqid: 0.076
150 92 dgram: 2.714 fape: 1.222 rmsd: 2.478 seqid: 0.065
160 92 dgram: 2.809 fape: 1.057 rmsd: 2.867 seqid: 0.076
170 92 dgram: 2.619 fape: 1.103 rmsd: 2.600 seqid: 0.076
180 92 dgram: 2.602 fape: 1.080 rms

In [None]:
# lets run for another 200 iterations (if  rmsd is still high)
# outputs = do_design(inputs, outputs, iters=200)

In [None]:
HTML(make_animation(outputs["xyz"],
                    outputs["seq"],
                    pos_ref=inputs["batch"]["all_atom_positions"][:,1,:]))

In [None]:
save_pdb(outputs["best_outs"])
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [None]:
seq = outputs["best_outs"]["seq"].argmax(-1)
if seq.ndim == 1:
  print("".join([order_restype[a] for a in seq]))
else:
  for s in seq:
    print("".join([order_restype[a] for a in s]))

MNIIIIYYHTGTKKVHKQHHFPTRQQNIRYIREFYRMYREKSKTHHVICVNTNSPGEGQQIMQRIVALHKGYIRTSMKQHNSGNNIHLHFTT


#MSA design
It is EASY to design an adversarial MSA to reproduce any fold or constraint. To avoid the adversarial mode, we add an entropy term to push the sequences to be similar.

In [None]:
example = "6WVS"
!wget -qnc https://files.rcsb.org/view/{example}.pdb

# to design MSA set num_seq > 1
inputs = prep_inputs(f"{example}.pdb", chain="A", num_seq=32)

In [None]:
outputs = do_design(inputs, iters=50, ent_weight=0.01)
# increase entropy weight every 50 iterations
for ent_weight in [0.02,0.04,0.08,0.16,0.32,0.64]:
  outputs = do_design(inputs, outputs, iters=50, ent_weight=ent_weight)

#experimental 
For more complex targets, we find directly optimizing a single one_hot encoded sequence (even with gumbel-st) to be very challenging. Instead, we start with a continious representation and then one residue at a time switch to one_hot in second round of optimization.

In [None]:
example = "1QJG"
!wget -qnc https://files.rcsb.org/view/{example}.pdb

inputs = prep_inputs(f"{example}.pdb", chain="A")

**design adversarial vector**

In [None]:
outputs = do_design(inputs, one_hot=False)
#
# NOTE: if you designing proteins of same length and number of sequences
# you can speedup the calculation by reusing the previously compiled model
# and adding a "restart" flag
#
# outputs = do_design(inputs, outputs, one_hot=False, restart=True)

10 0 3.889 2.846 19.710 0.056
20 0 3.742 3.433 26.450 0.080
30 0 3.321 2.694 19.280 0.080
40 0 3.140 2.108 15.102 0.096
50 0 2.934 2.109 17.310 0.104
60 0 2.758 1.826 15.434 0.112
70 0 2.608 1.364 5.182 0.080
80 0 2.564 1.464 7.075 0.080
90 0 2.376 1.229 3.618 0.096
100 0 2.219 1.104 3.450 0.112
110 0 2.227 1.005 3.368 0.112
120 0 2.044 0.880 2.862 0.120
130 0 1.959 0.763 2.613 0.136
140 0 1.972 0.839 2.980 0.128
150 0 1.884 0.671 2.007 0.144
160 0 1.660 0.593 1.522 0.160
170 0 1.644 0.699 1.588 0.152
180 0 1.628 0.574 1.398 0.184
190 0 1.621 0.586 1.386 0.184
200 0 1.553 0.524 1.270 0.200


**design one_hot**

In [None]:
# one_hot_switch=5, every 5 iterations switch one position to one_hot
outputs = do_design(inputs, outputs, iters=outputs["L"]*5+100, one_hot=True, one_hot_switch=5)

210 1 1.430 0.520 1.066 0.224
220 3 1.381 0.567 1.226 0.216
230 5 1.352 0.416 1.082 0.216
240 7 1.454 0.473 1.272 0.224
250 9 1.314 0.548 1.049 0.240
260 11 1.294 0.402 1.231 0.264
270 13 1.257 0.371 0.925 0.256
280 15 1.272 0.399 1.037 0.264
290 17 1.251 0.480 0.942 0.272
300 19 1.158 0.378 0.749 0.288
310 21 1.244 0.376 0.816 0.304
320 23 1.185 0.334 0.805 0.304
330 25 1.180 0.305 0.810 0.304
340 27 1.163 0.407 0.905 0.312
350 29 1.217 0.289 0.679 0.296
360 31 2.657 1.390 13.724 0.296
370 33 1.151 0.358 0.917 0.304
380 35 1.127 0.365 0.739 0.304
390 37 1.131 0.350 0.676 0.304
400 39 1.179 0.278 0.607 0.296
410 41 1.648 0.458 1.806 0.304
420 43 1.277 0.297 0.723 0.312
430 45 1.149 0.343 0.749 0.312
440 47 1.956 0.569 2.314 0.304
450 49 1.148 0.481 0.745 0.296
460 51 1.210 0.320 0.796 0.280
470 53 2.235 1.172 10.291 0.296
480 55 1.094 0.279 0.642 0.288
490 57 1.110 0.269 0.647 0.296
500 59 1.155 0.470 0.526 0.312
510 61 1.152 0.284 0.453 0.320
520 63 1.093 0.295 0.510 0.320
530 65 1.09

In [None]:
HTML(make_animation(outputs["xyz"],
                    outputs["seq"],
                    pos_ref=inputs["batch"]["all_atom_positions"][:,1,:]))

In [None]:
outputs["rmsd"]

DeviceArray(2.0402088, dtype=float32)

In [None]:
save_pdb(outputs["best_outs"])
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [None]:
seq = outputs["best_outs"]["seq"].argmax(-1)
if seq.ndim == 1:
  print("".join([order_restype[a] for a in seq]))
else:
  for s in seq:
    print("".join([order_restype[a] for a in s]))

MVTAERMLRVVRRFVRYMNRFDVDAIVSLFRPDAKINPHAGTTPVETRDQIRDYWAMMLMYPYQWAITWPPEATNNHATACCTMTHAVAGDVYQYCYTMTMLFNATGRVDYMNWYYTPESIHPGE
