# ColabDock
Inverting AlphaFold2 structure prediction model for protein-protein docking with experimental restraints.



In [None]:
#@title Download AlphaFold2 params and install the dependencies
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/JeffSHF/ColabDock.git@dev")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

os.system("git clone -b dev https://github.com/JeffSHF/ColabDock.git")
os.system("cp -r ./ColabDock/protein/4HFF ./")
os.system("rm -r ./ColabDock")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import re
import ml_collections
from ipywidgets import widgets, HBox
from IPython.display import display

from colabdock.utils import prep_path
from colabdock.model import ColabDock

In [None]:
#@title Upload input PDB file
from google.colab import files
import os
pdb_file  = files.upload()
file_name = list(pdb_file.keys())[0]
wd        = os.getcwd()

In [None]:
#@title Define restraints
#@markdown Provide the position of Sr50 and AvrSr50 to use as restraints. For instance, 711,99.

restraint_1 = "824,95" #@param {type:"string"}
restraint_2 = "904,95" #@param {type:"string"}
restraint_3 = "" #@param {type:"string"}
restraint_4 = "" #@param {type:"string"}
restraint_5 = "" #@param {type:"string"}
restraint_6 = "" #@param {type:"string"}
restraint_7 = "" #@param {type:"string"}
restraint_8 = "" #@param {type:"string"}
restraint_9 = "" #@param {type:"string"}
restraint_10 = None #@param {type:"string"}
restraint_11 = None #@param {type:"string"}
restraint_12 = None #@param {type:"string"}

restraints = [restraint_1, restraint_2, restraint_3, restraint_4, restraint_5,
              restraint_6, restraint_7, restraint_8, restraint_9, restraint_10,
              restraint_11, restraint_12]

restraint_list = []
for rest in restraints:
  if rest != None and rest != "":
    try:
      receptor, effector = map(int, rest.split(","))
      #receptor position 428 is reindexed to 1
      receptor_modified = receptor - 427
      #effector position 1 is reindexed to 530
      #i.e. the length of NBARC latch + LRR is 529
      effector_modified = effector + 529

      restraint_list.append([receptor_modified,effector_modified])

    except Exception:
      print(f"{rest} is not correctly defined")

  rest_1v1 = restraint_list

In [None]:
#@title Define input parameters
config = {}

template = f"{wd}/{file_name}"
chains = 'A,B'

config['chains'] = chains
config['template'] = template
config['native'] = None
config['fixed_chains'] = None
config['rest_1vN'] = None
config['rest_MvN'] = None
config['rest_rep'] = None

#@markdown - Threshold of the restraints, between 2Å and 22Å.
res_thres = 8.0 #@param [8.0] {type:"raw",allow-input:true}

#@markdown - Threshold of the repulsive restraints, between 2Å and 22Å.<br />
#@markdown Repulsive restraints means the distance of two residues is above the given threshold.
rep_thres = 12.0 #@param [12.0] {type:"raw",allow-input:true}

# check the inputs
if type(res_thres) is not float:
  raise Exception('Please set res_thres according to the descriptive information!')
else:
  config['res_thres'] = res_thres

if type(rep_thres) is not float:
  raise Exception('Please set rep_thres according to the descriptive information!')
else:
  config['rep_thres'] = rep_thres


#@markdown - path to save the results
save_path = './results3' #@param {type:"string"}
config['save_path'] = save_path

#@markdown - Segment based optimization
#@markdown -- Setting to None is suggested. If out of memory error is encountered in the generation stage, consider setting it to 200.
#@markdown But this may lead to degenerated performance. For more details, please refer to the paper.
crop_len = None #@param ["None", 200] {type:"raw",allow-input:true}
config['crop_len'] = crop_len

#@markdown - Rounds
#@markdown -- Large rounds can achive better performance but lead to longer time.
rounds = 1 #@param [1,5,10] {type:"raw",allow-input:true}
config['rounds'] = rounds

