<a href="https://colab.research.google.com/github/sokrypton/af_backprop/blob/main/examples/fixbb_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

install

In [1]:
%%bash
pip -q install biopython dm-haiku ml-collections py3Dmol
if [ ! -d af_backprop ]; then
  git clone https://github.com/sokrypton/af_backprop.git
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi

import libraries

In [1]:
import sys
sys.path.append('/content/af_backprop')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline, templates
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from utils import *

setup model

In [2]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
model_config.model.num_recycle = 1
model_config.data.common.num_recycle = 1

# backprop through recycles
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# number of sequences
model_config.data.common.max_extra_msa = 1
model_config.data.eval.max_msa_clusters = 1
model_config.data.eval.masked_msa_replace_fraction = 0

# dropout
model_config = set_dropout(model_config, 0.0)

# setup model
model_params = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
model_runner = model.RunModel(model_config, model_params[0], is_training=True)

# load the other model_params (during optimization, we randomly pick which model to use)
for model_name in ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]:
  params = data.get_model_haiku_params(model_name, '.')
  model_params.append({k: params[k] for k in model_runner.params.keys()})

example

In [3]:
# setup inputs
example = "1QYS"
!wget -qnc https://files.rcsb.org/view/{example}.pdb
protein_obj = protein.from_pdb_string(pdb_to_string(f"{example}.pdb"))

batch = {'aatype': protein_obj.aatype,
          'all_atom_positions': protein_obj.atom_positions,
          'all_atom_mask': protein_obj.atom_mask}
batch.update(all_atom.atom37_to_frames(**batch))

query_sequence = "".join([order_restype[a] for a in protein_obj.aatype])
starting_sequence = query_sequence

