<a href="https://colab.research.google.com/github/sokrypton/AccAdam_TF2/blob/main/afDesign_semigreedy_refinement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%bash
if [ ! -d afDesign ]; then
  git clone https://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@github.com/sokrypton/afDesign.git
  pip -q install py3Dmol biopython dm-haiku ml_collections
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi

Cloning into 'afDesign'...


In [None]:
import os
import sys
sys.path.append('afDesign')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from alphafold.data import prep_inputs
from utils import *

In [None]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
model_config.model.num_recycle = 3
model_config.data.common.num_recycle = 3

# backprop through recycles
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# number of sequences
N = 1
model_config.data.eval.max_msa_clusters = N
model_config.data.common.max_extra_msa = 1
model_config.data.eval.masked_msa_replace_fraction = 0

# dropout
model_config = set_dropout(model_config, 0.0)

# setup model
model_params = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
model_runner = model.RunModel(model_config, model_params[0], is_training=True)

# load the other models to sample during design.
for model_name in ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]:
  params = data.get_model_haiku_params(model_name, '.')
  model_params.append({k: params[k] for k in model_runner.params.keys()})

In [None]:
###############
# USER INPUT
###############
# native structure you want to pull active site from
pos_idx_ref = [13,37,98] # note: zero indexed
PDB_REF = "afDesign/1QJG.pdb"

# starting structure (for random starting sequence, set PDB=None and LEN to desired length)
pos_idx = [44,9,78]
MODE = "1.05_44_9_78_s524_r3_cce_adam"
PDB = f"{MODE}.pdb"
LEN = 100

In [None]:
# prep reference (native) features
OBJ_REF = protein.from_pdb_string(pdb_to_string(PDB_REF), chain_id="A")
SEQ_REF = jax.nn.one_hot(OBJ_REF.aatype,20)
START_SEQ_REF = "".join([order_restype[a] for a in OBJ_REF.aatype])

batch_ref = {'aatype': OBJ_REF.aatype,
             'all_atom_positions': OBJ_REF.atom_positions,
             'all_atom_mask': OBJ_REF.atom_mask}
batch_ref.update(all_atom.atom37_to_frames(**batch_ref))
batch_ref.update(prep_inputs.make_atom14_positions(batch_ref))

# prep starting (design) features
if PDB is not None:
  OBJ = protein.from_pdb_string(pdb_to_string(PDB), chain_id="A")
  SEQ = jax.nn.one_hot(OBJ.aatype,20)
  START_SEQ = "".join([order_restype[a] for a in OBJ.aatype])

  batch = {'aatype': OBJ.aatype,
          'all_atom_positions': OBJ.atom_positions,
          'all_atom_mask': OBJ.atom_mask}
  batch.update(all_atom.atom37_to_frames(**batch))
  batch.update(prep_inputs.make_atom14_positions(batch))
else:
  SEQ = jnp.zeros(LEN).at[jnp.asarray(pos_idx)].set([OBJ_REF.aatype[i] for i in pos_idx_ref])
  START_SEQ = "".join([order_restype[a] for a in SEQ])
  SEQ = jax.nn.one_hot(SEQ,20)

