<a href="https://colab.research.google.com/github/sokrypton/af_backprop/blob/main/examples/fixbb_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Fixed backbone Design
Note: this notebook was not optimized or thoroughly tested for protein design. It was designed to test backprop functionality.

install

In [1]:
%%bash
if [ ! -d af_backprop ]; then
  git clone https://github.com/sokrypton/af_backprop.git
  pip -q install biopython dm-haiku ml-collections py3Dmol
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi
wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py

import libraries

In [2]:
import sys
sys.path.append('/content/af_backprop')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol
from IPython.display import HTML

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline, templates
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from utils import *
import colabfold as cf

In [3]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

def make_animation(positions, seq, pos_ref=None, line_w=2.0, dpi=100, interval=60):

  def align(P, Q):
    p = P - P.mean(0,keepdims=True)
    q = Q - Q.mean(0,keepdims=True)
    return p @ cf.kabsch(p,q)

  if pos_ref is None: pos_ref = positions[-1]
  pos_ref = pos_ref - pos_ref.mean(0,keepdims=True)
  pos_ref = pos_ref @ cf.kabsch(pos_ref,pos_ref,return_v=True)

  new_positions = []
  for i in range(len(positions)):
    new_positions.append(align(positions[i],pos_ref))
  pos = np.asarray(new_positions)

  fig = plt.figure()
  gs = GridSpec(4,1, figure=fig)
  ax1,ax2 = fig.add_subplot(gs[:3,:]),fig.add_subplot(gs[3:,:])
  fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 0.9, left = 0.1, hspace = 0, wspace = 0)
  fig.set_figwidth(5);fig.set_figheight(6)
  fig.set_dpi(dpi)

  ax2.set_xlabel("positions")
  ax2.set_yticks([])

  z_min,z_max = pos_ref[...,2].min(),pos_ref[...,2].max()
  xy_min,xy_max = pos_ref[...,:2].min() - 5, pos_ref[...,:2].max() + 5
  for ax in [ax1]:
    ax.set_xlim(xy_min, xy_max); ax.set_ylim(xy_min, xy_max)
    ax.axis(False)

  ims = []
  for k,(x,s) in enumerate(zip(pos,seq)):
    ims.append([cf.add_text("colored by N→C", ax1),
                cf.plot_pseudo_3D(x, ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max),
                ax2.imshow(s.T, animated=True, cmap="bwr_r",vmin=-1, vmax=1)])
    
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=interval)
  plt.close()
  return ani.to_html5_video()

setup model

In [4]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
# NOTE: for denovo-designed proteins we find 1 recycle is enough to predict them
model_config.model.num_recycle = 1
model_config.data.common.num_recycle = 1

# backprop through recycles
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# number of sequences
model_config.data.common.max_extra_msa = 1
model_config.data.eval.max_msa_clusters = 1
model_config.data.eval.masked_msa_replace_fraction = 0

# dropout
# model_config = set_dropout(model_config, 0.0)

# setup model
model_params = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
model_runner = model.RunModel(model_config, model_params[0], is_training=True)

# load the other model_params (during optimization, we randomly pick which model to use)
for model_name in ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]:
  params = data.get_model_haiku_params(model_name, '.')
  model_params.append({k: params[k] for k in model_runner.params.keys()})


example

In [5]:
# setup inputs
example = "1QYS"
!wget -qnc https://files.rcsb.org/view/{example}.pdb
protein_obj = protein.from_pdb_string(pdb_to_string(f"{example}.pdb"))

batch = {'aatype': protein_obj.aatype,
          'all_atom_positions': protein_obj.atom_positions,
          'all_atom_mask': protein_obj.atom_mask}
batch.update(all_atom.atom37_to_frames(**batch))

query_sequence = "".join([order_restype[a] for a in protein_obj.aatype])
starting_sequence = query_sequence

