<a href="https://colab.research.google.com/github/sokrypton/af_backprop/blob/main/examples/fixbb_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Fixed backbone Design
Note: this notebook was not optimized or thoroughly tested for protein design. It was designed to test backprop functionality.

install

In [1]:
%%bash
if [ ! -d af_backprop ]; then
  git clone https://github.com/sokrypton/af_backprop.git
  pip -q install biopython dm-haiku ml-collections py3Dmol
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi
wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py

Cloning into 'af_backprop'...


import libraries

In [2]:
import sys
sys.path.append('/content/af_backprop')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol
from IPython.display import HTML

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline, templates
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from utils import *
import colabfold as cf

In [3]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

def make_animation(positions, seq, pos_ref=None, line_w=2.0, dpi=100, interval=60):

  def align(P, Q):
    p = P - P.mean(0,keepdims=True)
    q = Q - Q.mean(0,keepdims=True)
    return p @ cf.kabsch(p,q)

  if pos_ref is None: pos_ref = positions[-1]
  pos_ref = pos_ref - pos_ref.mean(0,keepdims=True)
  pos_ref = pos_ref @ cf.kabsch(pos_ref,pos_ref,return_v=True)

  new_positions = []
  for i in range(len(positions)):
    new_positions.append(align(positions[i],pos_ref))
  pos = np.asarray(new_positions)

  fig = plt.figure()
  gs = GridSpec(4,1, figure=fig)
  ax1,ax2 = fig.add_subplot(gs[:3,:]),fig.add_subplot(gs[3:,:])
  fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 0.9, left = 0.1, hspace = 0, wspace = 0)
  fig.set_figwidth(5);fig.set_figheight(6)
  fig.set_dpi(dpi)

  ax2.set_xlabel("positions")
  ax2.set_yticks([])

  z_min,z_max = pos_ref[...,2].min(),pos_ref[...,2].max()
  xy_min,xy_max = pos_ref[...,:2].min() - 5, pos_ref[...,:2].max() + 5
  for ax in [ax1]:
    ax.set_xlim(xy_min, xy_max); ax.set_ylim(xy_min, xy_max)
    ax.axis(False)

  ims = []
  for k,(x,s) in enumerate(zip(pos,seq)):
    ims.append([cf.add_text("colored by N→C", ax1),
                cf.plot_pseudo_3D(x, ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max),
                ax2.imshow(s.T, animated=True, cmap="bwr_r",vmin=-1, vmax=1)])
    
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=interval)
  plt.close()
  return ani.to_html5_video()

setup model

In [4]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
# NOTE: for denovo-designed proteins we find 1 recycle is enough to predict them
model_config.model.num_recycle = 1
model_config.data.common.num_recycle = 1

# backprop through recycles
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# number of sequences
model_config.data.common.max_extra_msa = 1
model_config.data.eval.max_msa_clusters = 1
model_config.data.eval.masked_msa_replace_fraction = 0

# dropout
# model_config = set_dropout(model_config, 0.0)

# setup model
model_params = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
model_runner = model.RunModel(model_config, model_params[0], is_training=True)

# load the other model_params (during optimization, we randomly pick which model to use)
for model_name in ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]:
  params = data.get_model_haiku_params(model_name, '.')
  model_params.append({k: params[k] for k in model_runner.params.keys()})


example

In [5]:
# setup inputs
example = "1QYS"
!wget -qnc https://files.rcsb.org/view/{example}.pdb
protein_obj = protein.from_pdb_string(pdb_to_string(f"{example}.pdb"))

batch = {'aatype': protein_obj.aatype,
          'all_atom_positions': protein_obj.atom_positions,
          'all_atom_mask': protein_obj.atom_mask}
batch.update(all_atom.atom37_to_frames(**batch))

query_sequence = "".join([order_restype[a] for a in protein_obj.aatype])
starting_sequence = query_sequence

