<a href="https://colab.research.google.com/github/sokrypton/AM216/blob/main/AF2BIND_prod.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### AF2BIND: Prediction of ligand-binding sites using AlphaFold2

AF2BIND is a simple and fast notebook that runs inference on the output obtained from [Alphafold](https://github.com/deepmind/alphafold).


The method utilizes [ColabDesign](https://github.com/sokrypton/ColabDesign) binder protocol framework which facilitates the identification of binding sites for protein-peptide and protein-ligand complexes.

Authors/Collaborators :

*   Artem Gazizov (agazizov@fas.harvard.edu)
*    Sergey Ovchinnikov (so@fas.harvard.edu)
*    Nicholas Polizzi (nicholasf_polizzi@dfci.harvard.edu)

![](https://raw.githubusercontent.com/artemg97/af2bind_prod/main/logo_300.png)





In [None]:
%%time
#@title Install AlphaFold2 (~2 mins)
#@markdown Please execute this cell by pressing the *Play* button on
#@markdown the left.

#@markdown **Note**: This installs the Colabdesign framework
import os, time
if not os.path.isdir("params"):
  # get code
  print("installing ColabDesign")
  os.system("(mkdir params; apt-get install aria2 -qq; \
  aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar; \
  aria2c -q -x 16 https://files.ipd.uw.edu/krypton/af2bind_params.zip; \
  tar -xf alphafold_params_2022-03-02.tar -C params; unzip af2bind_params.zip; touch params/done.txt )&")

  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")

  # download params
  if not os.path.isfile("params/done.txt"):
    print("downloading params")
    while not os.path.isfile("params/done.txt"):
      time.sleep(5)

import os
from colabdesign import mk_afdesign_model
from IPython.display import HTML
from google.colab import files
import numpy as np

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v4.pdb")
    return f"AF-{pdb_code}-F1-model_v4.pdb"

In [None]:
#@title **Run AF2BIND** 🔬
target_pdb = "6W70" #@param {type:"string"}
target_chain = "A" #@param {type:"string"}
#@markdown - Please indicate target pdb and chain (leave pdb blank for custom upload)
pdb_filename = get_pdb(target_pdb)
top_n = 15
model_type = 'ligand_model' #@param ["ligand_model", "peptide_model"]
import jax, pickle
def af2bind(inputs,outputs,params,aux):
  p = params["af2bind"]
  x = outputs["representations"]["pair"][:-20,-20:]
  x = x.reshape(x.shape[0],-1)
  x = (x - p["scale"]["mean"])/p["scale"]["std"]
  p = p["mlp"]
  for k in  range(5):
    x = x @ p["weights"][k] + p["bias"][k]
    if k < 4: x = jax.nn.relu(x)
  x = x[:,0]
  aux["af2bind"] = jax.nn.sigmoid(x)

  # TODO (figure out if we want to do sigmoid here)
  loss = x[inputs["opt"]["af2bind_site"]]
  return {"af2bind":loss}

if "af_model" not in dir():
  af_model = mk_afdesign_model(protocol="binder", debug=True, loss_callback=af2bind)
  af_model.opt["weights"]["af2bind"] = 1.0
  af_model.opt["af2bind_site"] = 0
af_model.prep_inputs(pdb_filename=pdb_filename, chain=target_chain, binder_len=20)
af_model.set_seq("ACDEFGHIKLMNPQRSTVWY")
with open(f"{model_type}.pkl",'rb') as handle:
  af_model._params["af2bind"] = pickle.load(handle)

print("target_length",af_model._target_len)
af_model.set_opt(weights=0, af2bind_site=0)
af_model.set_weights(af2bind=1.0)
af_model.predict(verbose=False)

preds = af_model.aux["af2bind"].copy()

top_n_idx = preds.argsort()[::-1][:top_n]
pymol_cmd="select ch"+str(target_chain)+", "

print("\n 🧪 Top",top_n, "binding residues sorted by confidence: ")
residues_dict = {}
for n,i in enumerate(top_n_idx):
  p = preds[i]
  c = af_model._pdb["idx"]["chain"][i]
  r = af_model._pdb["idx"]["residue"][i]
  residues_dict[f"{c}_{r}"] = p
  pymol_cmd += " resi " + str(r)
  if n < top_n-1:
    pymol_cmd += " + "
  print((c,r),p)

print("\n🧪Pymol Selection Cmd:")
print(pymol_cmd)

In [None]:
import py3Dmol
import matplotlib.pyplot as plt
from colabdesign.shared.protein import pdb_to_string

#@title **Color the structure by confidence**
#partly inspired by OpeFold - https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb#scrollTo=rowN0bVYLe9n
#color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
rescale_by_max_conf = True #@param {type:"boolean"}
show_ligand = False #@param {type:"boolean"}

if rescale_by_max_conf:
  preds_adj = preds.copy() / preds.max()
else:
  preds_adj = preds.copy()

af_model.aux["all"]["plddt"][:,:-20] = preds_adj
af_model.save_current_pdb("color_by_conf.pdb")

view = py3Dmol.view(width=800, height=400)
view.addModel(pdb_to_string("color_by_conf.pdb",chains=None if show_ligand else ["A"]))
view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
view.zoomTo()
view.show()

def plot_plddt_legend(dpi=100):
  thresh = ['confidence:','<50','60','70','80','>90']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt
plot_plddt_legend().show()

In [None]:
#@title **Run Saliency**
import matplotlib.pyplot as plt
from scipy.special import softmax
top = 0 #@param ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12","13","14"] {type:"raw"}
hard = True #@param {type:"boolean"}
soft = False #@param {type:"boolean"}
alpha = 2.0 #@param {type:"raw"}
show_distogram = False #@param {type:"boolean"}
i = preds.argsort()[::-1][top]
c = af_model._pdb["idx"]["chain"][i]
r = af_model._pdb["idx"]["residue"][i]

af_model.set_opt(af2bind_site=i,
                 dropout=False,
                 soft=soft,
                 hard=hard,
                 alpha=alpha)
af_model.run(models=[0])
saliency_map = af_model.aux["grad"]["seq"][0]
blosum_map = list("CSTAGPDEQNHRKMILVWYF")
cs_label_list = list("ACDEFGHIKLMNPQRSTVWY")
af_label_list = list("ARNDCQEGHILKMFPSTWYV")

indices_A_Y_mapping = np.array([cs_label_list.index(letter) for letter in blosum_map])
indices_A_R_mapping = np.array([af_label_list.index(letter) for letter in blosum_map])
saliency_map = saliency_map[indices_A_Y_mapping,:][:,indices_A_R_mapping]

max_val = np.abs(saliency_map).max()

plt.title(f"chain={c}, residue={r}")
plt.imshow(saliency_map.T, cmap="bwr_r", vmin=-max_val, vmax=max_val)
plt.xticks(np.arange(20),blosum_map)
plt.yticks(np.arange(20),blosum_map)
plt.xlabel("inputs"); plt.ylabel("gradient of aminoacids");
plt.colorbar()
plt.show()

if show_distogram:
  plt.figure(figsize=(18,5))
  plt.imshow(softmax(af_model.aux["debug"]["outputs"]["distogram"]["logits"][i,-20:],-1)[...,:-1])
  plt.yticks(np.arange(20),cs_label_list)
  plt.xticks(np.arange(63)[::5],np.round(np.append(0,af_model.aux["debug"]["outputs"]["distogram"]["bin_edges"])[:-1][::5],2))
  plt.xlabel("distances (angstroms)")
  plt.colorbar()
  plt.show()