#@markdown - Steps
#@markdown -- The number of backpropogations in each round.
#@markdown -- If in segment based optimization, set to larger value, for example 150. Otherwise, setting to 50 is enough.
steps = 50 #@param [50, 150] {type:"raw",allow-input:true}
config['steps'] = steps

#@markdown - save_every_n_step
#@markdown -- Save one conformtion in every save_every_n_step step.
#@markdown Useful in segment based optimization, since the number of steps is larger
#@markdown and saving conformations in every step will take too much time.
#@markdown If in segment based optimization, set to larger value, for example 3. Otherwise, setting to 1 is OK.
save_every_n_step = 1 #@param [1, 3] {type:"raw",allow-input:true}
config['save_every_n_step'] = save_every_n_step

#@markdown - bfloat
#@markdown -- Use AF2 in bfloat mode. Turning this on can save GPU memory and time.
bfloat = True #@param ["True", "False"] {type:"raw"}
config['bfloat'] = bfloat

config['data_dir'] = './params'

# check the inputs
config['rest_1v1'] = restraint_list

In [None]:
#@title Advanced settings
#@markdown - The weights of each chain in the complex. Run this cell and set using the
#@markdown displayed sliders.
#@markdown -- If you allow the structures of certain chains in the final docking structure
#@markdown different from those in the input template, to better satisfy the given restraints,
#@markdown you can set this parameter.
#@markdown -- Each chain has a value between 0 and 1. With this value increasing,
#@markdown the structure of the chain in the generation stage is getting similar
#@markdown to that in the input template.
#@markdown -- Normally, if your input template is accurate, leave it as the default value.
chains_lst = [c.strip() for c in chains.split(",")]
slider_lst = []
for ichain in chains_lst:
  islider = widgets.FloatSlider(
    value=1.00,
    min=0.00,
    max=1.00,
    step=0.01,
    description=f'Chain {ichain}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f'
  )
  slider_lst.append(islider)

ui = widgets.HBox(slider_lst)
display(ui)

