<a href="https://colab.research.google.com/github/sokrypton/af_backprop/blob/main/examples/fixbb_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Fixed backbone Design
Note: this notebook was not optimized or thoroughly tested for protein design. It was designed to test backprop functionality.

install

In [1]:
%%bash
if [ ! -d af_backprop ]; then
  git clone https://github.com/sokrypton/af_backprop.git
  pip -q install biopython dm-haiku ml-collections py3Dmol
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi
wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py

Cloning into 'af_backprop'...


import libraries

In [2]:
import sys
sys.path.append('/content/af_backprop')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol
from IPython.display import HTML

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline, templates
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from utils import *
import colabfold as cf

In [3]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

def make_animation(positions, seq, pos_ref=None, line_w=2.0, dpi=100, interval=60):

  def align(P, Q):
    p = P - P.mean(0,keepdims=True)
    q = Q - Q.mean(0,keepdims=True)
    return p @ cf.kabsch(p,q)

  if pos_ref is None: pos_ref = positions[-1]
  pos_ref = pos_ref - pos_ref.mean(0,keepdims=True)
  pos_ref = pos_ref @ cf.kabsch(pos_ref,pos_ref,return_v=True)

  new_positions = []
  for i in range(len(positions)):
    new_positions.append(align(positions[i],pos_ref))
  pos = np.asarray(new_positions)

  fig = plt.figure()
  gs = GridSpec(4,1, figure=fig)
  ax1,ax2 = fig.add_subplot(gs[:3,:]),fig.add_subplot(gs[3:,:])
  fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 0.9, left = 0.1, hspace = 0, wspace = 0)
  fig.set_figwidth(5);fig.set_figheight(6)
  fig.set_dpi(dpi)

  ax2.set_xlabel("positions")
  ax2.set_yticks([])

  z_min,z_max = pos_ref[...,2].min(),pos_ref[...,2].max()
  xy_min,xy_max = pos_ref[...,:2].min() - 5, pos_ref[...,:2].max() + 5
  for ax in [ax1]:
    ax.set_xlim(xy_min, xy_max); ax.set_ylim(xy_min, xy_max)
    ax.axis(False)

  ims = []
  for k,(x,s) in enumerate(zip(pos,seq)):
    ims.append([cf.add_text("colored by N→C", ax1),
                cf.plot_pseudo_3D(x, ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max),
                ax2.imshow(s.T, animated=True, cmap="bwr_r",vmin=-1, vmax=1)])
    
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=interval)
  plt.close()
  return ani.to_html5_video()

setup model

In [4]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
# NOTE: for denovo-designed proteins we find 1 recycle is enough to predict them
model_config.model.num_recycle = 1
model_config.data.common.num_recycle = 1

# backprop through recycles
# we find adding backprop through all recycles does not seem to help
# but maybe it will in other contexts/problems.
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# number of sequences
NUM = 1
model_config.data.eval.max_msa_clusters = NUM
model_config.data.common.max_extra_msa = 1
model_config.data.eval.masked_msa_replace_fraction = 0

# dropout (disable dropout by uncomment line below)
# model_config = set_dropout(model_config, 0.0)

# setup model
MODEL_PARAMS = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
MODEL_RUNNER = model.RunModel(model_config, MODEL_PARAMS[0], is_training=True)

# load the other model_params (during optimization, we randomly pick which model to use)
for model_name in ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]:
  params = data.get_model_haiku_params(model_name, '.')
  MODEL_PARAMS.append({k: params[k] for k in MODEL_RUNNER.params.keys()})

# combine model params (for parallel compute)
MODEL_PARAMS_multi = jax.tree_multimap(lambda *values: jnp.stack(values, axis=0), *MODEL_PARAMS)

setup gradient