# one_hot_encode
feature_dict = {
    **pipeline.make_sequence_features(sequence=starting_sequence,description="none",num_res=len(starting_sequence)),
    **pipeline.make_msa_features(msas=[[starting_sequence]],deletion_matrices=[[[0]*len(starting_sequence)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)
wt_seq = jax.nn.one_hot(inputs["aatype"][0],20)

setup gradient

In [6]:
def get_grad_fn(model_runner, inputs):

  def mod(params, key, model_params, opt):
    ############################
    # set amino acid sequence
    ############################
    seq_logits = params["seq"]
    seq = soft_seq(seq_logits)
    
    mask = opt["mask"][:,None]
    pseudo_seq = mask * seq + (1-mask) * seq_logits

    inputs_mod = inputs.copy()
    update_seq(pseudo_seq, inputs_mod)
    
    ####################
    # set sidechains identity
    ####################
    N,L = inputs_mod["aatype"].shape[:2]
    ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)

    aatype = jnp.zeros((N,L,21)).at[...,:20].set(seq)
    aatype_ala = jnp.zeros((N,L,21)).at[:].set(ALA)
    aatype_pseudo = mask * aatype + (1-mask) * aatype_ala
    update_aatype(aatype_pseudo, inputs_mod)
    
    # get output
    outputs = model_runner.apply(model_params, key, inputs_mod)
            
    # losses
    dgram_loss = get_dgram_loss(batch, outputs, model_config=model_runner.config)
    fape_loss = get_fape_loss(batch, outputs, model_config=model_runner.config)

    # we are just monitoring rmsd, but it's not used in loss
    rmsd_loss = jnp_rmsd(protein_obj.atom_positions[:,1,:],
                         outputs["structure_module"]["final_atom_positions"][:,1,:])

    loss = dgram_loss # + fape_loss
    outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
            "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

    return loss, ({"losses":{"rmsd":rmsd_loss,
                             "dgram":dgram_loss,
                             "fape":fape_loss},
                   "outputs":outs,"seq":seq,"pseudo_seq":pseudo_seq})
  
  return jax.value_and_grad(mod, has_aux=True, argnums=0)

where the magic happens

In [7]:
# gradient function
grad_fn = jax.jit(get_grad_fn(model_runner, inputs))

# stack model params
model_params_multi = jax.tree_multimap(lambda *values: jnp.stack(values, axis=0), *model_params)
grad_fn_multi = jax.jit(jax.vmap(grad_fn,(None,None,0,None)))

In [8]:
init_fun, update_fun, get_params = adam(step_size=1e-2)
def step(i, state, key, model_params, opt, multi=False):

  if multi:
    (loss, outs), grad = grad_fn_multi(get_params(state), key, model_params, opt)
    grad = jax.tree_map(lambda x: x.mean(0), grad)
    outs["losses"] = jax.tree_map(lambda x: x.mean(0), outs["losses"])
    loss = loss.mean(0)
  else:
    (loss, outs), grad = grad_fn(get_params(state), key, model_params, opt)

  grad["seq"] = grad["seq"] / jnp.sqrt(jnp.square(grad["seq"]).sum())
  state = update_fun(i, grad, state)
  return state, outs

# design

In [10]:
# - if True, during optimization gradients are sampled from the 5 models
# - if False, all given models are used and gradients are averaged
SAMPLE_GRADIENT = True

In [13]:
L,A = wt_seq.shape
key = jax.random.PRNGKey(0)
mask = jnp.ones((L,))
seq = 0.01 * jax.random.normal(key,(L,A))
state = init_fun({"seq":seq})

In [14]:
XYZ,SEQ = [],[]
RMSD,BEST_outs = np.inf,None
for i in range(200): # number of iterations

  key,subkey = jax.random.split(key)
  if SAMPLE_GRADIENT:
    ## sample gradients from one of the 5 models
    n = jax.random.randint(subkey,[],0,5)
    state, outs = step(i, state, subkey, model_params[n], {"mask":mask})
  else:
    ## take mean gradient across 5 models
    state, outs = step(i, state, subkey, model_params_multi, {"mask":mask}, multi=True)

  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()
  losses = outs["losses"]
  
  if (i+1) % 10 == 0:
    print(f'{i+1} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

  # save for animation
  if SAMPLE_GRADIENT:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"]))
  else:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][0,:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"][0]))

  if mask.sum() == len(seq) and losses["rmsd"] < RMSD:
    RMSD = losses["rmsd"]
    BEST_outs = outs

10 92 3.536 1.939 13.393 0.043
20 92 3.318 1.914 15.627 0.054
30 92 3.499 1.902 9.705 0.043
40 92 3.435 1.933 12.587 0.054
50 92 3.180 1.654 10.581 0.065
60 92 3.117 1.525 9.741 0.130
70 92 2.963 1.247 5.060 0.130
80 92 2.948 1.143 4.174 0.141
90 92 2.995 1.210 5.295 0.163
100 92 2.898 1.135 4.259 0.109
110 92 2.946 1.285 7.070 0.076
120 92 2.748 0.978 3.184 0.054
130 92 2.668 0.943 2.671 0.065
140 92 2.841 1.108 3.344 0.098
150 92 2.821 1.222 3.682 0.087
160 92 2.574 0.798 2.125 0.065
170 92 2.687 1.066 7.358 0.087
180 92 2.824 1.023 3.313 0.087
190 92 2.566 0.824 2.260 0.109
200 92 2.482 0.764 2.212 0.109


In [15]:
HTML(make_animation(XYZ,SEQ,pos_ref=batch["all_atom_positions"][:,1,:]))

In [16]:
if SAMPLE_GRADIENT:
  save_pdb(BEST_outs)
else:
  outs_ = BEST_outs.copy()
  outs_["outputs"] = jax.tree_map(lambda x:x[0], outs_["outputs"])
  outs_["seq"] = outs_["seq"][0]
  save_pdb(outs_)

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [None]:
if SAMPLE_GRADIENT:
  print("".join([order_restype[a] for a in BEST_outs["seq"].argmax(-1)]))
else:
  print("".join([order_restype[a] for a in BEST_outs["seq"][0].argmax(-1)]))

#experimental 
For more complex targets, we find directly optimizing a single one_hot encoded sequence (even with gumbel-st) to be very challenging. Instead, we start with a continious representation and then one residue at a time switch to one_hot in second round of optimization.

**design adversarial vector**

In [17]:
L,A = wt_seq.shape
key = jax.random.PRNGKey(0)
mask = jnp.zeros((L,))
seq = jnp.zeros((L,A))
state = init_fun({"seq":seq})

In [18]:
XYZ,SEQ = [],[]
for i in range(200): # number of iterations

  key,subkey = jax.random.split(key)
  
  if SAMPLE_GRADIENT:
    ## sample gradients from one of the 5 models
    n = jax.random.randint(subkey,[],0,5)
    state, outs = step(i, state, subkey, model_params[n], {"mask":mask})
  else:
    ## take mean gradient across 5 models
    state, outs = step(i, state, subkey, model_params_multi, {"mask":mask}, multi=True)

  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()

  # save for animation
  if SAMPLE_GRADIENT:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"]))
  else:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][0,:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"][0]))

  if (i+1) % 10 == 0:
    losses = outs["losses"]
    print(f'{i+1} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

10 0 4.737 4.361 35.908 0.043
20 0 4.119 3.736 29.720 0.043
30 0 3.726 1.997 10.962 0.033
40 0 3.448 1.986 11.737 0.033
50 0 3.318 1.687 6.805 0.054
60 0 3.185 1.648 9.247 0.087
70 0 3.010 1.488 6.746 0.098
80 0 2.904 1.333 7.721 0.098
90 0 2.498 0.861 2.216 0.098
100 0 2.347 0.686 1.817 0.130
110 0 2.195 0.537 1.532 0.120
120 0 2.230 0.529 1.583 0.120
130 0 2.009 0.470 1.296 0.120
140 0 1.872 0.386 0.925 0.109
150 0 1.902 0.417 1.341 0.141
160 0 1.762 0.314 0.860 0.174
170 0 1.679 0.290 0.783 0.174
180 0 1.700 0.304 0.844 0.163
190 0 1.686 0.307 0.820 0.185
200 0 1.557 0.297 0.743 0.185


In [19]:
HTML(make_animation(XYZ,SEQ,pos_ref=batch["all_atom_positions"][:,1,:]))

design one_hot

In [20]:
XYZ,SEQ = [],[]
RMSD,BEST_outs = np.inf,None
for i in range(200,1000): # number of iterations (might be overkill)

  key,subkey = jax.random.split(key)
  if SAMPLE_GRADIENT:
    ## sample gradients from one of the 5 models
    n = jax.random.randint(subkey,[],0,5)
    state, outs = step(i, state, subkey, model_params[n], {"mask":mask})
  else:
    ## take mean gradient across 5 models
    state, outs = step(i, state, subkey, model_params_multi, {"mask":mask}, multi=True)

  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()

  # save for animation
  if SAMPLE_GRADIENT:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"]))
  else:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][0,:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"][0]))

  if (i+1) % 5 == 0:
    # pick random position to flip to one_hot
    if mask.mean() < 1:
      mask = mask.at[np.random.choice(np.where(mask == 0)[0])].set(1)    

    losses = outs["losses"]
    print(f'{i+1} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

  if mask.sum() == len(seq) and losses["rmsd"] < RMSD:
    RMSD = losses["rmsd"]
    BEST_outs = outs

205 1 1.656 0.290 0.799 0.185
210 2 1.583 0.303 0.794 0.185
215 3 1.514 0.273 0.578 0.185
220 4 1.558 0.273 0.759 0.185
225 5 1.514 0.270 0.637 0.196
230 6 1.475 0.254 0.615 0.207
235 7 1.493 0.278 0.670 0.196
240 8 1.509 0.271 0.620 0.196
245 9 1.501 0.289 0.611 0.196
250 10 1.513 0.276 0.638 0.196
255 11 2.176 0.460 2.613 0.185
260 12 1.427 0.277 0.563 0.174
265 13 2.159 0.447 2.604 0.174
270 14 1.502 0.311 0.594 0.174
275 15 1.402 0.286 0.539 0.174
280 16 1.449 0.262 0.561 0.174
285 17 1.558 0.290 0.596 0.163
290 18 1.422 0.269 0.546 0.152
295 19 1.421 0.348 0.596 0.163
300 20 1.475 0.253 0.622 0.174
305 21 1.385 0.251 0.529 0.185
310 22 1.395 0.255 0.557 0.185
315 23 1.376 0.247 0.511 0.185
320 24 1.475 0.239 0.473 0.196
325 25 1.337 0.252 0.496 0.174
330 26 1.375 0.249 0.540 0.174
335 27 1.316 0.239 0.506 0.174
340 28 2.444 0.637 2.812 0.174
345 29 1.330 0.255 0.516 0.174
350 30 1.351 0.243 0.538 0.196
355 31 1.347 0.247 0.537 0.196
360 32 1.526 0.271 0.607 0.196
365 33 1.396 0.25

In [21]:
HTML(make_animation(XYZ,SEQ,pos_ref=batch["all_atom_positions"][:,1,:]))

In [24]:
if SAMPLE_GRADIENT:
  save_pdb(BEST_outs)
else:
  outs_ = BEST_outs.copy()
  outs_["outputs"] = jax.tree_map(lambda x:x[0], outs_["outputs"])
  outs_["seq"] = outs_["seq"][0]
  save_pdb(outs_)

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [28]:
if SAMPLE_GRADIENT:
  print("".join([order_restype[a] for a in BEST_outs["seq"].argmax(-1)]))
else:
  print("".join([order_restype[a] for a in BEST_outs["seq"][0].argmax(-1)]))

AIIIYVVMINDGTVHHLAWVTSDWRMANRISRFISRWVKAMICPFVCIAMVMRTRAMARVMADYWTRWCHRNGCTNPKVSFVGNVVIVCGVM
