<a href="https://colab.research.google.com/github/sokrypton/AfDesign_partial/blob/main/notebooks/step1_predesign.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%bash
if [ ! -d afDesign ]; then
  git clone https://github.com/sokrypton/af_backprop.git
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi
if [ ! -d models ]; then
  wget -qnc https://files.ipd.uw.edu/krypton/TrRosetta/models.zip
  unzip -qqo models.zip
  wget -qnc https://raw.githubusercontent.com/gjoni/trDesign/beta/02-GD/utils.py -O TrD_utils.py
fi

# data
wget -qnc https://raw.githubusercontent.com/sokrypton/AfDesign_partial/main/data/1QJG.pdb
wget -qnc https://github.com/sokrypton/AfDesign_partial/raw/main/data/bkg_100.npy

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.50"

In [None]:
import sys
sys.path.append('af_backprop')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from alphafold.data import prep_inputs
from utils import *

In [None]:
import tensorflow.compat.v1 as tf
#from jax.experimental import jax2tf

import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v1.keras.backend as K1
import tensorflow.keras.backend as K

tf1.disable_eager_execution()

tf_config = tf1.ConfigProto()
tf_config.gpu_options.per_process_gpu_memory_fraction=0.5


import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Activation, Dense, Lambda, Layer, Concatenate

from TrD_utils import prep_input, split_feat

In [None]:
def get_TrR_weights(filename):
  weights = [np.squeeze(w) for w in np.load(filename, allow_pickle=True)]
  # remove weights for beta-beta pairing
  del weights[-4:-2]
  return weights

def get_TrR(blocks=12, trainable=False, weights=None, name="TrR"):
  ex = {"trainable":trainable}
  # custom layer(s)
  class PSSM(Layer):
    # modified from MRF to only output tiled 1D features
    def __init__(self, diag=0.4, use_entropy=False):
      super(PSSM, self).__init__()
      self.diag = diag
      self.use_entropy = use_entropy
    def call(self, inputs):
      x,y = inputs
      _,_,L,A = [tf.shape(y)[k] for k in range(4)]
      with tf.name_scope('1d_features'):
        # sequence
        x_i = x[0,0,:,:20]
        # pssm
        f_i = y[0,0]
        # entropy
        if self.use_entropy:
          h_i = K.sum(-f_i * K.log(f_i + 1e-8), axis=-1, keepdims=True)
        else:
          h_i = tf.zeros((L,1))
        # tile and combined 1D features
        feat_1D = tf.concat([x_i,f_i,h_i], axis=-1)
        feat_1D_tile_A = tf.tile(feat_1D[:,None,:], [1,L,1])
        feat_1D_tile_B = tf.tile(feat_1D[None,:,:], [L,1,1])

      with tf.name_scope('2d_features'):
        ic = self.diag * tf.eye(L*A)
        ic = tf.reshape(ic,(L,A,L,A))
        ic = tf.transpose(ic,(0,2,1,3))
        ic = tf.reshape(ic,(L,L,A*A))
        i0 = tf.zeros([L,L,1])
        feat_2D = tf.concat([ic,i0], axis=-1)

      feat = tf.concat([feat_1D_tile_A, feat_1D_tile_B, feat_2D],axis=-1)
      return tf.reshape(feat, [1,L,L,442+2*42])
      
  class instance_norm(Layer):
    def __init__(self, axes=(1,2),trainable=True):
      super(instance_norm, self).__init__()
      self.axes = axes
      self.trainable = trainable
    def build(self, input_shape):
      self.beta  = self.add_weight(name='beta',shape=(input_shape[-1],),
                                  initializer='zeros',trainable=self.trainable)
      self.gamma = self.add_weight(name='gamma',shape=(input_shape[-1],),
                                  initializer='ones',trainable=self.trainable)
    def call(self, inputs):
      mean, variance = tf.nn.moments(inputs, self.axes, keepdims=True)
      return tf.nn.batch_normalization(inputs, mean, variance, self.beta, self.gamma, 1e-6)

  ## INPUT ##
  inputs = Input((None,None,21),batch_size=1)
  A = PSSM()([inputs,inputs])
  A = Dense(64, **ex)(A)
  A = instance_norm(**ex)(A)
  A = Activation("elu")(A)

  ## RESNET ##
  def resnet(X, dilation=1, filters=64, win=3):
    Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(X)
    Y = instance_norm(**ex)(Y)
    Y = Activation("elu")(Y)
    Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(Y)
    Y = instance_norm(**ex)(Y)
    return Activation("elu")(X+Y)

  for _ in range(blocks):
    for dilation in [1,2,4,8,16]:
      A = resnet(A, dilation)
  A = resnet(A, dilation=1)
  
  ## OUTPUT ##
  A_input   = Input((None,None,64))
  p_theta   = Dense(25, activation="softmax", **ex)(A_input)
  p_phi     = Dense(13, activation="softmax", **ex)(A_input)
  A_sym     = Lambda(lambda x: (x + tf.transpose(x,[0,2,1,3]))/2)(A_input)
  p_dist    = Dense(37, activation="softmax", **ex)(A_sym)
  p_omega   = Dense(25, activation="softmax", **ex)(A_sym)
  A_model   = Model(A_input,Concatenate()([p_theta,p_phi,p_dist,p_omega]))

  ## MODEL ##
  model = Model(inputs, A_model(A),name=name)
  if weights is not None: model.set_weights(weights)
  return model

