<a href="https://colab.research.google.com/github/xiergo/GRASP-JAX/blob/main/GRASP-JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GRASP-JAX
Integrating Diverse Experimental Information to Assist Protein Complex Structure Prediction


In [1]:
#@title Download GRASP-JAX and GRASP params
%%time
import os
import subprocess
from sys import version_info
PYTHON_VERSION = f"{version_info.major}.{version_info.minor}"

def run_cmd(cmd):
  res = subprocess.run(cmd.split(), capture_output=True, text=True)
  if res.returncode != 0:
    print(res.stdout)
    print(res.stderr)
    raise Exception(f'Command failed {cmd}')

if not os.path.isfile("CONDA_READY"):
  print("installing conda...")
  os.system("wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh")
  os.system("bash Miniforge3-Linux-x86_64.sh -bfp /usr/local")
  os.system("mamba config --set auto_update_conda false")
  os.system("touch CONDA_READY")
if not os.path.isfile("HH_READY"):
  print("installing hhsuite...")
  os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='{PYTHON_VERSION}'")
  os.system("touch HH_READY")
if not os.path.isfile("AMBER_READY"):
  print("installing amber...")
  os.system(f"mamba install -y -c conda-forge openmm=8.0.0 python='{PYTHON_VERSION}' pdbfixer")
  os.system("touch AMBER_READY")