# prep input features
feature_dict = {
    **pipeline.make_sequence_features(sequence=START_SEQ,description="none",num_res=len(START_SEQ)),
    **pipeline.make_msa_features(msas=[N*[START_SEQ]], deletion_matrices=[N*[[0]*len(START_SEQ)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)

if N > 1:
  inputs["msa_row_mask"] = jnp.ones_like(inputs["msa_row_mask"])
  inputs["msa_mask"] = jnp.ones_like(inputs["msa_mask"])

In [None]:
print([START_SEQ[i] for i in pos_idx])
print([START_SEQ_REF[i] for i in pos_idx_ref])

['Y', 'N', 'D']
['Y', 'N', 'D']


In [None]:
########################################
# losses to constrain backbone to starting design
########################################
def get_dgram_loss_(batch, outputs):
  pb, pb_mask = model.modules.pseudo_beta_fn(batch["aatype"],
                                             batch["all_atom_positions"],
                                             batch["all_atom_mask"])
  
  dgram_loss = model.modules._distogram_log_loss(outputs["distogram"]["logits"],
                                                 outputs["distogram"]["bin_edges"],
                                                 batch={"pseudo_beta":pb,"pseudo_beta_mask":pb_mask},
                                                 num_bins=model_config.model.heads.distogram.num_bins)
  return dgram_loss["loss"]

def get_fape_loss_(batch, outputs, use_clamped_fape=False):

  sub_batch = jax.tree_map(lambda x: x, batch)
  sub_batch["use_clamped_fape"] = use_clamped_fape
  loss = {"loss":0.0}    
  folding.backbone_loss(loss, sub_batch, outputs["structure_module"], model_config.model.heads.structure_module)
  return loss["loss"]

#########################################
# losses to constrain sidechains to active site
#########################################
def get_dgram_loss(batch, outputs, pos_idx, pos_idx_ref=None):
  if pos_idx_ref is None: pos_idx_ref = pos_idx
  pb, pb_mask = model.modules.pseudo_beta_fn(batch["aatype"][pos_idx_ref],
                                             batch["all_atom_positions"][pos_idx_ref],
                                             batch["all_atom_mask"][pos_idx_ref])
  
  dgram_loss = model.modules._distogram_log_loss(outputs["distogram"]["logits"][:,pos_idx][pos_idx,:],
                                                 outputs["distogram"]["bin_edges"],
                                                 batch={"pseudo_beta":pb,"pseudo_beta_mask":pb_mask},
                                                 num_bins=model_config.model.heads.distogram.num_bins)
  return dgram_loss["loss"]

def get_fape_loss(batch, outputs, pos_idx, pos_idx_ref=None, backbone=True, sidechain=True, use_clamped_fape=False):
  if pos_idx_ref is None: pos_idx_ref = pos_idx

  sub_batch = jax.tree_map(lambda x: x[pos_idx_ref,...], batch)
  sub_batch["use_clamped_fape"] = use_clamped_fape

  value = jax.tree_map(lambda x: x, outputs["structure_module"])
  loss = {"loss":0.0}
  
  if sidechain:
    value.update(folding.compute_renamed_ground_truth(sub_batch, value['final_atom14_positions'][pos_idx,...]))
    value['sidechains']['frames'] = jax.tree_map(lambda x: x[:,pos_idx,:], value["sidechains"]["frames"])
    value['sidechains']['atom_pos'] = jax.tree_map(lambda x: x[:,pos_idx,:], value["sidechains"]["atom_pos"])
    loss.update(folding.sidechain_loss(sub_batch, value, model_config.model.heads.structure_module))
  
  if backbone:
    value["traj"] = value["traj"][...,pos_idx,:]
    folding.backbone_loss(loss, sub_batch, value, model_config.model.heads.structure_module)

  return loss["loss"]

def get_sidechain_rmsd_fix(batch, outputs, pos_idx, pos_idx_ref=None, include_CA=False):

  if pos_idx_ref is None: pos_idx_ref = pos_idx
  bb_atoms_to_exclude = ["N","O"] if include_CA else ["N","CA","O"]

  def kabsch(P, Q):
    V, S, W = jnp.linalg.svd(P.T @ Q, full_matrices=False)
    flip = jax.nn.sigmoid(-10 * jnp.linalg.det(V) * jnp.linalg.det(W))
    S = flip * S.at[-1].set(-S[-1]) + (1-flip) * S
    V = flip * V.at[:,-1].set(-V[:,-1]) + (1-flip) * V
    return V@W

  true_aa_idx = batch["aatype"][pos_idx_ref]
  true_pos = all_atom.atom37_to_atom14(batch["all_atom_positions"],batch)[pos_idx_ref,:,:]
  pred_pos = outputs["structure_module"]["final_atom14_positions"][pos_idx,:,:]

  i,j,j_alt = [],[],[]
  i_non,j_non = [],[]
  for n,aa_idx in enumerate(true_aa_idx):
    aa = idx_to_resname[aa_idx]
    atoms = residue_constants.residue_atoms[aa].copy()
    for atom in atoms:
      if atom not in bb_atoms_to_exclude:
        i.append(n)
        j.append(residue_constants.restype_name_to_atom14_names[aa].index(atom))
        if aa in residue_constants.residue_atom_renaming_swaps:
          swaps = residue_constants.residue_atom_renaming_swaps[aa]
          swaps_rev = {v:k for k,v in swaps.items()}
          if atom in swaps:
            j_alt.append(residue_constants.restype_name_to_atom14_names[aa].index(swaps[atom]))
          elif atom in swaps_rev:
            j_alt.append(residue_constants.restype_name_to_atom14_names[aa].index(swaps_rev[atom]))
          else:
            j_alt.append(j[-1])
            i_non.append(i[-1])
            j_non.append(j[-1])
        else:
          j_alt.append(j[-1])
          i_non.append(i[-1])
          j_non.append(j[-1])

  # align non-ambigious atoms
  true_pos_non = true_pos[i_non,j_non,:]  
  pred_pos_non = pred_pos[i_non,j_non,:]
  true_pos = (true_pos - true_pos_non.mean(0)) @ kabsch(true_pos_non - true_pos_non.mean(0), pred_pos_non - pred_pos_non.mean(0))
  pred_pos = pred_pos - pred_pos_non.mean(0)

  true_pos_a = true_pos[i,j,:]
  pred_pos_a = pred_pos[i,j,:]
  pred_pos_b = pred_pos[i,j_alt,:]

  rms_a = jnp.square(true_pos_a - pred_pos_a).sum(-1)
  rms_b = jnp.square(true_pos_a - pred_pos_b).sum(-1)

  return jnp.sqrt(jnp.minimum(rms_a,rms_b).mean() + 1e-8)

In [None]:
def get_grad_fn(model_runner, inputs, pos_idx_ref, inc_backbone=False):
  
  def mod(params, key, model_params, opt):
    pos_idx = opt["pos_idx"]
    ############################
    # set amino acid sequence
    ############################
    seq_logits = jax.random.permutation(key, params["msa"])
    seq_soft = jax.nn.softmax(seq_logits)
    seq = jax.lax.stop_gradient(jax.nn.one_hot(seq_soft.argmax(-1),20) - seq_soft) + seq_soft
    seq = seq.at[:,pos_idx,:].set(SEQ_REF[pos_idx_ref,:])

    oh_mask = opt["oh_mask"][:,None]
    pseudo_seq = oh_mask * seq + (1-oh_mask) * seq_logits

    inputs_mod = inputs.copy()
    update_seq(pseudo_seq, inputs_mod, msa_input=("msa" in params))

    if "msa_mask" in opt:
      inputs_mod["msa_mask"] = inputs_mod["msa_mask"] * opt["msa_mask"][None,:,None]
      inputs_mod["msa_row_mask"] = inputs_mod["msa_row_mask"] * opt["msa_mask"][None,:]
    
    ####################
    # set sidechains identity
    ####################
    B,L = inputs_mod["aatype"].shape[:2]
    ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)

    if "msa" in params:
      aatype = jnp.zeros((B,L,21)).at[...,:20].set(seq[0])
    else:
      aatype = jnp.zeros((B,L,21)).at[...,:20].set(seq)

    ala_mask = opt["ala_mask"][:,None]
    aatype_ala = jnp.zeros((B,L,21)).at[:].set(ALA)
    aatype_ala = aatype_ala.at[:,pos_idx,:20].set(SEQ_REF[pos_idx_ref,:])
    aatype_pseudo = ala_mask * aatype + (1-ala_mask) * aatype_ala
    update_aatype(aatype_pseudo, inputs_mod)
    
    # get output
    outputs = model_runner.apply(model_params, key, inputs_mod)

    ###################
    # structure loss
    ###################
    fape_loss = get_fape_loss(batch_ref, outputs, pos_idx, pos_idx_ref, backbone=inc_backbone, sidechain=True)
    rmsd_loss = get_sidechain_rmsd_fix(batch_ref, outputs, pos_idx, pos_idx_ref)
    dgram_loss = get_dgram_loss(batch_ref, outputs, pos_idx, pos_idx_ref)

    losses = {"fape":fape_loss,
              "rmsd":rmsd_loss,
              "dgram":dgram_loss}

    if "sc_weight_fape" in opt: fape_loss *= opt["sc_weight_fape"]
    if "sc_weight_rmsd" in opt: rmsd_loss *= opt["sc_weight_rmsd"]
    if "sc_weight_dgram" in opt: dgram_loss *= opt["sc_weight_dgram"]

    loss = (rmsd_loss + fape_loss + dgram_loss) * opt["sc_weight"]
  
    ################### 
    # background loss
    ###################
    if "conf_weight" in opt:
      pae = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"])
      pae_loss = (pae * jnp.arange(pae.shape[-1])).sum(-1).mean()
      plddt = jax.nn.softmax(outputs['predicted_lddt']['logits'])
      plddt_loss = (plddt * jnp.arange(plddt.shape[-1])[::-1]).sum(-1).mean()

      loss = loss + (pae_loss + plddt_loss) * opt["conf_weight"]
      losses["pae"] = pae_loss
      losses["plddt"] = plddt_loss

    if "rg_weight" in opt:
      ca_coords = outputs["structure_module"]["final_atom_positions"][:,1,:]
      rg_loss = jnp.sqrt(jnp.square(ca_coords - ca_coords.mean(0)).sum(-1).mean() + 1e-8)
      loss = loss + rg_loss * opt["rg_weight"]
      losses["rg"] = rg_loss
      
    if "bb_weight" in opt:
      fape_start_loss = get_fape_loss_(batch, outputs)      
      dgram_start_loss = get_dgram_loss_(batch, outputs)
      loss = loss + (dgram_start_loss + fape_start_loss) * opt["bb_weight"]
      losses["dgram_start"] = dgram_start_loss
      losses["fape_start"] = fape_start_loss
    
    if "msa" in params and "ent_weight" in opt:
      seq_prf = seq.mean(0)
      ent_loss = -(seq_prf * jnp.log(seq_prf + 1e-8)).sum(-1).mean()
      loss = loss + ent_loss * opt["ent_weight"]
      losses["ent"] = ent_loss
    else:
      ent_loss = 0

    outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
            "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

    seq_ = seq[0] if "msa" in params else seq

    return loss, ({"losses":losses, "outputs":outs, "seq":seq_})
  loss_fn = mod
  grad_fn = jax.value_and_grad(mod, has_aux=True, argnums=0)
  return loss_fn, grad_fn

In [None]:
# gradient function (note for greedy search we won't be using grad_fn, only loss_fn)
loss_fn, grad_fn = get_grad_fn(model_runner, inputs, pos_idx_ref=pos_idx_ref)
loss_fn = jax.jit(loss_fn)

In [None]:
RMSD_min = np.inf
key = jax.random.PRNGKey(0)
L,A = len(START_SEQ),20

pos_idx_ = jnp.asarray(pos_idx)
pos_idx_ref_ = jnp.asarray(pos_idx_ref)

msa = SEQ[None]
#msa = jnp.zeros_like(SEQ[None]).at[:,pos_idx_].set(SEQ_REF[pos_idx_ref_])
#msa = jax.nn.one_hot(jax.random.randint(key, (N,L), 0, A),A).at[:,pos_idx_].set(SEQ_REF[pos_idx_ref_])
params = {"msa":msa}

In [None]:
oh_mask = jnp.ones((L,))
ala_mask = jnp.ones((L,))
msa_mask = jnp.ones((N,))
opt={"oh_mask":oh_mask,
    "msa_mask":msa_mask,
    "ala_mask":ala_mask,
    #"bb_weight":0.1, 
    "sc_weight":1.0,
    "sc_weight_rmsd":1.0,
    "sc_weight_fape":1.0,
    "sc_weight_dgram":0.0,
    "ent_weight":0.0,
    "rg_weight":0.0,
    "conf_weight":0.01,
    "pos_idx":pos_idx_,
    }
loss, outs = loss_fn(params, key, model_params[0], opt=opt)
print(loss,outs["losses"]["rmsd"],outs["losses"]["fape"])
save_pdb(outs,f"{MODE}_starting.pdb")

1.966159 1.0131341 0.43879992


In [None]:
def mut(params):
  while True:
    i = np.random.randint(L)
    a = np.random.randint(A)
    if i not in pos_idx and params["msa"][0,i,a] == 0:
      params_ = params.copy()
      params_["msa"] = params["msa"].at[:,i,:].set(jnp.eye(A)[a])
      break
  return params_

In [None]:
LOSS = loss
OVERALL_RMSD = outs["losses"]["rmsd"]
OVERALL_FAPE = outs["losses"]["fape"]
OVERALL_LOSS = LOSS
key = jax.random.PRNGKey(0)
for n in range(1000):
  buff_p,buff_l,buff_o = [],[],[]
  for _ in range(20):
    key,subkey = jax.random.split(key)
    p = mut(params)
    l,o = loss_fn(p, subkey, model_params[0], opt=opt)
    buff_p.append(p); buff_l.append(l); buff_o.append(o)
    if l < LOSS: break
  best = jnp.argmin(jnp.asarray(buff_l))
  params, LOSS, outs = buff_p[best], buff_l[best], buff_o[best]
  RMSD = outs["losses"]["rmsd"]
  FAPE = outs["losses"]["fape"]
  if RMSD < OVERALL_RMSD:
    OVERALL_RMSD = RMSD
    save_pdb(outs,f"{MODE}_best_rmsd.pdb")
  if FAPE < OVERALL_FAPE:
    OVERALL_FAPE = FAPE
    save_pdb(outs,f"{MODE}_best_fape.pdb")
  if LOSS < OVERALL_LOSS:
    OVERALL_LOSS = LOSS
    save_pdb(outs,f"{MODE}_best_loss.pdb")
  print(n, LOSS, RMSD, FAPE, len(buff_l))
  n += 1