In [None]:
def get_TrR_model(L=None):
  def gather_idx(x):
    idx = x[1][0]
    return tf.gather(tf.gather(x[0],idx,axis=-2),idx,axis=-3)

  def get_cce_loss(x, eps=1e-8, exclude_theta=True):
    if exclude_theta:
      true_x = split_feat(x[0])
      pred_x = split_feat(x[1])
      true_x = tf.concat([true_x[k] for k in ["phi","dist","omega"]],-1)
      pred_x = tf.concat([pred_x[k] for k in ["phi","dist","omega"]],-1)
      loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2])
      loss *= 4/3
      return loss
    else:
      return -tf.reduce_mean(tf.reduce_sum(x[0]*tf.math.log(x[1] + eps),-1),[-1,-2])
  
  def get_bkg_loss(x, eps=1e-8):
    return -tf.reduce_mean(tf.reduce_sum(x[1]*(tf.math.log(x[1]+eps)-tf.math.log(x[0]+eps)),-1),[-1,-2])      

  def prep_seq(x_logits):
    x_soft = tf.nn.softmax(x_logits,-1)
    x_hard = tf.one_hot(tf.argmax(x_logits,-1),20)
    x = tf.stop_gradient(x_hard - x_soft) + x_soft
    x = tf.pad(x,[[0,0],[0,0],[0,1]])
    return x[None]

  I_seq_logits = Input((L,20),name="seq_logits")
  seq = Lambda(prep_seq,name="seq")(I_seq_logits)
  I_true = Input((L,L,100),name="true")
  I_bkg = Input((L,L,100),name="bkg")
  I_idx = Input((None,),dtype=tf.int32,name="idx")
  I_idx_true = Input((None,),dtype=tf.int32,name="idx_true")
  
  #nam = "xaa"
  #pred = get_TrR(weights=get_TrR_weights(f"models/model_{nam}.npy"),name=nam)(seq)
  pred = []
  for nam in ["xaa","xab","xac","xad","xae"]:
    print(nam)
    TrR = get_TrR(weights=get_TrR_weights(f"models/model_{nam}.npy"),name=nam)
    pred.append(TrR(seq))
  pred = sum(pred)/len(pred)

  pred_sub = Lambda(gather_idx, name="pred_sub")([pred,I_idx])
  true_sub = Lambda(gather_idx, name="true_sub")([I_true,I_idx_true])
  
  cce_loss = Lambda(get_cce_loss,name="cce_loss")([true_sub, pred_sub])
  bkg_loss = Lambda(get_bkg_loss,name="bkg_loss")([I_bkg, pred])

  loss = Lambda(lambda x: x[0]+0.1*x[1])([cce_loss,bkg_loss])
  grad = Lambda(lambda x: tf.gradients(x[0],x[1]), name="grad")([loss,I_seq_logits])
  model = Model([I_seq_logits, I_true, I_bkg, I_idx, I_idx_true], [cce_loss, bkg_loss, grad, pred], name="TrR_model")
  
  def TrR_model(seq, true, bkg, pos_idx, pos_idx_ref=None):
    if pos_idx_ref is None: pos_idx_ref = pos_idx
    cce_loss, bkg_loss, grad, pred = model.predict([seq[None],
                                                    true[None],
                                                    bkg[None],
                                                    pos_idx[None],
                                                    pos_idx_ref[None]])
    return {"cce_loss":cce_loss[0],
            "bkg_loss":bkg_loss[0],
            "grad":grad[0],
            "pred":pred[0]}
  
  return TrR_model