In [32]:
def get_grad_fn(inputs, batch):

  # setup function to get gradients
  def mod(params, key, model_params, opt):
    ############################
    # set amino acid sequence
    ############################
    seq_logits = params["seq"]
    seq = soft_seq(seq_logits)
    
    if "mask" in opt:
      mask = opt["mask"][:,None]
      pseudo_seq = mask * seq + (1-mask) * seq_logits
    else:
      pseudo_seq = seq

    inputs_mod = inputs.copy()
    update_seq(pseudo_seq, inputs_mod)
    
    ####################
    # set sidechains identity
    ####################
    N,L = inputs_mod["aatype"].shape[:2]
    aatype = jnp.zeros((N,L,21)).at[...,:20].set(seq)

    if "mask" in opt:
      ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)
      aatype_ala = jnp.zeros((N,L,21)).at[:].set(ALA)
      aatype_pseudo = mask * aatype + (1-mask) * aatype_ala
    else:
      aatype_pseudo = aatype

    update_aatype(aatype_pseudo, inputs_mod)
    
    # get output
    outputs = MODEL_RUNNER.apply(model_params, key, inputs_mod)
            
    # losses
    dgram_loss = get_dgram_loss(batch, outputs, model_config=MODEL_RUNNER.config)
    fape_loss = get_fape_loss(batch, outputs, model_config=MODEL_RUNNER.config)

    # note: rmsd monitored but not used in loss
    rmsd_loss = jnp_rmsd(batch["all_atom_positions"][:,1,:],
                         outputs["structure_module"]["final_atom_positions"][:,1,:])

    # note: we find dgram loss to easier to backprop through
    loss = dgram_loss # + fape_loss
    outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
            "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

    return loss, ({"losses":{"rmsd":rmsd_loss,"dgram":dgram_loss,"fape":fape_loss},
                   "outputs":outs,"seq":seq,"pseudo_seq":pseudo_seq})
  
  return jax.value_and_grad(mod, has_aux=True, argnums=0)

design function

In [12]:
def prep_inputs(pdb_filename, chain=None):
  # setup inputs
  protein_obj = protein.from_pdb_string(pdb_to_string(pdb_filename), chain_id=chain)
  batch = {'aatype': protein_obj.aatype,
           'all_atom_positions': protein_obj.atom_positions,
           'all_atom_mask': protein_obj.atom_mask}
  batch.update(all_atom.atom37_to_frames(**batch))
  
  starting_sequence = "".join([order_restype[a] for a in protein_obj.aatype])
  # one_hot_encode
  feature_dict = {
      **pipeline.make_sequence_features(sequence=starting_sequence,description="none",num_res=len(starting_sequence)),
      **pipeline.make_msa_features(msas=[[starting_sequence]],deletion_matrices=[[[0]*len(starting_sequence)]]),
  }
  inputs = MODEL_RUNNER.process_features(feature_dict, random_seed=0)
  
  wt_seq = jax.nn.one_hot(inputs["aatype"][0],20)
  return {"seq":wt_seq, "inputs":inputs, "batch":batch}