HBox(children=(FloatSlider(value=1.0, continuous_update=False, description='Chain A', max=1.0, step=0.01), Flo…

In [None]:
#@title Run Docking
if sum([islider.value for islider in slider_lst]) == len(slider_lst):
  chain_weights = None
else:
  chain_weights = {}
  for ith in range(len(chains_lst)):
    chain_weights[chains_lst[ith]] = slider_lst[ith].value
config['chain_weights'] = chain_weights

config_ml = ml_collections.ConfigDict(config)
save_path = config_ml.save_path
prep_path(save_path)
######################################################################################
# template and native structure
######################################################################################
template_r = config_ml.template
native_r = config_ml.native
chains = config_ml.chains
template = {'pdb_path': template_r,
       'chains': chains}
native = {'pdb_path': native_r,
      'chains': chains}
fixed_chains = config_ml.fixed_chains

######################################################################################
# experimental restraints
######################################################################################
rest_MvN_r = config_ml.rest_MvN
rest_non_r = config_ml.rest_rep
rest_1vN_r = config_ml.rest_1vN
rest_1v1_r = config_ml.rest_1v1
# print restraints
print_str = f'restraints:\n'
if rest_1v1_r is None:
  print_str += '\tno 1v1 restraints provided.\n'
else:
  print_str += f'\t1v1 restraints:\n\t\t{rest_1v1_r}\n'

if rest_1vN_r is None:
  print_str += '\tno 1vN restraints provided.\n'
else:
  print_str += f'\t1vN restraints:\n\t\t{rest_1vN_r}\n'

if rest_MvN_r is None:
  print_str += '\tno MvN restraints provided.\n'
else:
  print_str += f'\tMvN restraints:\n\t\t{rest_MvN_r}\n'

if rest_non_r is None:
  print_str += '\tno repulsive restraints provided.\n'
else:
  print_str += f'\trepulsive restraints:\n\t\t{rest_non_r}\n'

# 1v1
if rest_1v1_r is not None:
  if type(rest_1v1_r[0]) is not list:
    rest_1v1_r = [rest_1v1_r]
  rest_1v1 = np.array(rest_1v1_r) - 1
else:
  rest_1v1 = None

# 1vN
if rest_1vN_r is not None:
  if type(rest_1vN_r[0]) is not list:
    rest_1vN_r = [rest_1vN_r]
  rest_1vN = []
  for irest_1vN in rest_1vN_r:
    rest_1vN.append([irest_1vN[0] - 1, np.array(irest_1vN[1]) - 1])
else:
  rest_1vN = None

# MvN
if rest_MvN_r is not None:
  if type(rest_MvN_r[-1]) is not list:
    rest_MvN_r = [rest_MvN_r]
  rest_MvN = []
  for irest_MvN in rest_MvN_r:
    irest = []
    for irest_1vN in irest_MvN[:-1]:
      irest.append([irest_1vN[0] - 1, np.array(irest_1vN[1]) - 1])
    irest.append(irest_MvN[-1])
    rest_MvN.append(irest)
else:
  rest_MvN = None

# repulsive
if rest_non_r is not None:
  if type(rest_non_r[0]) is not list:
    rest_non_r = [rest_non_r]
  rest_non = np.array(rest_non_r) - 1
else:
  rest_non = None

restraints = {'1v1': rest_1v1,
        '1vN': rest_1vN,
        'MvN': rest_MvN,
        'non': rest_non}

res_thres = config_ml.res_thres
non_thres = config_ml.rep_thres

######################################################################################
# optimization parameters
######################################################################################
rounds = config_ml.rounds
crop_len = config_ml.crop_len
step_num = config_ml.steps
save_every_n_step = config_ml.save_every_n_step
data_dir = config_ml.data_dir
bfloat = config_ml.bfloat

######################################################################################
# chain weights
######################################################################################
chain_weights = config_ml.chain_weights

######################################################################################
# print setting
######################################################################################
print_str += '\nOptimization losses include:\n\t'
if rest_1v1 is not None:
    print_str += '1v1 restraint loss, '
if rest_1vN is not None:
    print_str += '1vN restraint loss, '
if rest_MvN is not None:
    print_str += 'MvN restraint loss, '
if rest_non is not None:
    print_str += 'repulsive restraint loss, '
print_str += 'distogram loss, pLDDT, and ipAE.\n'

if chain_weights:
  print_str += f'\nChain weights:\n\t'
  for ik, iv in chain_weights.items():
    print_str += f'{ik}:{iv:.2f}\t'

######################################################################################
# start docking
######################################################################################
dock_model = ColabDock(template,
             restraints,
             save_path,
             data_dir,
             structure_gt=native,
             crop_len=crop_len,
             fixed_chains=fixed_chains,
             chain_weights=chain_weights,
             round_num=rounds,
             step_num=step_num,
             bfloat=bfloat,
             res_thres=res_thres,
             non_thres=non_thres,
             save_every_n_step=save_every_n_step)
dock_model.setup()
if dock_model.crop_len is not None:
    print_str += 'Colabdock will work in segment based mode.'
print(print_str)
print('\nStart optimization')
dock_model.dock_rank()

In [None]:
#@title Display the best structure {run: "auto"}

from string import ascii_uppercase,ascii_lowercase
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patheffects
import py3Dmol

pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

alphabet_list = list(ascii_uppercase+ascii_lowercase)


rank_num = 1 #@param ["1", "2", "3", "4", "5"] {type:"raw"}
color = "rainbow" #@param ["chain", "rainbow"]
show_sidechains = True #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}


def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color="chain"):
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  pdb_file = f'{config_ml.save_path}/docked/1st_best.pdb'
  view.addModel(open(pdb_file,'r').read(),'pdb')

  if color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = len(config.template['chains'].split(','))
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

show_pdb(rank_num, show_sidechains, show_mainchains, color).show()


In [None]:
from google.colab import files

!zip -r {save_path}.zip {save_path}
files.download(f'{save_path}.zip')