In [None]:
tf1.reset_default_graph()
K.clear_session()
K1.set_session(tf1.Session(config=tf_config))
TrR_model = get_TrR_model()


xaa
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
xab
xac
xad
xae


In [None]:
N = 1
MODE = "ksi_test"
if MODE == "cm2":
  pos_idx =     [43, 102,  50, 79, 90, 103, 129, 133, 99]
  pos_idx_ref = [97,  47, 104, 24, 35,  48,  80,  84, 44]

  pos_idx = [i-37 for i in pos_idx]
  pos_idx_ref = [i-1 for i in pos_idx_ref]

  PDB = "afDesign/design_cm2.pdb"
  PDB_REF = "afDesign/native_cm2.pdb"

elif MODE == "hcA":
  pos_idx =     [5,  7,   16,   82,   49]
  pos_idx_ref = [91, 93, 116,   61,  195]

  pos_idx = [i-1 for i in pos_idx]
  pos_idx_ref = [i-1 for i in pos_idx_ref]

  PDB = "afDesign/design_hcA.pdb"
  PDB_REF = "afDesign/native_hcA.pdb"

elif MODE == "hcA_control":
  pos_idx = [5,  7,   16,   82,   49]
  pos_idx = [i-1 for i in pos_idx]
  PDB = "afDesign/design_hcA_model_3_ptm_seed_0_unrelaxed.pdb"

  pos_idx_ref = pos_idx
  PDB_REF = PDB

elif MODE == "ksi":
  #native 1QJG: A14,A38,A99
  #ksi_149: A85,A48,A71
  #ksi_230: A34,A56,A7
  #ksi2_301: A91,A8,A38

  pos_idx_ref = [13,37,98]
  PDB_REF = "afDesign/1QJG.pdb"
  PDB = PDB_REF
  pos_idx = pos_idx_ref

elif MODE == "ksi_test":
  pos_idx_ref = [13,37,98]
  PDB_REF = "afDesign/1QJG.pdb"

  pos_idx = [20,45,80]
  PDB = "afDesign/tmp_20_45_80_best.pdb"
  LEN = 100

TrR_ref = prep_input(PDB_REF,"A")

In [None]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
model_config.model.num_recycle = 3
model_config.data.common.num_recycle = 3

# backprop through recycles
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# number of sequences
model_config.data.eval.max_msa_clusters = N
model_config.data.common.max_extra_msa = 1
model_config.data.eval.masked_msa_replace_fraction = 0

# dropout
model_config = set_dropout(model_config, 0.0)

# setup model
model_params = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
model_runner = model.RunModel(model_config, model_params[0], is_training=True)

for model_name in ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]:
  params = data.get_model_haiku_params(model_name, '.')
  model_params.append({k: params[k] for k in model_runner.params.keys()})

In [None]:
# prep reference (native) features
OBJ_REF = protein.from_pdb_string(pdb_to_string(PDB_REF), chain_id="A")
SEQ_REF = jax.nn.one_hot(OBJ_REF.aatype,20)
START_SEQ_REF = "".join([order_restype[a] for a in OBJ_REF.aatype])

batch_ref = {'aatype': OBJ_REF.aatype,
         'all_atom_positions': OBJ_REF.atom_positions,
         'all_atom_mask': OBJ_REF.atom_mask}
batch_ref.update(all_atom.atom37_to_frames(**batch_ref))
batch_ref.update(prep_inputs.make_atom14_positions(batch_ref))

# prep starting (design) features
if PDB is not None:
  OBJ = protein.from_pdb_string(pdb_to_string(PDB), chain_id="A")
  SEQ = jax.nn.one_hot(OBJ.aatype,20)
  START_SEQ = "".join([order_restype[a] for a in OBJ.aatype])

  batch = {'aatype': OBJ.aatype,
          'all_atom_positions': OBJ.atom_positions,
          'all_atom_mask': OBJ.atom_mask}
  batch.update(all_atom.atom37_to_frames(**batch))
  batch.update(prep_inputs.make_atom14_positions(batch))