In [77]:
def do_design(inputs, outputs=None, iters=200, seed=None, one_hot=True, one_hot_switch=5,
              sample_gradients=True, save_traj=True, restart=False):
  '''
  sample_gradients=True, during optimization gradients are sampled from the 5 models
  sample_gradients=False, all 5 models are used and gradients are averaged
  '''
  if seed is None:
    seed = np.random.randint(10000)

  def init_state(o, init_fun):
    o["L"],o["A"] = inputs["seq"].shape
    o["key"] = jax.random.PRNGKey(seed)
    if one_hot:
      o["mask"] = jnp.ones((o["L"],))
      o["state"] = init_fun({"seq":0.01*jax.random.normal(o["key"],(o["L"],o["A"]))})
    else:
      o["mask"] = jnp.zeros((o["L"],))
      o["state"] = init_fun({"seq":jnp.zeros((o["L"],o["A"]))})
    o["k"] = 0
    o["xyz"],o["seq"] = [],[]
    o["rmsd"],o["best_outs"] = np.inf, None
    return o    

  if outputs is None:
    o = {}
    # gradient function
    grad_fn = get_grad_fn(inputs["inputs"], inputs['batch'])
    o["grad_fn"] = jax.jit(grad_fn)
    o["grad_fn_multi"] = jax.jit(jax.vmap(grad_fn,(None,None,0,None)))
  
    # setup optimizer    
    o["adam"] = init_fun, update_fun, get_params = adam(step_size=1e-2)
    o = init_state(o, init_fun)

  else:
    o = outputs
    init_fun, update_fun, get_params = o["adam"]
    if restart: o = init_state(o, init_fun)  
  
  def step(k, state, key, model_params, opt, multi=False):
    if multi:
      (loss, outs), grad = o["grad_fn_multi"](get_params(state), key, model_params, opt)
      # take the mean of gradients and loss
      grad = jax.tree_map(lambda x: x.mean(0), grad)
      outs["losses"] = jax.tree_map(lambda x: x.mean(0), outs["losses"])
      loss = loss.mean(0)

      # use first model
      outs["outputs"] = jax.tree_map(lambda x:x[0], outs["outputs"])
      outs["seq"] = outs["seq"][0]
      outs["pseudo_seq"] = outs["pseudo_seq"][0]
    else:
      (loss, outs), grad = o["grad_fn"](get_params(state), key, model_params, opt)

    grad["seq"] = grad["seq"] / jnp.sqrt(jnp.square(grad["seq"]).sum())
    state = update_fun(k, grad, state)
    return state, outs
  
  k = o["k"]
  while k < o["k"] + iters: # number of iterations
    o["key"], subkey = jax.random.split(o["key"])
    if sample_gradients:
      ## sample gradients from one of the 5 models
      n = jax.random.randint(subkey,[],0,5)
      o["state"], outs = step(k, o["state"], subkey, MODEL_PARAMS[n], {"mask":o["mask"]})
    else:
      ## take mean gradient across 5 models
      o["state"], outs = step(k, o["state"], subkey, MODEL_PARAMS_multi, {"mask":o["mask"]}, multi=True)

    seq = outs["seq"].argmax(-1)
    seq_id = (seq == inputs["seq"].argmax(-1)).mean()
    losses = outs["losses"]
    
    if (k+1) % 10 == 0:
      print(f'{k+1} {int(o["mask"].sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

    # save for animation
    if save_traj:
      o["xyz"].append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
      o["seq"].append(np.asarray(outs["pseudo_seq"]))

    if o["mask"].sum() == o["L"] and losses["rmsd"] < o["rmsd"]:
      o["rmsd"] = losses["rmsd"]
      o["best_outs"] = outs

    if one_hot and (k+1) % one_hot_switch == 0 and o["mask"].sum() < o["L"]:          
      # pick random position to flip to one_hot
      o["mask"] = o["mask"].at[np.random.choice(np.where(o["mask"] == 0)[0])].set(1)    

    k += 1
  o["k"] = k  
  return o

example

In [13]:
example = "1QYS"
!wget -qnc https://files.rcsb.org/view/{example}.pdb
inputs = prep_inputs(f"{example}.pdb", chain="A")

In [78]:
# do it
outputs = do_design(inputs)

10 92 3.391 1.791 11.168 0.043
20 92 3.658 2.253 15.590 0.065
30 92 3.340 1.925 13.170 0.065
40 92 3.408 1.964 11.945 0.054
50 92 3.322 1.730 7.258 0.087
60 92 3.242 1.644 10.188 0.087
70 92 3.194 1.627 6.153 0.098
80 92 3.190 1.672 9.086 0.065
90 92 3.065 1.563 9.236 0.076
100 92 3.048 1.387 5.568 0.076
110 92 2.961 1.470 7.827 0.076
120 92 2.842 1.437 2.965 0.141
130 92 2.838 1.250 3.069 0.152
140 92 2.790 1.124 7.794 0.109
150 92 2.810 1.166 3.108 0.120
160 92 2.876 1.209 3.214 0.109
170 92 2.732 1.215 2.783 0.109
180 92 2.933 1.078 4.245 0.120
190 92 2.716 0.812 2.741 0.130
200 92 2.582 0.834 2.678 0.141


In [None]:
# lets run for another 200 iterations (if  rmsd is still high)
# outputs = do_design(inputs, outputs, iters=200)

In [83]:
HTML(make_animation(outputs["xyz"],
                    outputs["seq"],
                    pos_ref=inputs["batch"]["all_atom_positions"][:,1,:]))

In [84]:
save_pdb(outputs["best_outs"])
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [85]:
seq = outputs["best_outs"]["seq"].argmax(-1)
print("".join([order_restype[a] for a in seq]))

MKYIHICNSNTTKVVCKHYQCTTTAQKWTILRKMVKKFKTMGGSCVHIIIYTTTTTTVTNIMEIMKKLFNKHNLTHKHKYPTNPHPYIVWCT


#experimental 
For more complex targets, we find directly optimizing a single one_hot encoded sequence (even with gumbel-st) to be very challenging. Instead, we start with a continious representation and then one residue at a time switch to one_hot in second round of optimization.

**design adversarial vector**

In [65]:
#outputs = do_design(inputs, one_hot=False)
outputs = do_design(inputs, outputs, one_hot=False, restart=True) # we'll reuse the previously compiled model

10 0 4.990 5.077 41.531 0.011
20 0 5.711 3.746 27.049 0.011
30 0 3.465 2.238 18.453 0.022
40 0 3.809 2.333 16.223 0.022
50 0 3.298 2.328 17.460 0.033
60 0 3.300 1.846 10.243 0.043
70 0 3.256 1.697 6.598 0.033
80 0 3.082 1.717 8.262 0.033
90 0 3.064 1.552 7.098 0.033
100 0 2.805 1.183 5.143 0.065
110 0 2.478 0.853 1.837 0.065
120 0 2.325 0.804 1.816 0.054
130 0 2.149 0.611 1.457 0.065
140 0 2.033 0.407 1.342 0.087
150 0 1.849 0.358 1.014 0.120
160 0 1.838 0.359 1.140 0.141
170 0 1.764 0.334 0.957 0.174
180 0 1.991 0.362 0.970 0.174
190 0 1.686 0.337 0.780 0.185
200 0 1.660 0.320 0.819 0.185


**design one_hot**

In [67]:
# one_hot_switch=5, every 5 iterations switch one position to one_hot
outputs = do_design(inputs, outputs, iters=500, one_hot=True, one_hot_switch=5)

210 1 1.618 0.304 0.750 0.207
220 3 1.543 0.307 0.736 0.207
230 5 1.491 0.353 0.654 0.207
240 7 1.470 0.310 0.667 0.228
250 9 1.451 0.370 0.611 0.217
260 11 1.423 0.265 0.660 0.239
270 13 1.413 0.263 0.659 0.250
280 15 1.365 0.247 0.592 0.228
290 17 1.377 0.337 0.565 0.239
300 19 1.302 0.235 0.535 0.228
310 21 1.344 0.218 0.498 0.207
320 23 1.346 0.230 0.545 0.196
330 25 1.319 0.245 0.526 0.196
340 27 1.415 0.291 0.597 0.196
350 29 1.455 0.246 0.561 0.174
360 31 1.295 0.234 0.542 0.196
370 33 1.337 0.233 0.563 0.196
380 35 1.382 0.235 0.625 0.207
390 37 1.345 0.235 0.552 0.207
400 39 1.989 0.402 1.519 0.207
410 41 2.259 0.395 1.353 0.217
420 43 1.309 0.228 0.568 0.239
430 45 1.755 0.284 0.654 0.239
440 47 1.376 0.251 0.578 0.228
450 49 1.914 0.558 1.259 0.228
460 51 1.352 0.280 0.504 0.228
470 53 1.494 0.304 0.741 0.228
480 55 1.370 0.316 0.562 0.217
490 57 1.356 0.247 0.657 0.239
500 59 1.386 0.245 0.574 0.228
510 61 1.315 0.243 0.547 0.228
520 63 1.387 0.279 0.606 0.217
530 65 1.395 

In [68]:
HTML(make_animation(outputs["xyz"],
                    outputs["seq"],
                    pos_ref=inputs["batch"]["all_atom_positions"][:,1,:]))

In [69]:
save_pdb(outputs["best_outs"])
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [70]:
seq = outputs["best_outs"]["seq"].argmax(-1)
print("".join([order_restype[a] for a in seq]))

RTIIRLHLEDSGMMLYMLHVVDSWAAWDELMKAYYAMVMAMNCSNVTISVTTFMQAEAREIAEMLLAFIAAAGYTETNVNFLGPVVVVSSTK