# one_hot_encode
feature_dict = {
    **pipeline.make_sequence_features(sequence=starting_sequence,description="none",num_res=len(starting_sequence)),
    **pipeline.make_msa_features(msas=[[starting_sequence]],deletion_matrices=[[[0]*len(starting_sequence)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)
wt_seq = jax.nn.one_hot(inputs["aatype"][0],20)

loss functions

In [4]:
def get_dgram_loss(batch, outputs):
  pb, pb_mask = model.modules.pseudo_beta_fn(batch["aatype"],
                                             batch["all_atom_positions"],
                                             batch["all_atom_mask"])
  
  dgram_loss = model.modules._distogram_log_loss(outputs["distogram"]["logits"],
                                                 outputs["distogram"]["bin_edges"],
                                                 batch={"pseudo_beta":pb,"pseudo_beta_mask":pb_mask},
                                                 num_bins=model_config.model.heads.distogram.num_bins)
  return dgram_loss["loss"]

def get_fape_loss(batch, outputs, use_clamped_fape=False):

  sub_batch = jax.tree_map(lambda x: x, batch)
  sub_batch["use_clamped_fape"] = use_clamped_fape
  loss = {"loss":0.0}    
  folding.backbone_loss(loss, sub_batch, outputs["structure_module"], model_config.model.heads.structure_module)
  return loss["loss"]

setup gradient

In [43]:
def get_grad_fn(model_runner, inputs):
  def mod(params, key, model_params, opt):
    ############################
    # set amino acid sequence
    ############################
    seq_logits = params["seq"]
    seq = soft_seq(seq_logits)
    
    mask = opt["mask"][:,None]
    pseudo_seq = mask * seq + (1-mask) * seq_logits

    inputs_mod = inputs.copy()
    update_seq(pseudo_seq, inputs_mod)
    
    ####################
    # set sidechains identity
    ####################
    N,L = inputs_mod["aatype"].shape[:2]
    ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)

    aatype = jnp.zeros((N,L,21)).at[...,:20].set(seq)
    aatype_ala = jnp.zeros((N,L,21)).at[:].set(ALA)
    aatype_pseudo = mask * aatype + (1-mask) * aatype_ala
    update_aatype(aatype_pseudo, inputs_mod)
    
    # get output
    outputs = model_runner.apply(model_params, key, inputs_mod)
            
    # losses
    dgram_loss = get_dgram_loss(batch, outputs)
    fape_loss = get_fape_loss(batch, outputs)

    # we are just monitoring rmsd, but it's not used in loss
    rmsd_loss = jnp_rmsd(protein_obj.atom_positions[:,1,:],
                         outputs["structure_module"]["final_atom_positions"][:,1,:])

    loss = dgram_loss # + fape_loss
    outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
            "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

    return loss, ({"losses":{"rmsd":rmsd_loss,
                             "dgram":dgram_loss,
                             "fape":fape_loss},
                   "outputs":outs,"seq":seq})
  
  return jax.value_and_grad(mod, has_aux=True, argnums=0)

where the magic happens

In [44]:
# gradient function
grad_fn = jax.jit(get_grad_fn(model_runner, inputs))

In [45]:
init_fun, update_fun, get_params = adam(step_size=1e-2)
def step(i, state, key, model_params, opt):
  (loss, outs), grad = grad_fn(get_params(state), key, model_params=model_params, opt=opt)
  grad["seq"] = grad["seq"] / jnp.sqrt(jnp.square(grad["seq"]).sum())
  state = update_fun(i, grad, state)
  return state, outs

For complex targets, we find directly optimizing a single one_hot encoded sequence (even with gumbel-st) to be very challenging. Instead, we start with a continious representation and then one residue at a time switch to one_hot in second round of optimization.

**design adversarial vector**

In [46]:
L,A = wt_seq.shape
key = jax.random.PRNGKey(0)

seq = jnp.zeros((L,A))
mask = jnp.zeros((L,))
state = init_fun({"seq":seq})

In [47]:
for i in range(500): # number of iterations (might be overkill)
  key,subkey = jax.random.split(key)
  n = 0 #jax.random.randint(subkey,[],0,5) # select which model to use
  state, outs = step(i, state, subkey, model_params[n], {"mask":mask})
  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()
  if (i+1) % 10 == 0:
    losses = outs["losses"]
    print(f'{i+1} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

10 0 4.130 2.908 21.409 0.011
20 0 3.445 2.268 17.489 0.087
30 0 3.245 1.863 15.562 0.065
40 0 3.118 1.593 10.177 0.043
50 0 2.907 1.535 8.673 0.043
60 0 2.385 0.827 1.950 0.087
70 0 2.174 0.578 1.412 0.076
80 0 2.026 0.562 1.284 0.109
90 0 1.891 0.476 1.088 0.120
100 0 1.814 0.455 1.021 0.130
110 0 1.763 0.433 0.975 0.120
120 0 1.768 0.442 1.013 0.109
130 0 1.676 0.490 0.766 0.109
140 0 1.658 0.439 0.833 0.120
150 0 1.614 0.492 0.788 0.109
160 0 1.605 0.401 0.811 0.109
170 0 1.537 0.523 0.648 0.109
180 0 1.482 0.470 0.598 0.120
190 0 1.480 0.431 0.589 0.098
200 0 1.458 0.389 0.602 0.087
210 0 1.436 0.278 0.551 0.098
220 0 1.437 0.256 0.548 0.098
230 0 1.411 0.266 0.601 0.098
240 0 1.397 0.273 0.586 0.087
250 0 1.376 0.266 0.551 0.098
260 0 1.351 0.246 0.563 0.098
270 0 1.346 0.235 0.563 0.109
280 0 1.314 0.243 0.489 0.109
290 0 1.302 0.245 0.503 0.098
300 0 1.280 0.304 0.488 0.076
310 0 1.289 0.275 0.492 0.076
320 0 1.266 0.326 0.528 0.098
330 0 1.236 0.366 0.506 0.098
340 0 1.220 0.4

In [48]:
save_pdb(outs)
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

design one_hot

In [49]:
RMSD = np.inf
BEST_outs = None
for n in range(500,1500):
  key,subkey = jax.random.split(key)
  n = 0 #jax.random.randint(subkey,[],0,5) # select which model to use
  state, outs = step(i, state, subkey, model_params[0], {"mask":mask})
  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()
  i += 1
  losses = outs["losses"]
  if mask.sum() == len(seq) and losses["rmsd"] < RMSD:
    RMSD = losses["rmsd"]
    BEST_outs = outs
  if i % 10 == 0:
    # pick random position to flip to one_hot
    if mask.mean() < 1:
      mask = mask.at[np.random.choice(np.where(mask == 0)[0])].set(1)    
    print(f'{i} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

500 1 1.149 0.247 0.463 0.120
510 2 1.144 0.235 0.451 0.109
520 3 1.181 0.245 0.440 0.120
530 4 1.178 0.270 0.481 0.120
540 5 1.181 0.213 0.467 0.120
550 6 1.176 0.197 0.462 0.130
560 7 1.205 0.211 0.464 0.130
570 8 1.237 0.216 0.509 0.130
580 9 1.201 0.207 0.470 0.130
590 10 1.195 0.208 0.474 0.141
600 11 1.224 0.206 0.471 0.141
610 12 1.191 0.198 0.422 0.152
620 13 1.183 0.199 0.442 0.152
630 14 1.187 0.203 0.431 0.163
640 15 1.249 0.203 0.462 0.152
650 16 1.191 0.197 0.444 0.152
660 17 1.192 0.193 0.418 0.152
670 18 1.196 0.184 0.416 0.152
680 19 1.183 0.184 0.431 0.152
690 20 1.255 0.189 0.432 0.152
700 21 1.280 0.206 0.454 0.163
710 22 1.300 0.206 0.467 0.152
720 23 2.113 0.500 2.423 0.152
730 24 1.433 0.238 0.553 0.152
740 25 1.623 0.294 0.785 0.141
750 26 1.404 0.254 0.613 0.130
760 27 1.397 0.264 0.571 0.130
770 28 1.417 0.299 0.639 0.120
780 29 1.337 0.309 0.538 0.152
790 30 1.439 0.267 0.611 0.163
800 31 1.398 0.252 0.621 0.163
810 32 1.442 0.272 0.684 0.152
820 33 1.369 0.24

In [50]:
RMSD

DeviceArray(0.9135631, dtype=float32)

In [51]:
for model_param in model_params:
  params = {"seq":BEST_outs["seq"]}
  (loss, outs), grad = grad_fn(params, key, model_params=model_param, opt={"mask":mask})
  print(outs["losses"]["rmsd"])

0.9135631
0.92909205
1.1264701
2.209554
2.3352773


In [52]:
save_pdb(BEST_outs)
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [53]:
"".join([order_restype[a] for a in seq])

'QIVIICQFKKRHKQYQFNWTDTSLASMSTPFNICAAIYDAMEYKTFRMTIKAYDKELCAAMSGVIHTICRNIGMTTMKHYNNGNQIIVQCTM'