else:
  SEQ = jnp.zeros(LEN).at[jnp.asarray(pos_idx)].set([OBJ_REF.aatype[i] for i in pos_idx_ref])
  START_SEQ = "".join([order_restype[a] for a in SEQ])
  SEQ = jax.nn.one_hot(SEQ,20)

# prep input features
inputs = jax.tree_map(lambda x:jnp.asarray(x), dict(np.load("inputs_4.npz")))

'''
feature_dict = {
    **pipeline.make_sequence_features(sequence=START_SEQ,description="none",num_res=len(START_SEQ)),
    **pipeline.make_msa_features(msas=[N*[START_SEQ]], deletion_matrices=[N*[[0]*len(START_SEQ)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)

if N > 1:
  inputs["msa_row_mask"] = jnp.ones_like(inputs["msa_row_mask"])
  inputs["msa_mask"] = jnp.ones_like(inputs["msa_mask"])
'''

'\nfeature_dict = {\n    **pipeline.make_sequence_features(sequence=START_SEQ,description="none",num_res=len(START_SEQ)),\n    **pipeline.make_msa_features(msas=[N*[START_SEQ]], deletion_matrices=[N*[[0]*len(START_SEQ)]]),\n}\ninputs = model_runner.process_features(feature_dict, random_seed=0)\n\nif N > 1:\n  inputs["msa_row_mask"] = jnp.ones_like(inputs["msa_row_mask"])\n  inputs["msa_mask"] = jnp.ones_like(inputs["msa_mask"])\n'

In [None]:
print([START_SEQ[i] for i in pos_idx])
print([START_SEQ_REF[i] for i in pos_idx_ref])

['Y', 'N', 'D']
['Y', 'N', 'D']


In [None]:
def get_dgram_loss_(batch, outputs):
  pb, pb_mask = model.modules.pseudo_beta_fn(batch["aatype"],
                                             batch["all_atom_positions"],
                                             batch["all_atom_mask"])
  
  dgram_loss = model.modules._distogram_log_loss(outputs["distogram"]["logits"],
                                                 outputs["distogram"]["bin_edges"],
                                                 batch={"pseudo_beta":pb,"pseudo_beta_mask":pb_mask},
                                                 num_bins=model_config.model.heads.distogram.num_bins)
  return dgram_loss["loss"]

def get_fape_loss_(batch, outputs, use_clamped_fape=False):

  sub_batch = jax.tree_map(lambda x: x, batch)
  sub_batch["use_clamped_fape"] = use_clamped_fape
  loss = {"loss":0.0}    
  folding.backbone_loss(loss, sub_batch, outputs["structure_module"], model_config.model.heads.structure_module)
  return loss["loss"]

#########################################
# loss restricted to specific amino acids
#########################################
def get_dgram_loss(batch, outputs, pos_idx, pos_idx_ref=None):
  if pos_idx_ref is None: pos_idx_ref = pos_idx
  pb, pb_mask = model.modules.pseudo_beta_fn(batch["aatype"][pos_idx_ref],
                                             batch["all_atom_positions"][pos_idx_ref],
                                             batch["all_atom_mask"][pos_idx_ref])
  
  dgram_loss = model.modules._distogram_log_loss(outputs["distogram"]["logits"][:,pos_idx][pos_idx,:],
                                                 outputs["distogram"]["bin_edges"],
                                                 batch={"pseudo_beta":pb,"pseudo_beta_mask":pb_mask},
                                                 num_bins=model_config.model.heads.distogram.num_bins)
  return dgram_loss["loss"]

def get_fape_loss(batch, outputs, pos_idx, pos_idx_ref=None, backbone=True, sidechain=True, use_clamped_fape=False):
  if pos_idx_ref is None: pos_idx_ref = pos_idx

  sub_batch = jax.tree_map(lambda x: x[pos_idx_ref,...], batch)
  sub_batch["use_clamped_fape"] = use_clamped_fape

  value = jax.tree_map(lambda x: x, outputs["structure_module"])
  loss = {"loss":0.0}
  
  if sidechain:
    value.update(folding.compute_renamed_ground_truth(sub_batch, value['final_atom14_positions'][pos_idx,...]))
    value['sidechains']['frames'] = jax.tree_map(lambda x: x[:,pos_idx,:], value["sidechains"]["frames"])
    value['sidechains']['atom_pos'] = jax.tree_map(lambda x: x[:,pos_idx,:], value["sidechains"]["atom_pos"])
    loss.update(folding.sidechain_loss(sub_batch, value, model_config.model.heads.structure_module))
  
  if backbone:
    value["traj"] = value["traj"][...,pos_idx,:]
    folding.backbone_loss(loss, sub_batch, value, model_config.model.heads.structure_module)

  return loss["loss"]