# one_hot_encode
feature_dict = {
    **pipeline.make_sequence_features(sequence=starting_sequence,description="none",num_res=len(starting_sequence)),
    **pipeline.make_msa_features(msas=[[starting_sequence]],deletion_matrices=[[[0]*len(starting_sequence)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)
wt_seq = jax.nn.one_hot(inputs["aatype"][0],20)

setup gradient

In [6]:
def get_grad_fn(model_runner, inputs):

  def mod(params, key, model_params, opt):
    ############################
    # set amino acid sequence
    ############################
    seq_logits = params["seq"]
    seq = soft_seq(seq_logits)
    
    mask = opt["mask"][:,None]
    pseudo_seq = mask * seq + (1-mask) * seq_logits

    inputs_mod = inputs.copy()
    update_seq(pseudo_seq, inputs_mod)
    
    ####################
    # set sidechains identity
    ####################
    N,L = inputs_mod["aatype"].shape[:2]
    ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)

    aatype = jnp.zeros((N,L,21)).at[...,:20].set(seq)
    aatype_ala = jnp.zeros((N,L,21)).at[:].set(ALA)
    aatype_pseudo = mask * aatype + (1-mask) * aatype_ala
    update_aatype(aatype_pseudo, inputs_mod)
    
    # get output
    outputs = model_runner.apply(model_params, key, inputs_mod)
            
    # losses
    dgram_loss = get_dgram_loss(batch, outputs, model_config=model_runner.config)
    fape_loss = get_fape_loss(batch, outputs, model_config=model_runner.config)

    # we are just monitoring rmsd, but it's not used in loss
    rmsd_loss = jnp_rmsd(protein_obj.atom_positions[:,1,:],
                         outputs["structure_module"]["final_atom_positions"][:,1,:])

    loss = dgram_loss # + fape_loss
    outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
            "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

    return loss, ({"losses":{"rmsd":rmsd_loss,
                             "dgram":dgram_loss,
                             "fape":fape_loss},
                   "outputs":outs,"seq":seq,"pseudo_seq":pseudo_seq})
  
  return jax.value_and_grad(mod, has_aux=True, argnums=0)

where the magic happens

In [7]:
# gradient function
grad_fn = jax.jit(get_grad_fn(model_runner, inputs))

# stack model params
model_params_multi = jax.tree_multimap(lambda *values: jnp.stack(values, axis=0), *model_params)
grad_fn_multi = jax.jit(jax.vmap(grad_fn,(None,None,0,None)))

In [8]:
init_fun, update_fun, get_params = adam(step_size=1e-2)
def step(i, state, key, model_params, opt, multi=False):

  if multi:
    (loss, outs), grad = grad_fn_multi(get_params(state), key, model_params, opt)
    grad = jax.tree_map(lambda x: x.mean(0), grad)
    outs["losses"] = jax.tree_map(lambda x: x.mean(0), outs["losses"])
    loss = loss.mean(0)
  else:
    (loss, outs), grad = grad_fn(get_params(state), key, model_params, opt)

  grad["seq"] = grad["seq"] / jnp.sqrt(jnp.square(grad["seq"]).sum())
  state = update_fun(i, grad, state)
  return state, outs

# design

In [7]:
L,A = wt_seq.shape
key = jax.random.PRNGKey(0)
mask = jnp.ones((L,))
seq = 0.01 * jax.random.normal(key,(L,A))
state = init_fun({"seq":seq})

In [8]:
SAMPLE_GRADIENT = False
XYZ,SEQ = [],[]
RMSD,BEST_outs = np.inf,None
for i in range(200): # number of iterations

  key,subkey = jax.random.split(key)
  if SAMPLE_GRADIENT:
    ## sample gradients from one of the 5 models
    n = jax.random.randint(subkey,[],0,5)
    state, outs = step(i, state, subkey, model_params[n], {"mask":mask})
  else:
    ## take mean gradient across 5 models
    state, outs = step(i, state, subkey, model_params_multi, {"mask":mask}, multi=True)

  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()
  
  if (i+1) % 10 == 0:
    losses = outs["losses"]
    print(f'{i+1} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

  # save for animation
  if SAMPLE_GRADIENT:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"]))
  else:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][0,:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"][0]))

  if mask.sum() == len(seq) and losses["rmsd"] < RMSD:
    RMSD = losses["rmsd"]
    BEST_outs = outs

10 92 3.284 1.807 10.728 0.087
20 92 2.895 1.166 4.774 0.109
30 92 2.882 1.131 6.468 0.098
40 92 2.654 0.857 2.426 0.098
50 92 2.631 0.900 3.674 0.087
60 92 2.593 0.826 2.522 0.098
70 92 2.596 0.833 3.223 0.098
80 92 2.769 1.092 6.206 0.109
90 92 2.629 0.755 2.529 0.120
100 92 2.572 0.829 2.733 0.130
110 92 2.557 0.750 2.657 0.109
120 92 3.028 1.297 8.663 0.098
130 92 2.270 0.582 1.692 0.130
140 92 2.379 0.607 1.882 0.109
150 92 2.358 0.639 1.756 0.076
160 92 2.262 0.558 1.705 0.076
170 92 2.389 0.668 2.122 0.087
180 92 2.360 0.564 1.761 0.076
190 92 2.963 1.066 6.312 0.087
200 92 2.296 0.561 1.805 0.054


In [11]:
HTML(make_animation(XYZ,SEQ,pos_ref=batch["all_atom_positions"][:,1,:]))

In [14]:
if SAMPLE_GRADIENT:
  save_pdb(outs)
else:
  outs_ = outs.copy()
  outs_["outputs"] = jax.tree_map(lambda x:x[0], outs_["outputs"])
  outs_["seq"] = outs_["seq"][0]
  save_pdb(outs_)

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

#experimental 
For more complex targets, we find directly optimizing a single one_hot encoded sequence (even with gumbel-st) to be very challenging. Instead, we start with a continious representation and then one residue at a time switch to one_hot in second round of optimization.

**design adversarial vector**

In [9]:
L,A = wt_seq.shape
key = jax.random.PRNGKey(0)
mask = jnp.zeros((L,))
seq = jnp.zeros((L,A))
state = init_fun({"seq":seq})

In [10]:
SAMPLE_GRADIENT = False
XYZ,SEQ = [],[]
for i in range(200): # number of iterations

  key,subkey = jax.random.split(key)
  
  if SAMPLE_GRADIENT:
    ## sample gradients from one of the 5 models
    n = jax.random.randint(subkey,[],0,5)
    state, outs = step(i, state, subkey, model_params[n], {"mask":mask})
  else:
    ## take mean gradient across 5 models
    state, outs = step(i, state, subkey, model_params_multi, {"mask":mask}, multi=True)

  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()

  # save for animation
  if SAMPLE_GRADIENT:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"]))
  else:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][0,:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"][0]))

  if (i+1) % 10 == 0:
    losses = outs["losses"]
    print(f'{i+1} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

10 0 4.880 3.752 28.739 0.043
20 0 3.770 2.349 18.316 0.033
30 0 3.333 1.913 12.688 0.054
40 0 3.062 1.754 5.756 0.043
50 0 2.859 1.570 3.406 0.043
60 0 2.767 1.446 3.977 0.054
70 0 2.641 1.296 2.842 0.054
80 0 2.519 0.966 5.351 0.076
90 0 2.396 0.882 4.202 0.087
100 0 2.265 0.745 2.208 0.120
110 0 2.098 0.517 1.609 0.130
120 0 1.947 0.401 1.183 0.120
130 0 1.876 0.396 1.122 0.109
140 0 1.910 0.409 1.416 0.120
150 0 1.761 0.350 0.918 0.120
160 0 1.654 0.308 0.816 0.120
170 0 1.565 0.286 0.686 0.130
180 0 1.522 0.275 0.649 0.141
190 0 1.412 0.269 0.558 0.152
200 0 1.370 0.256 0.551 0.152


In [11]:
HTML(make_animation(XYZ,SEQ,pos_ref=batch["all_atom_positions"][:,1,:]))

design one_hot

In [12]:
SAMPLE_GRADIENT = False
XYZ,SEQ = [],[]
RMSD,BEST_outs = np.inf,None
for i in range(200,1000): # number of iterations

  key,subkey = jax.random.split(key)
  if SAMPLE_GRADIENT:
    ## sample gradients from one of the 5 models
    n = jax.random.randint(subkey,[],0,5)
    state, outs = step(i, state, subkey, model_params[n], {"mask":mask})
  else:
    ## take mean gradient across 5 models
    state, outs = step(i, state, subkey, model_params_multi, {"mask":mask}, multi=True)

  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()

  # save for animation
  if SAMPLE_GRADIENT:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"]))
  else:
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][0,:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"][0]))

  if (i+1) % 5 == 0:
    # pick random position to flip to one_hot
    if mask.mean() < 1:
      mask = mask.at[np.random.choice(np.where(mask == 0)[0])].set(1)    

    losses = outs["losses"]
    print(f'{i+1} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

  if mask.sum() == len(seq) and losses["rmsd"] < RMSD:
    RMSD = losses["rmsd"]
    BEST_outs = outs

205 1 1.366 0.252 0.569 0.163
210 2 1.325 0.238 0.487 0.152
215 3 1.344 0.236 0.508 0.152
220 4 1.289 0.240 0.486 0.163
225 5 1.291 0.237 0.497 0.163
230 6 1.305 0.279 0.498 0.152
235 7 1.281 0.251 0.488 0.141
240 8 1.271 0.245 0.493 0.152
245 9 1.230 0.264 0.451 0.152
250 10 1.256 0.232 0.477 0.152
255 11 1.249 0.256 0.453 0.141
260 12 1.230 0.260 0.459 0.141
265 13 1.225 0.284 0.440 0.141
270 14 1.184 0.230 0.417 0.152
275 15 1.193 0.233 0.430 0.141
280 16 1.203 0.251 0.438 0.130
285 17 1.188 0.248 0.426 0.141
290 18 1.192 0.232 0.408 0.141
295 19 1.229 0.291 0.456 0.141
300 20 1.321 0.282 0.622 0.141
305 21 1.252 0.263 0.476 0.141
310 22 1.226 0.243 0.466 0.141
315 23 1.446 0.287 1.002 0.141
320 24 1.239 0.227 0.484 0.141
325 25 1.242 0.229 0.451 0.152
330 26 1.246 0.218 0.435 0.152
335 27 1.328 0.239 0.535 0.152
340 28 1.284 0.224 0.501 0.152
345 29 1.366 0.266 0.616 0.163
350 30 1.267 0.217 0.466 0.130
355 31 1.243 0.203 0.459 0.130
360 32 1.269 0.206 0.445 0.130
365 33 1.254 0.21

In [13]:
HTML(make_animation(XYZ,SEQ,pos_ref=batch["all_atom_positions"][:,1,:]))

In [15]:
if SAMPLE_GRADIENT:
  save_pdb(outs)
else:
  outs_ = outs.copy()
  outs_["outputs"] = jax.tree_map(lambda x:x[0], outs_["outputs"])
  outs_["seq"] = outs_["seq"][0]
  save_pdb(outs_)

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [18]:
if SAMPLE_GRADIENT:
  print("".join([order_restype[a] for a in seq]))
else:
  print("".join([order_restype[a] for a in seq[0]]))

KIRIWVIMPRRNSMIHYHFVTEDWQSFNAIIRAIMGVIAAWKPRHACIWVDCPDTEECSRIGAYMTRIFTAAGYTNCRVYFYGNHVYIQCTP
