### AF2BIND: Prediction of ligand-binding sites using AlphaFold2

AF2BIND is a simple and fast notebook that runs inference on the output obtained from [AlphaFold2](https://github.com/deepmind/alphafold).

<!--<img src="https://raw.githubusercontent.com/artemg97/af2bind_prod/main/logo.png" width="300">.-->

<figure>
<img src='https://raw.githubusercontent.com/artemg97/af2bind_prod/main/logo.png'  width="300" height="150"/>
</figure>



For more details see preprint:

**AF2BIND: Predicting ligand-binding sites using the pair representation of AlphaFold2**
* Artem Gazizov, Anna Lian, Casper Alexander Goverde, Sergey Ovchinnikov, Nicholas F. Polizzi
* https://doi.org/10.1101/2023.10.15.562410


In [None]:
%%time
#@title Install AlphaFold2 (~2 mins)
#@markdown Please execute this cell by pressing the *Play* button on
#@markdown the left.
import os, time
if not os.path.isdir("params"):
  # get code
  print("installing ColabDesign")
  os.system("(mkdir params; apt-get install aria2 -qq; \
  aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar; \
  mkdir af2bind_params; \
  wget -qnc https://github.com/sokrypton/af2bind/raw/main/attempt_7_2k_lam0-03.zip; unzip attempt_7_2k_lam0-03.zip -d af2bind_params; \
  tar -xf alphafold_params_2021-07-14.tar -C params; touch params/done.txt )&")

  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")

  # download params
  if not os.path.isfile("params/done.txt"):
    print("downloading params")
    while not os.path.isfile("params/done.txt"):
      time.sleep(5)

import os
from colabdesign import mk_afdesign_model, clear_mem
from IPython.display import HTML
from google.colab import files
import numpy as np

from colabdesign.af.alphafold.common import residue_constants
import pandas as pd
from google.colab import data_table
data_table._DEFAULT_FORMATTERS[float] = lambda x: f"{x:.3f}"
from IPython.display import display, HTML
import jax, pickle
import jax.numpy as jnp
from scipy.special import expit as sigmoid
import plotly.express as px

import py3Dmol
import matplotlib.pyplot as plt
from scipy.special import softmax
import copy
from colabdesign.shared.protein import renum_pdb_str
from colabdesign.af.alphafold.common import protein


aa_order = {v:k for k,v in residue_constants.restype_order.items()}

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v4.pdb")
    return f"AF-{pdb_code}-F1-model_v4.pdb"

def af2bind(outputs, mask_sidechains=True, seed=0):
  pair_A = outputs["representations"]["pair"][:-20,-20:]
  pair_B = outputs["representations"]["pair"][-20:,:-20].swapaxes(0,1)
  pair_A = pair_A.reshape(pair_A.shape[0],-1)
  pair_B = pair_B.reshape(pair_B.shape[0],-1)
  x = np.concatenate([pair_A,pair_B],-1)

  # get params
  if mask_sidechains:
    model_type = f"split_nosc_pair_A_split_nosc_pair_B_{seed}"
  else:
    model_type = f"split_pair_A_split_pair_B_{seed}"
  with open(f"af2bind_params/attempt_7_2k_lam0-03/{model_type}.pickle","rb") as handle:
    params_ = pickle.load(handle)
  params_ = dict(**params_["~"], **params_["linear"])
  p = jax.tree_map(lambda x:np.asarray(x), params_)

  # get predictions
  x = (x - p["mean"]) / p["std"]
  x = (x * p["w"][:,0]) + (p["b"] / x.shape[-1])
  p_bind_aa = x.reshape(x.shape[0],2,20,-1).sum((1,3))
  p_bind = sigmoid(p_bind_aa.sum(-1))
  return {"p_bind":p_bind, "p_bind_aa":p_bind_aa}

In [None]:
#@title **Run AF2BIND** 🔬
target_pdb = "6w70" #@param {type:"string"}
target_chain = "A" #@param {type:"string"}
#@markdown - Please indicate target pdb (or uniprot ID to download from AlphaFoldDB) and chain.
#@markdown - Leave pdb blank for custom upload prompt.
mask_sidechains = True
mask_sequence = False

target_pdb = target_pdb.replace(" ","")
target_chain = target_chain.replace(" ","")
if target_chain == "":
  target_chain = "A"

pdb_filename = get_pdb(target_pdb)

clear_mem()
af_model = mk_afdesign_model(protocol="binder", debug=True)
af_model.prep_inputs(pdb_filename=pdb_filename,
                     chain=target_chain,
                     binder_len=20,
                     rm_target_sc=mask_sidechains,
                     rm_target_seq=mask_sequence)

# split
r_idx = af_model._inputs["residue_index"][-20] + (1 + np.arange(20)) * 50
af_model._inputs["residue_index"][-20:] = r_idx.flatten()

af_model.set_seq("ACDEFGHIKLMNPQRSTVWY")
af_model.predict(verbose=False)

o = af2bind(af_model.aux["debug"]["outputs"],
            mask_sidechains=mask_sidechains)
pred_bind = o["p_bind"].copy()
pred_bind_aa = o["p_bind_aa"].copy()

#######################################################
labels = ["chain","resi","resn","p(bind)"]
data = []
for i in range(af_model._target_len):
  c = af_model._pdb["idx"]["chain"][i]
  r = af_model._pdb["idx"]["residue"][i]
  a = aa_order.get(af_model._pdb["batch"]["aatype"][i],"X")
  p = pred_bind[i]
  data.append([c,r,a,p])

df = pd.DataFrame(data, columns=labels)
df.to_csv('results.csv')

data_table.enable_dataframe_formatter()
df_sorted = df.sort_values("p(bind)",ascending=False, ignore_index=True).rename_axis('rank').reset_index()
display(data_table.DataTable(df_sorted, min_width=100, num_rows_per_page=15, include_index=False))

top_n = 15
top_n_idx = pred_bind.argsort()[::-1][:15]
pymol_cmd="select ch"+str(target_chain)+","
for n,i in enumerate(top_n_idx):
  p = pred_bind[i]
  c = af_model._pdb["idx"]["chain"][i]
  r = af_model._pdb["idx"]["residue"][i]
  pymol_cmd += f" resi {r}"
  if n < top_n-1:
    pymol_cmd += " +"

print("\n🧪Pymol Selection Cmd:")
print(pymol_cmd)

In [None]:
#@title **Display Structure** (Colored by Confidence)
rescale_by_max_pbind = False
use_native_coordinates = True
show_ligand = False

if rescale_by_max_pbind:
  preds_adj = pred_bind.copy() / pred_bind.max()
else:
  preds_adj = pred_bind.copy()

# replace plddt and coordinates of prediction
L = af_model._target_len
aux = copy.deepcopy(af_model.aux["all"])
aux["plddt"][:,:L] = preds_adj
if show_ligand:
  af_model.save_pdb("output.pdb",aux={"all":aux})
else:
  aux["atom_mask"][:,L:] = 0
  x = {k:[] for k in ["aatype",
                      "residue_index",
                      "atom_positions",
                      "atom_mask",
                      "b_factors"]}
  asym_id = []
  for i in range(af_model._target_len):
    for k in ["aatype","atom_mask"]: x[k].append(aux[k][0,i])
    if use_native_coordinates:
      x["atom_positions"].append(af_model._pdb["batch"]["all_atom_positions"][i])
    else:
      x["atom_positions"].append(aux["atom_positions"][0,i])
    x["residue_index"].append(af_model._pdb["idx"]["residue"][i])
    x["b_factors"].append(x["atom_mask"][-1] * aux["plddt"][0,i] * 100.0)
    asym_id.append(af_model._pdb["idx"]["chain"][i])
  x = {k:np.array(v) for k,v in x.items()}

  # fix the chains
  (n,resnum_) = (0,None)
  pdb_lines = []
  for line in protein.to_pdb(protein.Protein(**x)).splitlines():
    if line[:4] == "ATOM":
      resnum = int(line[22:22+5])
      if resnum_ is None: resnum_ = resnum
      if resnum != resnum_:
        n += 1
        resnum_ = resnum
      pdb_lines.append("%s%s%4i%s" % (line[:21],asym_id[n],resnum,line[26:]))
  with open("output.pdb","w") as handle:
    handle.write("\n".join(pdb_lines))

hbondCutoff = 4.0
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',width=800,height=600)
pdb_str = open("output.pdb",'r').read()
view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
color_scheme = {'prop':'b','gradient': 'rwb','min':0,'max':100}
view.setStyle({'cartoon': {'colorscheme': color_scheme}})

# add sidechains
for i in range(af_model._target_len):
  c = af_model._pdb["idx"]["chain"][i]
  r = int(af_model._pdb["idx"]["residue"][i])
  p = pred_bind[i]
  if p > 0.5:
    view.addStyle({'and':[{'chain':c},{'resi':r},{'resn':["GLY","PRO"],'invert':True},{'atom':['C','O','N'],'invert':True}]},
                  {'stick':{'colorscheme':color_scheme,'radius':0.3}})
    view.addStyle({'and':[{'chain':c},{'resi':r},{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':color_scheme,'radius':0.3}})
    view.addStyle({'and':[{'chain':c},{'resi':r},{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':color_scheme,'radius':0.3}})

view.setHoverable({}, True,
               '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(atom.chain+"/"+atom.resi+"/"+atom.resn+" "+(atom.b/100.0).toFixed(3),{position:atom,backgroundColor:'white',backgroundOpacity:0.75,borderColor:'black',borderThickness:2.0,fontColor:'black'});}}''',
               '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')

view.zoomTo()
view.show()

def plot_plddt_legend(dpi=100):
  thresh = ['p(bind):','0.00','0.25','0.50','0.75','1.00']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["white","#FF0000","#FF8080","#FFFFFF","#8080FF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt
plot_plddt_legend().show()

In [None]:
#@title **Download Predictions**
from google.colab import files
os.system(f"zip -r output.zip output.pdb results.csv")
files.download(f'output.zip')

In [None]:
#@title **Activation analysis** (optional)
pbind_cutoff = 0.5 # @param ["0.0", "0.5", "0.9"] {type:"raw"}
blosum_map = list("CSTAGPDEQNHRKMILVWYF")
cs_label_list = list("ACDEFGHIKLMNPQRSTVWY")

indices_A_Y_mapping = np.array([cs_label_list.index(letter) for letter in blosum_map])
pred_bind_aa_blosum = pred_bind_aa[:,indices_A_Y_mapping]
filt = pred_bind > pbind_cutoff
pred_bind_aa_blosum = pred_bind_aa_blosum[filt]
res_labels = np.array(af_model._pdb["idx"]["residue"])[filt]
chain_labels = np.array(af_model._pdb["idx"]["chain"])[filt]

fig = px.imshow(pred_bind_aa_blosum.T,
                labels=dict(x="positions", y="amino acids", color="pref"),
                y=blosum_map,
                x=[f"{y}_{x}" for x,y in zip(res_labels,chain_labels)],
                zmin=-1,
                zmax=1,
                template="simple_white",
                color_continuous_scale=["red", "white", "blue"],
              )
fig.show()