def get_sidechain_rmsd_fix(batch, outputs, pos_idx, pos_idx_ref=None, include_CA=True):

  if pos_idx_ref is None: pos_idx_ref = pos_idx
  bb_atoms_to_exclude = ["N","O"] if include_CA else ["N","CA","O"]

  def kabsch(P, Q):
    V, S, W = jnp.linalg.svd(P.T @ Q, full_matrices=False)
    flip = jax.nn.sigmoid(-10 * jnp.linalg.det(V) * jnp.linalg.det(W))
    S = flip * S.at[-1].set(-S[-1]) + (1-flip) * S
    V = flip * V.at[:,-1].set(-V[:,-1]) + (1-flip) * V
    return V@W

  true_aa_idx = batch["aatype"][pos_idx_ref]
  true_pos = all_atom.atom37_to_atom14(batch["all_atom_positions"],batch)[pos_idx_ref,:,:]
  pred_pos = outputs["structure_module"]["final_atom14_positions"][pos_idx,:,:]

  i,j,j_alt = [],[],[]
  i_non,j_non = [],[]
  for n,aa_idx in enumerate(true_aa_idx):
    aa = idx_to_resname[aa_idx]
    atoms = residue_constants.residue_atoms[aa].copy()
    for atom in atoms:
      if atom not in bb_atoms_to_exclude:
        i.append(n)
        j.append(residue_constants.restype_name_to_atom14_names[aa].index(atom))
        if aa in residue_constants.residue_atom_renaming_swaps:
          swaps = residue_constants.residue_atom_renaming_swaps[aa]
          swaps_rev = {v:k for k,v in swaps.items()}
          if atom in swaps:
            j_alt.append(residue_constants.restype_name_to_atom14_names[aa].index(swaps[atom]))
          elif atom in swaps_rev:
            j_alt.append(residue_constants.restype_name_to_atom14_names[aa].index(swaps_rev[atom]))
          else:
            j_alt.append(j[-1])
            i_non.append(i[-1])
            j_non.append(j[-1])
        else:
          j_alt.append(j[-1])
          i_non.append(i[-1])
          j_non.append(j[-1])

  # align non-ambigious atoms
  true_pos_non = true_pos[i_non,j_non,:]  
  pred_pos_non = pred_pos[i_non,j_non,:]
  true_pos = (true_pos - true_pos_non.mean(0)) @ kabsch(true_pos_non - true_pos_non.mean(0), pred_pos_non - pred_pos_non.mean(0))
  pred_pos = pred_pos - pred_pos_non.mean(0)

  true_pos_a = true_pos[i,j,:]
  pred_pos_a = pred_pos[i,j,:]
  pred_pos_b = pred_pos[i,j_alt,:]

  rms_a = jnp.square(true_pos_a - pred_pos_a).sum(-1)
  rms_b = jnp.square(true_pos_a - pred_pos_b).sum(-1)

  return jnp.sqrt(jnp.minimum(rms_a,rms_b).mean() + 1e-8)

