<a href="https://colab.research.google.com/github/xiergo/GRASP-JAX/blob/main/GRASP-JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GRASP-JAX

Integrating Diverse Experimental Information to Assist Protein Complex Structure Prediction


## GRASP Notebook

GRASP is a tool designed to enhance protein complex structure prediction by incorporating various experimental restraints. This notebook offers a simplified, size-limited version of the pipeline. For optimal results, we recommend running [GRASP](https://github.com/xiergo/GRASP-JAX) locally.

The notebook leverages the [MMSeqs2](https://github.com/soedinglab/MMseqs2.git) server, provided by [ColabFold](https://github.com/sokrypton/ColabFold), to perform fast and efficient homology searches. This step helps identify homologous sequences and templates, which are critical for accurate structure prediction.  

If you use this notebook, please cite the following papers:

*   Yuhao Xie, Chengwei Zhang, Shimian Li, Xinyu Du, Yanjiao Lu, Min Wang, Yingtong Hu, Zhenyu Chen, Sirui Liu, Yi Qin Gao. "[Integrating diverse experimental information to assist protein complex structure prediction by GRASP](https://www.nature.com/articles/s41592-025-02820-1)" Nature Methods (2025)

*   Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. "[ColabFold: Making protein folding accessible to all.](https://www.nature.com/articles/s41592-022-01488-1)" Nature Methods (2022)


In [3]:
#@title Download GRASP-JAX and GRASP params
%%time
import os
import subprocess
# from sys import version_info
# PYTHON_VERSION = f"{version_info.major}.{version_info.minor}"
PYTHON_VERSION='3.9'

def run_command(cmd, verbose=False):
  """
  :param cmd: list or str, command
  :param verbose: bool, whether to print output
  """
  print(f"Running: {cmd}")
  process = subprocess.Popen(cmd, shell=isinstance(cmd, str),
    stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
    text=True, bufsize=1)
  output_lines = []
  for line in process.stdout:
    if verbose:
      print(line, end="")
    output_lines.append(line)
  process.wait()
  print(f"Return code: {process.returncode}")
  output = "".join(output_lines)
  if process.returncode != 0:
    raise subprocess.CalledProcessError(process.returncode, cmd, output=output)

if not os.path.isfile("CONDA_READY"):
  print("installing conda...")
  run_command("wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh")
  run_command("bash Miniforge3-Linux-x86_64.sh -bfp /usr/local")
  run_command("touch CONDA_READY")
if not os.path.isfile("HH_READY"):
  print("installing hhsuite...")
  run_command(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='{PYTHON_VERSION}'")
  run_command("touch HH_READY")
if not os.path.isfile("AMBER_READY"):
  print("installing amber...")
  run_command(f"mamba install -y -c conda-forge openmm=8.0.0 python='{PYTHON_VERSION}' pdbfixer")
  run_command("touch AMBER_READY")
if not os.path.isfile("GRASP_READY"):
  print("installing GRASP...")
  run_command(f"pip install --no-cache-dir --no-warn-conflicts git+https://github.com/xiergo/GRASP-JAX.git")
  run_command(f"pip install --no-warn-conflicts --upgrade jax[cuda12_pip]==0.4.30 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  run_command("git clone https://github.com/xiergo/GRASP-JAX.git")
  run_command("tar -xzf GRASP-JAX/examples/7NXX_dimer_A_B.tar.gz")
  run_command("touch GRASP_READY")
if not os.path.isfile('PARAMS_READY'):
  print('Downloading GRASP parameters...')
  run_command('wget -O params.zip https://files.osf.io/v1/resources/6kjuq/providers/osfstorage/66f420b1ac56e4bcd59e82e5/?zip=')
  run_command('unzip params.zip')
  run_command('mkdir -p data/params')
  run_command('mv params_model_1_multimer_v3_v11*npz data/params/')
  run_command('rm params.zip')
  run_command('touch PARAMS_READY')


installing conda...
Running: wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh
Return code: 0
Running: bash Miniforge3-Linux-x86_64.sh -bfp /usr/local
Return code: 0
Running: touch CONDA_READY
Return code: 0
installing hhsuite...
Running: mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='3.9'
Return code: 0
Running: touch HH_READY
Return code: 0
installing amber...
Running: mamba install -y -c conda-forge openmm=8.0.0 python='3.9' pdbfixer
Return code: 0
Running: touch AMBER_READY
Return code: 0
installing GRASP...
Running: pip install --no-cache-dir --no-warn-conflicts git+https://github.com/xiergo/GRASP-JAX.git
Return code: 0
Running: pip install --no-warn-conflicts --upgrade jax[cuda12_pip]==0.4.30 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Return code: 0
Running: git clone https://github.com/xiergo/GRASP-JAX.git
Return code: 0
Running: tar -xzf GRASP-JAX/examples/7NXX_dimer_

In [4]:
#@title Set path and hyperparameters
#@markdown - Upload your fasta file and set the path
fasta_path = '7NXX_dimer_A_B/7NXX_dimer_A_B.fasta' #@param {type:"string"}
#@markdown - Upload your restraint file and set the path. If not provided, inference will be done without restraints.
restraints_file = '7NXX_dimer_A_B/RPR_restr.txt' #@param ["None"] {type:"string",allow-input:true}
#@markdown - Upload your features.pkl file and set the path. Path to the feature dictionary generated using AlphaFold-multimer\'s protocal. If not specified, MMseqs2 pipeline from ColabFold will be used to generate this file.
feature_pickle = 'None' #@param ["None"] {type:"string",allow-input:true}
#@markdown - Path to a directory that will store the results.
output_dir = './results' #@param {type:"string"}
#@markdown - How many predictions (each with a different random seed) will be generated per model. E.g. if this is 2 and there are 5 models then there will be 10 predictions per input.
num_multimer_predictions_per_model = 1 #@param {type:"integer"}
#@markdown - Maximum iteration for iterative restraint filtering.
iter_num = 2 #@param {type:"integer"}
#@markdown - Specifies the metric for ranking models. If set to `plddt` (default), models are ranked by pLDDT. If set to `ptm`, models are ranked using the weighted score 0.2*pTM + 0.8*ipTM. In both cases, models with a recall value not less than 0.3 are prioritized.
rank_by = "plddt" #@param {type:"string"} ['plddt', 'ptm']
#@markdown - The mode of running GRASP, "normal" or "quick".
mode = "normal" #@param {type:"string"} ['normal', 'quick']
#@markdown - The models to run the final relaxation step on. If `all`, all models are relaxed, which may be time consuming. If `best`, only the most confident model is relaxed. If `none`, relaxation is not run. Turning off relaxation might result in predictions with distracting stereochemical violations but might help in case you are having issues with the relaxation stage.
models_to_relax = "best" #@param {type:"string"} ['best', 'all', 'none']



## Run and display

In [5]:
#@title Search MSA and template using ColabFold pipeline
%%time
if feature_pickle == "None":
  print('Generating features.pkl using ColabFold pipeline...')
  feature_pickle = f'{output_dir}/colab_search/features.pkl'
  if os.path.isfile(feature_pickle):
    print(f'{feature_pickle} already exists. Skipping ColabFold pipeline.')
  else:
    cmd = f'python GRASP-JAX/search_features_colab.py {fasta_path} {output_dir}/colab_search'
    run_command(cmd, verbose=True)

Generating features.pkl using ColabFold pipeline...
Running: python GRASP-JAX/search_features_colab.py 7NXX_dimer_A_B/7NXX_dimer_A_B.fasta ./results/colab_search
outdir ./results/colab_search
sequence ATKAVCVLKGDGPVQGIINFEQKESNGPVKVWGSIKGLTEGLHGFHVHEFGDNTAGCTSAGPHFNPLSRKHGGPKDEERHVGDLGNVTADKDGVADVSIEDSVISLSGDHCIIGRTLVVHEKADDLGKGGNEESTKTGNAGSRLACGVIGIAQ:MAQVQLQESGGGSVQAGGSLRLACVASGGDTRPYITYWMGWYRQAPGKEREGVATIYTGGSGTYYSDSVEGRFTISQDKAQRTVYLQMNDLKPEDTAMYYCAAGNGALPPGRRLSPQNMDTWGPGTQVTVSSHHHH
length 289
2025-10-08 03:48:41,330 Found 7 citations for tools or databases
2025-10-08 03:48:41,335 Query 1/1: 7NXX_dimer_A_B (length 289)

  0%|          | 0/300 [elapsed: 00:00 remaining: ?]
SUBMIT:   0%|          | 0/300 [elapsed: 00:00 remaining: ?]
COMPLETE:   0%|          | 0/300 [elapsed: 00:00 remaining: ?]
COMPLETE: 100%|██████████| 300/300 [elapsed: 00:00 remaining: 00:00]
COMPLETE: 100%|██████████| 300/300 [elapsed: 00:01 remaining: 00:00]
2025-10-08 03:49:00,929 Sequence 0 found templates: [

In [6]:
#@title Run inference
%%time
# sys.path.append('/usr/local/bin/')
run_command(f'mkdir -p {output_dir}')
cmd = f'python GRASP-JAX/run_grasp.py --data_dir data --fasta_path {fasta_path} --feature_pickle {feature_pickle} --restraints_file {restraints_file} --output_dir {output_dir} --num_multimer_predictions_per_model {num_multimer_predictions_per_model} --iter_num {iter_num} --rank_by {rank_by} --mode {mode} --models_to_relax {models_to_relax}'
run_command(cmd, verbose=True)

Running: mkdir -p ./results
Return code: 0
Running: python GRASP-JAX/run_grasp.py --data_dir data --fasta_path 7NXX_dimer_A_B/7NXX_dimer_A_B.fasta --feature_pickle ./results/colab_search/features.pkl --restraints_file 7NXX_dimer_A_B/RPR_restr.txt --output_dir ./results --num_multimer_predictions_per_model 1 --iter_num 2 --rank_by plddt --mode normal --models_to_relax best
  from pkg_resources import resource_filename
I1008 03:49:40.747050 138161499800640 xla_bridge.py:889] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1008 03:49:40.771734 138161499800640 xla_bridge.py:889] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-10-08 03:49:41.277038: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas

In [7]:
#@title Display the best structure {run: "auto"}
import sys
sys.path.append(f'/usr/local/lib/python{PYTHON_VERSION}/site-packages/')
from string import ascii_uppercase,ascii_lowercase
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patheffects
import py3Dmol

pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

alphabet_list = list(ascii_uppercase+ascii_lowercase)


color = "chain" #@param ["chain", "rainbow"]
show_sidechains = True #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
pdb_file = f'{output_dir}/ranked_0.pdb'
def get_chains(pdb_file):
  with open(pdb_file,'r') as f:
    lines = [line.strip() for line in f]
  chain_list = [line[21] for line in lines if line.startswith("ATOM")]
  return sorted(list(set(chain_list)))
def show_pdb(pdb_file, show_sidechains=False, show_mainchains=False, color="chain"):
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(pdb_file,'r').read(),'pdb')

  if color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})

  elif color == "chain":
    chains = get_chains(pdb_file)
    for chain,color in zip(chains,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

show_pdb(pdb_file, show_sidechains, show_mainchains, color).show()
