<a href="https://colab.research.google.com/github/xiergo/GRASP-JAX/blob/main/GRASP-JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GRASP-JAX

Integrating Diverse Experimental Information to Assist Protein Complex Structure Prediction


## GRASP Notebook

GRASP is a tool designed to enhance protein complex structure prediction by incorporating various experimental restraints. This notebook offers a simplified, size-limited version of the pipeline. For optimal results, we recommend running [GRASP](https://github.com/xiergo/GRASP-JAX) locally.

The notebook leverages the [MMSeqs2](https://github.com/soedinglab/MMseqs2.git) server, provided by [ColabFold](https://github.com/sokrypton/ColabFold), to perform fast and efficient homology searches. This step helps identify homologous sequences and templates, which are critical for accurate structure prediction.  

If you use this notebook, please cite the following papers:

*   Yuhao Xie, Chengwei Zhang, Shimian Li, Xinyu Du, Min Wang, Yingtong Hu, Sirui Liu, Yi Qin Gao. "[Integrating various Experimental Information to Assist Protein Complex Structure Prediction by GRASP](https://www.biorxiv.org/content/10.1101/2024.09.16.613256v1)" bioRxiv (2024)

*   Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. "[ColabFold: Making protein folding accessible to all.](https://www.nature.com/articles/s41592-022-01488-1)" Nature Methods (2022)


In [None]:
#@title Download GRASP-JAX and GRASP params
%%time
import os
import subprocess
from sys import version_info
PYTHON_VERSION = f"{version_info.major}.{version_info.minor}"


if not os.path.isfile("CONDA_READY"):
  print("installing conda...")
  os.system("wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh")
  os.system("bash Miniforge3-Linux-x86_64.sh -bfp /usr/local")
  os.system("mamba config --set auto_update_conda false")
if not os.path.isfile("HH_READY"):
  print("installing hhsuite...")
  os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='{PYTHON_VERSION}'")
  os.system("touch HH_READY")
if not os.path.isfile("AMBER_READY"):
  print("installing amber...")
  os.system(f"mamba install -y -c conda-forge openmm=8.0.0 python='{PYTHON_VERSION}' pdbfixer")
  os.system("touch AMBER_READY")
if not os.path.isfile("GRASP_READY"):
  print("installing GRASP...")
  os.system("pip install --no-cache-dir --no-warn-conflicts git+https://github.com/xiergo/GRASP-JAX.git")
  os.system("pip install --no-warn-conflicts --upgrade jax[cuda12_pip]==0.4.30 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("git clone https://github.com/xiergo/GRASP-JAX.git")
  os.system("tar -xzf GRASP-JAX/examples/7NXX_dimer_A_B.tar.gz")
  os.system("touch GRASP_READY")
if not os.path.isfile('PARAMS_READY'):
  print('Downloading GRASP parameters...')
  os.system('wget -O params.zip https://files.osf.io/v1/resources/6kjuq/providers/osfstorage/66f420b1ac56e4bcd59e82e5/?zip=')
  os.system('unzip params.zip')
  os.system('mkdir -p data/params')
  os.system('mv params_model_1_multimer_v3_v11*npz data/params/')
  os.system('rm params.zip')
  os.system('touch PARAMS_READY')

installing conda...
installing hhsuite...
installing amber...
installing GRASP...
Downloading GRASP parameters...
CPU times: user 1.37 s, sys: 224 ms, total: 1.59 s
Wall time: 7min 57s


In [None]:
#@title Set path and hyperparameters
#@markdown - Upload your fasta file and set the path
fasta_path = '7NXX_dimer_A_B/7NXX_dimer_A_B.fasta' #@param {type:"string"}
#@markdown - Upload your restraint file and set the path. If not provided, inference will be done without restraints.
restraints_file = '7NXX_dimer_A_B/RPR_restr.txt' #@param ["None"] {type:"string",allow-input:true}
#@markdown - Upload your features.pkl file and set the path. Path to the feature dictionary generated using AlphaFold-multimer\'s protocal. If not specified, MMseqs2 pipeline from ColabFold will be used to generate this file.
feature_pickle = 'None' #@param ["None"] {type:"string",allow-input:true}
#@markdown - Path to a directory that will store the results.
output_dir = './results' #@param {type:"string"}
#@markdown - How many predictions (each with a different random seed) will be generated per model. E.g. if this is 2 and there are 5 models then there will be 10 predictions per input.
num_multimer_predictions_per_model = 1 #@param {type:"integer"}
#@markdown - Maximum iteration for iterative restraint filtering.
iter_num = 2 #@param {type:"integer"}
#@markdown - Specifies the metric for ranking models. If set to `plddt` (default), models are ranked by pLDDT. If set to `ptm`, models are ranked using the weighted score 0.2*pTM + 0.8*ipTM. In both cases, models with a recall value not less than 0.3 are prioritized.
rank_by = "plddt" #@param {type:"string"} ['plddt', 'ptm']
#@markdown - The mode of running GRASP, "normal" or "quick".
mode = "normal" #@param {type:"string"} ['normal', 'quick']
#@markdown - The models to run the final relaxation step on. If `all`, all models are relaxed, which may be time consuming. If `best`, only the most confident model is relaxed. If `none`, relaxation is not run. Turning off relaxation might result in predictions with distracting stereochemical violations but might help in case you are having issues with the relaxation stage.
models_to_relax = "best" #@param {type:"string"} ['best', 'all', 'none']



## Run and display

In [None]:
#@title Search MSA and template using ColabFold pipeline
%%time
if feature_pickle == "None":
  print('Generating features.pkl using ColabFold pipeline...')
  feature_pickle = f'{output_dir}/colab_search/features.pkl'
  if os.path.isfile(feature_pickle):
    print(f'{feature_pickle} already exists. Skipping ColabFold pipeline.')
  else:
    cmd = f'python GRASP-JAX/search_features_colab.py {fasta_path} {output_dir}/colab_search'
    print(cmd)
    os.system(cmd)

Generating features.pkl using ColabFold pipeline...
python GRASP-JAX/search_features_colab.py 7NXX_dimer_A_B/7NXX_dimer_A_B.fasta ./results/colab_search
CPU times: user 153 ms, sys: 29 ms, total: 182 ms
Wall time: 51.7 s


In [None]:
#@title Run inference
%%time
# sys.path.append('/usr/local/bin/')
os.system(f'mkdir -p {output_dir}')
log_file = f'{output_dir}/run_grasp.log'
cmd = f'python GRASP-JAX/run_grasp.py --data_dir data --fasta_path {fasta_path} --feature_pickle {feature_pickle} --restraints_file {restraints_file} --output_dir {output_dir} --num_multimer_predictions_per_model {num_multimer_predictions_per_model} --iter_num {iter_num} --rank_by {rank_by} --mode {mode} --models_to_relax {models_to_relax} > {log_file} 2>&1'
print(cmd)
os.system(f'echo "{cmd}" > {log_file}')
os.system(cmd)

python GRASP-JAX/run_grasp.py --data_dir data --fasta_path 7NXX_dimer_A_B/7NXX_dimer_A_B.fasta --feature_pickle ./results/colab_search/features.pkl --restraints_file 7NXX_dimer_A_B/RPR_restr.txt --output_dir ./results --num_multimer_predictions_per_model 1 --iter_num 2 --rank_by plddt --mode normal --models_to_relax best > ./results/run_grasp.log 2>&1
CPU times: user 1.47 s, sys: 212 ms, total: 1.68 s
Wall time: 14min 17s


0

In [None]:
#@title Show logs
with open(log_file, 'r') as f:
  print(f.read())

I0303 18:53:16.756769 139244402353984 xla_bridge.py:889] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0303 18:53:16.778076 139244402353984 xla_bridge.py:889] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-03-03 18:53:17.349580: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.8.61). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
I0303 18:53:24.293421 139244402353984 run_grasp.py:595] Have 5 models: ['model_1_multimer_v3_v11_8000_pred_0', 'model_1_multimer_v3_v11_14000_pred_0', 'model_1_multimer_v3_v11_20000_pred_0', 'model_1_multimer_v3_v11_22000_pred_0', 'model_

In [None]:
#@title Display the best structure {run: "auto"}
import sys
sys.path.append(f'/usr/local/lib/python{PYTHON_VERSION}/site-packages/')
from string import ascii_uppercase,ascii_lowercase
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patheffects
import py3Dmol

pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

alphabet_list = list(ascii_uppercase+ascii_lowercase)


color = "chain" #@param ["chain", "rainbow"]
show_sidechains = True #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
pdb_file = f'{output_dir}/ranked_0.pdb'
def get_chains(pdb_file):
  with open(pdb_file,'r') as f:
    lines = [line.strip() for line in f]
  chain_list = [line[21] for line in lines if line.startswith("ATOM")]
  return sorted(list(set(chain_list)))
def show_pdb(pdb_file, show_sidechains=False, show_mainchains=False, color="chain"):
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(pdb_file,'r').read(),'pdb')

  if color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})

  elif color == "chain":
    chains = get_chains(pdb_file)
    for chain,color in zip(chains,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

show_pdb(pdb_file, show_sidechains, show_mainchains, color).show()