In [None]:
def get_grad_fn(model_runner, inputs, pos_idx_ref, inc_backbone=False, TrR_only=False):
  
  def mod(params, key, model_params, opt):
    pos_idx = opt["pos_idx"]
    ############################
    # set amino acid sequence
    ############################
    if "msa" in params:
      seq_logits = jax.random.permutation(key, params["msa"])
      seq_soft = jax.nn.softmax(seq_logits)
      seq = jax.lax.stop_gradient(jax.nn.one_hot(seq_soft.argmax(-1),20) - seq_soft) + seq_soft
      seq = seq.at[:,pos_idx,:].set(SEQ_REF[pos_idx_ref,:])

    oh_mask = opt["oh_mask"][:,None]
    pseudo_seq = oh_mask * seq + (1-oh_mask) * seq_logits

    inputs_mod = inputs.copy()
    update_seq(pseudo_seq, inputs_mod, msa_input=("msa" in params))

    if "msa_mask" in opt:
      inputs_mod["msa_mask"] = inputs_mod["msa_mask"] * opt["msa_mask"][None,:,None]
      inputs_mod["msa_row_mask"] = inputs_mod["msa_row_mask"] * opt["msa_mask"][None,:]
    
    ####################
    # set sidechains identity
    ####################
    B,L = inputs_mod["aatype"].shape[:2]
    ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)

    if "msa" in params:
      aatype = jnp.zeros((B,L,21)).at[...,:20].set(seq[0])
    else:
      aatype = jnp.zeros((B,L,21)).at[...,:20].set(seq)

    ala_mask = opt["ala_mask"][:,None]
    aatype_ala = jnp.zeros((B,L,21)).at[:].set(ALA)
    aatype_ala = aatype_ala.at[:,pos_idx,:20].set(SEQ_REF[pos_idx_ref,:])
    aatype_pseudo = ala_mask * aatype + (1-ala_mask) * aatype_ala
    update_aatype(aatype_pseudo, inputs_mod)
    
    # get output
    outputs = model_runner.apply(model_params, key, inputs_mod)


    ###################
    # structure loss
    ###################
    fape_loss = get_fape_loss(batch_ref, outputs, pos_idx, pos_idx_ref, backbone=inc_backbone, sidechain=True)
    rmsd_loss = get_sidechain_rmsd_fix(batch_ref, outputs, pos_idx, pos_idx_ref)
    dgram_loss = get_dgram_loss(batch_ref, outputs, pos_idx, pos_idx_ref)


    #TrR_loss = 0.0
    losses = {"fape":fape_loss,
              "rmsd":rmsd_loss,
              "dgram":dgram_loss}
              #"TrR":TrR_loss}

    if "sc_weight_fape" in opt: fape_loss *= opt["sc_weight_fape"]
    if "sc_weight_rmsd" in opt: rmsd_loss *= opt["sc_weight_rmsd"]
    if "sc_weight_dgram" in opt: dgram_loss *= opt["sc_weight_dgram"]
    #if "sc_weight_TrR" in opt: TrR_loss *= opt["sc_weight_TrR"]

    loss = (jnp.log(rmsd_loss + 1.0) + fape_loss + dgram_loss) * opt["sc_weight"]
  
    ################### 
    # background loss
    ###################
    if "conf_weight" in opt:
      pae = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"])
      pae_loss = (pae * jnp.arange(pae.shape[-1])).sum(-1).mean()
      plddt = jax.nn.softmax(outputs['predicted_lddt']['logits'])
      plddt_loss = (plddt * jnp.arange(plddt.shape[-1])[::-1]).sum(-1).mean()

      loss = loss + (pae_loss + plddt_loss) * opt["conf_weight"]
      losses["pae"] = pae_loss
      losses["plddt"] = plddt_loss

    if "rg_weight" in opt:
      ca_coords = outputs["structure_module"]["final_atom_positions"][:,1,:]
      rg_loss = jnp.sqrt(jnp.square(ca_coords - ca_coords.mean(0)).sum(-1).mean() + 1e-8)
      loss = loss + rg_loss * opt["rg_weight"]
      losses["rg"] = rg_loss
      
    if "bb_weight" in opt:
      fape_start_loss = get_fape_loss_(batch, outputs)      
      dgram_start_loss = get_dgram_loss_(batch, outputs)
      loss = loss + (dgram_start_loss + fape_start_loss) * opt["bb_weight"]
      losses["dgram_start"] = dgram_start_loss
      losses["fape_start"] = fape_start_loss
    
    if "msa" in params and "ent_weight" in opt:
      seq_prf = seq.mean(0)
      ent_loss = -(seq_prf * jnp.log(seq_prf + 1e-8)).sum(-1).mean()
      loss = loss + ent_loss * opt["ent_weight"]
      losses["ent"] = ent_loss
    else:
      ent_loss = 0

    outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
            "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

    seq_ = seq[0] if "msa" in params else seq

    return loss, ({"losses":losses, "outputs":outs, "seq":seq_})
  loss_fn = mod
  grad_fn = jax.value_and_grad(mod, has_aux=True, argnums=0)
  return loss_fn, grad_fn

In [None]:
# gradient function
loss_fn, grad_fn = get_grad_fn(model_runner, inputs, pos_idx_ref=pos_idx_ref, inc_backbone=False)
grad_fn = jax.jit(grad_fn)

In [None]:
bkg_ = np.load("bkg_100.npy")
feat_ = TrR_ref["feat"]