if not os.path.isfile("GRASP_READY"):
  print("installing GRASP...")
  run_cmd("pip install --no-warn-conflicts git+https://github.com/xiergo/GRASP-JAX.git")
  from Bio.PDB import PDBParser
  os.system("pip install --no-warn-conflicts --upgrade jax[cuda12_pip]==0.4.30 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("git clone https://github.com/xiergo/GRASP-JAX.git")
  os.system("tar -xzf GRASP-JAX/examples/7NXX_dimer_A_B.tar.gz")
  os.system("touch GRASP_READY")
if not os.path.isfile('PARAMS_READY'):
  print('Downloading GRASP parameters...')
  os.system('wget -O params.zip https://files.osf.io/v1/resources/6kjuq/providers/osfstorage/66f420b1ac56e4bcd59e82e5/?zip=')
  os.system('unzip params.zip')
  os.system('mkdir -p data/params')
  os.system('mv params_model_1_multimer_v3_v11*npz data/params/')
  os.system('rm params.zip')
  os.system('touch PARAMS_READY')

installing GRASP...
installing conda...
installing hhsuite...
installing amber...
Downloading GRASP parameters...
CPU times: user 608 ms, sys: 104 ms, total: 712 ms
Wall time: 3min 30s


In [None]:
#@title Set path and hyperparameters
#@markdown - Upload your fasta file and set the path
fasta_path = '7NXX_dimer_A_B/7NXX_dimer_A_B.fasta' #@param {type:"string"}
#@markdown - Upload your restraint file and set the path. If not provided, inference will be done without restraints.
restraints_file = '7NXX_dimer_A_B/RPR_restr.txt' #@param ["None"] {type:"string",allow-input:true}
#@markdown - Upload your features.pkl file and set the path. Path to the feature dictionary generated using AlphaFold-multimer\'s protocal. If not specified, other arguments used for generating features will be required.
feature_pickle = 'None' #@param ["None"] {type:"string",allow-input:true}
#@markdown - Path to a directory that will store the results.
output_dir = './results' #@param {type:"string"}
#@markdown - How many predictions (each with a different random seed) will be generated per model. E.g. if this is 2 and there are 5 models then there will be 10 predictions per input.
num_multimer_predictions_per_model = 1 #@param {type:"integer"}
#@markdown - Maximum iteration for iterative restraint filtering.
iter_num = 2 #@param {type:"integer"}
#@markdown - Maximum iteration for iterative restraint filtering.
rank_by = "plddt" #@param {type:"string"} ['plddt', 'ptm']
#@markdown - The mode of running GRASP, "normal" or "quick".
mode = "normal" #@param {type:"string"} ['normal', 'quick']
#@markdown - The models to run the final relaxation step on. If `all`, all models are relaxed, which may be time consuming. If `best`, only the most confident model is relaxed. If `none`, relaxation is not run. Turning off relaxation might result in predictions with distracting stereochemical violations but might help in case you are having issues with the relaxation stage.
models_to_relax = "best" #@param {type:"string"} ['best', 'all', 'none']



## Run and display

In [None]:
#@title Search MSA and template using ColabFold pipeline
%%time
if feature_pickle == "None":
  print('Generating features.pkl using ColabFold pipeline...')
  feature_pickle = f'{output_dir}/colab_search/features.pkl'
  if os.path.isfile(feature_pickle):
    print(f'{feature_pickle} already exists. Skipping ColabFold pipeline.')
  else:
    cmd = f'python GRASP-JAX/search_features_colab.py {fasta_path} {output_dir}/colab_search'
    print(cmd)
    import subprocess
    res = subprocess.run(cmd.split(), capture_output=True, text=True)
    print(res.stdout)
    print(res.stderr)
    if res.returncode != 0:
      feature_pickle='None'
      raise Exception('ColabFold pipeline failed')


Generating features.pkl using ColabFold pipeline...
python GRASP-JAX/search_features_colab.py 7NXX_dimer_A_B/7NXX_dimer_A_B.fasta ./results/colab_search
outdir ./results/colab_search
sequence ATKAVCVLKGDGPVQGIINFEQKESNGPVKVWGSIKGLTEGLHGFHVHEFGDNTAGCTSAGPHFNPLSRKHGGPKDEERHVGDLGNVTADKDGVADVSIEDSVISLSGDHCIIGRTLVVHEKADDLGKGGNEESTKTGNAGSRLACGVIGIAQ:MAQVQLQESGGGSVQAGGSLRLACVASGGDTRPYITYWMGWYRQAPGKEREGVATIYTGGSGTYYSDSVEGRFTISQDKAQRTVYLQMNDLKPEDTAMYYCAAGNGALPPGRRLSPQNMDTWGPGTQVTVSSHHHH
length 289
2025-03-03 13:26:00,651 Found 7 citations for tools or databases
2025-03-03 13:26:00,658 Query 1/1: 7NXX_dimer_A_B (length 289)
2025-03-03 13:26:18,436 Sequence 0 found templates: ['2lu5_A', '1n18_H', '6dtk_G', '1fun_D', '1n19_B', '6dtk_C', '1uxl_J', '6dtk_G', '6dtk_A', '5k02_U', '3ecw_B', '2wz0_F', '1ptz_B', '3gtv_K', '6spa_C', '6dtk_A', '2zky_I', '6dtk_C']
2025-03-03 13:26:36,334 Sequence 1 found templates: ['7orb_E', '7ben_G', '6xkp_H', '6xkp_M', '7czu_J', '5ggu_A', '7rks_I', '7rks_H', '5ggv_H', '7

In [None]:
#@title Run inference
%%time
os.system(f'mkdir -p {output_dir}')
log_file = f'{output_dir}/run_grasp.log'
cmd = f'python GRASP-JAX/run_grasp.py --data_dir data --fasta_path {fasta_path} --feature_pickle {feature_pickle} --restraints_file {restraints_file} --output_dir {output_dir} --num_multimer_predictions_per_model {num_multimer_predictions_per_model} --iter_num {iter_num} --rank_by {rank_by} --mode {mode} --models_to_relax {models_to_relax} > {log_file} 2>&1'
os.system(f'echo "{cmd}" > {log_file}')
os.system(cmd)

CPU times: user 254 ms, sys: 35.8 ms, total: 290 ms
Wall time: 3min 30s


15

In [None]:
#@title Show logs
with open(log_file, 'r') as f:
  print(f.read())

I0303 03:48:16.010916 133720207927104 xla_bridge.py:889] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0303 03:48:16.012708 133720207927104 xla_bridge.py:889] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-03-03 03:48:16.057161: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.8.61). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
I0303 03:48:22.892105 133720207927104 run_grasp.py:595] Have 5 models: ['model_1_multimer_v3_v11_8000_pred_0', 'model_1_multimer_v3_v11_14000_pred_0', 'model_1_multimer_v3_v11_20000_pred_0', 'model_1_multimer_v3_v11_22000_pred_0', 'model_

In [None]:
#@title Display the best structure {run: "auto"}

from string import ascii_uppercase,ascii_lowercase
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patheffects
import py3Dmol

pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

alphabet_list = list(ascii_uppercase+ascii_lowercase)


color = "chain" #@param ["chain", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
pdb_file = f'{output_dir}/ranked_0.pdb'
def get_chains(pdb_file):
  with open(pdb_file,'r') as f:
    lines = [line.strip() for line in f]
  chain_list = [line[21] for line in lines if line.startswith("ATOM")]
  return sorted(list(set(chain_list)))
def show_pdb(pdb_file, show_sidechains=False, show_mainchains=False, color="chain"):
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(pdb_file,'r').read(),'pdb')

  if color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})

  elif color == "chain":
    chains = get_chains(pdb_file)
    for chain,color in zip(chains,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

show_pdb(pdb_file, show_sidechains, show_mainchains, color).show()