In [None]:
# compute loss and gradient, apply gradient
init_fun, update_fun, get_params = adam(step_size=5e-3)

In [None]:
#@jax.jit
def step_ADAM(i, state, key, model_params, opt):
  params = get_params(state)
  (loss, outs), grad = grad_fn(params, key, model_params=model_params, opt=opt)
    
  seq_ = np.asarray(params["msa"])[0]                     
  pos_idx_ = np.asarray(opt["pos_idx"])
  pos_idx_ref_ = np.asarray(pos_idx_ref)
  TrR_out = TrR_model(seq_,feat_,bkg_,pos_idx_,pos_idx_ref_)

  grad["msa"] += jnp.asarray(TrR_out["grad"])

  outs["losses"]["TrR_cce"] = float(TrR_out["cce_loss"])
  outs["losses"]["TrR_bkg"] = float(TrR_out["bkg_loss"])
  outs["TrR"] = TrR_out["pred"]
  
  grad["msa"] /= jnp.sqrt(jnp.square(grad["msa"]).sum([-1,-2],keepdims=True)) + 1e-8
  state = update_fun(i, grad, state)
  return state, outs

In [None]:
def step_GD(params, key, model_params, opt):
  (loss, outs), grad = grad_fn(params, key, model_params=model_params, opt=opt)
  
  seq_ = np.asarray(params["msa"])[0]                     
  pos_idx_ = np.asarray(opt["pos_idx"])
  pos_idx_ref_ = np.asarray(pos_idx_ref)
  TrR_out = TrR_model(seq_,feat_,bkg_,pos_idx_,pos_idx_ref_)

  outs["losses"]["TrR_cce"] = float(TrR_out["cce_loss"])
  outs["losses"]["TrR_bkg"] = float(TrR_out["bkg_loss"])
  outs["TrR"] = TrR_out["pred"]
  
  grad = grad["msa"] + jnp.asarray(TrR_out["grad"])
  grad /= jnp.sqrt(jnp.square(grad).sum([-1,-2],keepdims=True)) + 1e-8
  params["msa"] -= 0.5 * jnp.sqrt(grad.shape[-2]) * grad  
  return params, outs

In [None]:
#@jax.jit
def step_TrR(params, opt):
  seq_ = np.asarray(params["msa"])[0]                     
  pos_idx_ = np.asarray(opt["pos_idx"])
  pos_idx_ref_ = np.asarray(pos_idx_ref)
  TrR_out = TrR_model(seq_,feat_,bkg_,idx_,idx_ref_)

  outs = {"losses":{}}
  outs["losses"]["TrR_cce"] = float(TrR_out["cce_loss"])
  outs["losses"]["TrR_bkg"] = float(TrR_out["bkg_loss"])
  outs["TrR"] = TrR_out["pred"]
  
  grad = jnp.asarray(TrR_out["grad"])
  grad /= jnp.sqrt(jnp.square(grad).sum([-1,-2],keepdims=True)) + 1e-8
  params["msa"] -= jnp.sqrt(grad.shape[-2]) * grad  
  return params, outs

In [None]:
ia,ib,ic = np.random.choice(np.arange(L),size=3, replace=False)
ia,ib,ic

(83, 75, 30)

In [None]:
L,A = len(START_SEQ),20
while True:
  ia,ib,ic = np.random.choice(np.arange(L),size=3, replace=False)
  seed = np.random.randint(1000)
  
  pos_idx_ = jnp.asarray([ia,ib,ic])
  RMSD_min = np.inf
  FAPE_min = np.inf
  key = jax.random.PRNGKey(seed)
  pos_idx_ = jnp.asarray([ia,ib,ic])
  pos_idx_ref_ = jnp.asarray(pos_idx_ref)
  msa = 0.01 * jax.random.normal(key,shape=(1,L,A))
  msa = msa.at[:,pos_idx_].set(1e10 * SEQ_REF[pos_idx_ref_])
  params = {"msa":msa}
  state = init_fun(params)
  i = 0      
  oh_mask = jnp.ones((L,))
  ala_mask = jnp.zeros((L,))
  msa_mask = jnp.ones((N,))
  #######################################
  while i < 200:
    key,subkey = jax.random.split(key)
    #oh_mask = jax.random.bernoulli(subkey,0.5,(L,)).astype(jnp.float32)
    n = 0 #np.random.randint(0,5)
    state, outs = step_ADAM(i, state, subkey, model_params[n], opt={"oh_mask":oh_mask,
                                                              "msa_mask":msa_mask,
                                                              "ala_mask":ala_mask,
                                                              #"bb_weight":0.1,
                                                              "sc_weight":1.0,
                                                              "sc_weight_rmsd":1.0,
                                                              "sc_weight_fape":1.0,
                                                              "sc_weight_dgram":1.0,
                                                              "sc_weight_TrR":1.0,
                                                              "ent_weight":0.0,
                                                              "rg_weight":0.0,
                                                              "conf_weight":0.01,
                                                              "pos_idx":pos_idx_,
                                                              })


    i += 1
    PRINT = False
    if outs["losses"]["rmsd"] < RMSD_min:
      OLD_PDB = f"tmp_rmsd/{RMSD_min:.2f}_{ia}_{ib}_{ic}_s{seed}_r3_cce_adam.pdb"
      if os.path.isfile(OLD_PDB): os.remove(OLD_PDB)
      RMSD_min = outs["losses"]["rmsd"]
      NEW_PDB = f"tmp_rmsd/{RMSD_min:.2f}_{ia}_{ib}_{ic}_s{seed}_r3_cce_adam.pdb"
      save_pdb(outs,NEW_PDB)
      PRINT = True

    if outs["losses"]["fape"] < FAPE_min:
      OLD_PDB = f"tmp_fape/{FAPE_min:.2f}_{ia}_{ib}_{ic}_s{seed}_r3_cce_adam.pdb"
      if os.path.isfile(OLD_PDB): os.remove(OLD_PDB)
      FAPE_min = outs["losses"]["fape"]
      NEW_PDB = f"tmp_fape/{FAPE_min:.2f}_{ia}_{ib}_{ic}_s{seed}_r3_cce_adam.pdb"
      save_pdb(outs,NEW_PDB)
      PRINT = True

    if PRINT:
      print(f'[{ia} {ib} {ic}] {seed} {i} {int(oh_mask.sum())}\
| rmsd: {outs["losses"]["rmsd"]:.3f} fape: {outs["losses"]["fape"]:.3f} dgram: {outs["losses"]["dgram"]:.3f}\
| rg: {outs["losses"]["rg"]:.3f} ent: {outs["losses"]["ent"]:.3f}\
| pae: {outs["losses"]["pae"]:.3f} plddt: {outs["losses"]["plddt"]:.3f}\
| TrR_cce: {outs["losses"]["TrR_cce"]:.3f} TrR_bkg: {outs["losses"]["TrR_bkg"]:.3f}')
  #######################################

[27 47 86] 589 1 100| rmsd: 10.146 fape: 0.724 dgram: 3.381| rg: 14.714 ent: -0.000| pae: 36.923 plddt: 31.733| TrR_cce: 10.088 TrR_bkg: -0.766
[27 47 86] 589 2 100| rmsd: 4.340 fape: 0.677 dgram: 2.746| rg: 15.231 ent: -0.000| pae: 33.973 plddt: 27.431| TrR_cce: 9.419 TrR_bkg: -0.674
[27 47 86] 589 4 100| rmsd: 6.738 fape: 0.675 dgram: 2.986| rg: 16.792 ent: -0.000| pae: 37.804 plddt: 28.937| TrR_cce: 8.119 TrR_bkg: -1.025
[27 47 86] 589 14 100| rmsd: 3.231 fape: 0.660 dgram: 2.716| rg: 14.566 ent: -0.000| pae: 33.616 plddt: 28.041| TrR_cce: 7.127 TrR_bkg: -1.600
[27 47 86] 589 22 100| rmsd: 2.756 fape: 0.655 dgram: 2.599| rg: 14.635 ent: -0.000| pae: 26.652 plddt: 22.450| TrR_cce: 7.445 TrR_bkg: -2.019
[27 47 86] 589 23 100| rmsd: 3.560 fape: 0.644 dgram: 2.961| rg: 15.335 ent: -0.000| pae: 31.197 plddt: 22.334| TrR_cce: 7.280 TrR_bkg: -2.270
[27 47 86] 589 31 100| rmsd: 2.527 fape: 0.677 dgram: 2.543| rg: 14.661 ent: -0.000| pae: 34.584 plddt: 28.762| TrR_cce: 6.988 TrR_bkg: -1.909


KeyboardInterrupt: 