<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/RoseTTAFold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RoseTTAFold w/ PyRosetta

**Limitations**
- This notebook disables a few aspects (templates) of the full rosettafold pipeline.
- For best resuls use the [full pipeline](https://github.com/RosettaCommons/RoseTTAFold) or [Robetta webserver](https://robetta.bakerlab.org/)!
- For a typical Google-Colab session, with a `16G-GPU`, the max total length is **700 residues**. Sometimes a `12G-GPU` is assigned, in which case the max length is lower.

For other related notebooks see [ColabFold](https://github.com/sokrypton/ColabFold)

In [None]:
#@title ##Install and import libraries
if "PAPERMILL_INPUT_PATH" not in dir():
  #@markdown This step will take 2+ mins (6min PyRosetta, 2min RoseTTAFold)
  use_pyrosetta = False #@param {type:"boolean"}
  #@markdown - Use PyRosetta for structure prediction. To do so you'll need to get a [PyRosetta License](https://els2.comotion.uw.edu/product/pyrosetta) (free for academic use).
  #@markdown   Once you obtain license, enter the username and password below.
  username = ''  #@param {type:"string"}
  username.strip().lower()
  password = ''  #@param {type:"string"}

import os
import subprocess
import hashlib

import sys
from IPython.utils import io
try:
  from google.colab import files
except:
  from IPython.core.magic import register_line_magic
  @register_line_magic
  def tensorflow_version(line):
    pass
  pass

if use_pyrosetta:
  # Thanks Matteo Ferla for password check
  hashed_username = hashlib.sha256(username.encode()).hexdigest()
  hashed_password = hashlib.sha256(password.encode()).hexdigest()
  expected_hashed_username = 'cf6f296b8145262b22721e52e2edec13ce57af8c6fc990c8ae1a4aa3e50ae40e'
  expected_hashed_password = '45066dd976d8bf0c05dc8dd4d58727945c3437e6eb361ba9870097968db7a0da'
  msg = 'Error: username or password is incorrect.'
  assert hashed_username == expected_hashed_username, msg
  assert hashed_password == expected_hashed_password, msg

  dist = subprocess.check_output(['lsb_release', '-is']).strip()
  if dist == "Ubuntu":
    dist = "ubuntu"
  else:
    dist = "linux"
  pyrosetta_path = f"PyRosetta4.Release.python37.{dist}.release-306"
  if not os.path.isdir(pyrosetta_path):
    print("installing pyRosetta - 6 mins")
    with io.capture_output() as captured:
      !wget --show-progress --user {username} --password {password} https://graylab.jhu.edu/download/PyRosetta4/archive/release/PyRosetta4.Release.python37.{dist}/{pyrosetta_path}.tar.bz2 -O - | tar -xj
      !pip install -e {pyrosetta_path}/setup/

if not os.path.isdir("RoseTTAFold"):
  print("installing RoseTTAFold - 2 mins")
  with io.capture_output() as captured:
    # extra functionality
    !wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py

    # download model
    !git clone https://github.com/RosettaCommons/RoseTTAFold.git
    !wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/RoseTTAFold__network__Refine_module.patch
    !patch -u RoseTTAFold/network/Refine_module.py -i RoseTTAFold__network__Refine_module.patch

    # download model params
    !wget -qnc https://files.ipd.uw.edu/pub/RoseTTAFold/weights.tar.gz
    !tar -xf weights.tar.gz
    !rm weights.tar.gz

    # download scwrl4 (for adding sidechains)
    # http://dunbrack.fccc.edu/SCWRL3.php
    # Thanks Roland Dunbrack!
    !wget -qnc https://files.ipd.uw.edu/krypton/TrRosetta/scwrl4.zip
    !unzip -qqo scwrl4.zip

    # install libraries
    !pip install -q dgl-cu113 -f https://data.dgl.ai/wheels/repo.html
    !pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cu113.html
    !pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.11.0+cu113.html
    !pip install -q torch-geometric
    !pip install -q py3Dmol

    !wget https://openstructure.org/static/lddt-linux.zip -O lddt.zip
    !unzip -d ./RoseTTAFold/lddt -j lddt.zip

with io.capture_output() as captured:
  sys.path.append('./RoseTTAFold/network')
  import predict_e2e, predict_pyRosetta
  from parsers import parse_a3m
  
import colabfold as cf
import py3Dmol
import numpy as np
import matplotlib.pyplot as plt

def get_bfactor(pdb_filename):
  bfac = []
  for line in open(pdb_filename,"r"):
    if line[:4] == "ATOM":
      bfac.append(float(line[60:66]))
  return np.array(bfac)

def set_bfactor(pdb_filename, bfac):
  I = open(pdb_filename,"r").readlines()
  O = open(pdb_filename,"w")
  for line in I:
    if line[0:6] == "ATOM  ":
      seq_id = int(line[22:26].strip()) - 1
      O.write(f"{line[:60]}{bfac[seq_id]:6.2f}{line[66:]}")
  O.close()    

def do_scwrl(inputs, outputs, exe="./scwrl4/Scwrl4"):
  subprocess.run([exe,"-i",inputs,"-o",outputs,"-h"],
                  stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
  bfact = get_bfactor(inputs)
  set_bfactor(outputs, bfact)
  return bfact

In [None]:
#@markdown ##Input Sequence
if "PAPERMILL_INPUT_PATH" not in dir():
  sequence = "PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK" #@param {type:"string"}
  jobname = "test" #@param {type:"string"}

sequence = sequence.translate(str.maketrans('', '', ' \n\t')).upper()
jobname = jobname+"_"+cf.get_hash(sequence)[:5]

In [None]:
#@title Search against genetic databases
if "PAPERMILL_INPUT_PATH" not in dir():
  #@markdown ---
  msa_method = "mmseqs2" #@param ["mmseqs2","single_sequence","custom_a3m"]
  #@markdown - `mmseqs2` - FAST method from [ColabFold](https://github.com/sokrypton/ColabFold)
  #@markdown - `single_sequence` - use single sequence input (not recommended, unless a *denovo* design and you dont expect to find any homologous sequences)
  #@markdown - `custom_a3m` Upload custom MSA (a3m format)

# tmp directory
prefix = cf.get_hash(sequence)
os.makedirs('tmp', exist_ok=True)
prefix = os.path.join('tmp',prefix)

os.makedirs(jobname, exist_ok=True)


if msa_method == "mmseqs2":
  if "PAPERMILL_INPUT_PATH" not in dir():
    host_url="https://a3m.mmseqs.com"
  a3m_lines = cf.run_mmseqs2(sequence, prefix, filter=True, host_url=host_url)
  with open(f"{jobname}/msa.a3m","w") as a3m:
    a3m.write(a3m_lines)

elif msa_method == "single_sequence":
  with open(f"{jobname}/msa.a3m","w") as a3m:
    a3m.write(f">{jobname}\n{sequence}\n")

elif msa_method == "custom_a3m":
  if "PAPERMILL_INPUT_PATH" not in dir():
    print("upload custom a3m")
    msa_dict = files.upload()
    lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines()
    a3m_lines = []
    for line in lines:
      line = line.replace("\x00","")
      if len(line) > 0 and not line.startswith('#'):
        a3m_lines.append(line)
    with open(f"{jobname}/msa.a3m","w") as a3m:
      a3m.write("\n".join(a3m_lines))
  else:
    import shutil
    shutil.copy(a3m_file, f"{jobname}/msa.a3m")

msa_all = parse_a3m(f"{jobname}/msa.a3m")
msa_arr = np.unique(msa_all,axis=0)
total_msa_size = len(msa_arr)
if msa_method == "mmseqs2":
  print(f'\n{total_msa_size} Sequences Found in Total (after filtering)\n')
else:
  print(f'\n{total_msa_size} Sequences Found in Total\n')

if total_msa_size > 1:
  plt.figure(figsize=(8,5),dpi=100)
  plt.title("Sequence coverage")
  seqid = (msa_all[0] == msa_arr).mean(-1)
  seqid_sort = seqid.argsort()
  non_gaps = (msa_arr != 20).astype(float)
  non_gaps[non_gaps == 0] = np.nan
  plt.imshow(non_gaps[seqid_sort]*seqid[seqid_sort,None],
            interpolation='nearest', aspect='auto',
            cmap="rainbow_r", vmin=0, vmax=1, origin='lower',
            extent=(0, msa_arr.shape[1], 0, msa_arr.shape[0]))
  plt.plot((msa_arr != 20).sum(0), color='black')
  plt.xlim(0,msa_arr.shape[1])
  plt.ylim(0,msa_arr.shape[0])
  plt.colorbar(label="Sequence identity to query",)
  plt.xlabel("Positions")
  plt.ylabel("Sequences")
  plt.savefig(f"{jobname}/msa_coverage.png", bbox_inches = 'tight')
  plt.show()

In [None]:
%tensorflow_version 1.x
#@title ## Run RoseTTAFold
# load model
if "rosettafold" not in dir():
  if use_pyrosetta:
    rosettafold = predict_pyRosetta.Predictor(model_dir="weights")
  else:
    rosettafold = predict_e2e.Predictor(model_dir="weights")

# make prediction using model
if use_pyrosetta:
  if not os.path.isfile(f"{jobname}/pred.npz"):
    print("running RoseTTAFold")
    rosettafold.predict(f"{jobname}/msa.a3m",f"{jobname}/pred")
else:
  if not os.path.isfile(f"{jobname}/pred.pdb"):
    print("running RoseTTAFold")
    rosettafold.predict(f"{jobname}/msa.a3m",f"{jobname}/pred")

if not os.path.isfile(f"{jobname}/pred.fasta"):
  with open(f"{jobname}/pred.fasta","w") as out:
    out.write(f">pred\n{sequence}\n")

if use_pyrosetta:
  CPU = max(1, len(os.sched_getaffinity(0)))
  pyrosetta_script = "./RoseTTAFold/folding/RosettaTR.py"
  if not os.path.isfile(f"{jobname}/model/model_1.crderr.pdb"):
    print("running PyRosetta")
    with open(f"{jobname}/job_list","w") as job_list:
      for m in [0,1,2]:
        for p in [0.05,0.15,0.25,0.35,0.45]:
          pdb_out = f"{jobname}/model_{m}_{p}.pdb"
          opt = f"--roll -r 3 -pd {p} -m {m} -sg 7,3"
          if not os.path.isfile(pdb_out):
            job_list.write(f"python -u {pyrosetta_script} {opt} {jobname}/pred.npz {jobname}/pred.fasta {pdb_out}\n")
    os.system(f"cat {jobname}/job_list | tr '\\n' '\\0' | xargs -0 -L 1 -P {CPU} -I % sh -c '%'")
    
    print("ranking models using DAN")
    dan_script = "./RoseTTAFold/DAN-msa/ErrorPredictorMSA.py"
    dan_pick_script = "./RoseTTAFold/DAN-msa/pick_final_models.div.py"
    os.system(f"python {dan_script} --roll -p {CPU} {jobname}/pred.npz {jobname} {jobname}")
    os.system(f"python {dan_pick_script} {jobname} {jobname}/model {CPU}")
  
  plddt = get_bfactor(f"{jobname}/model/model_1.crderr.pdb")
  
else:
  # pack sidechains using Scwrl4
  plddt = do_scwrl(f"{jobname}/pred.pdb",f"{jobname}/pred.scwrl.pdb")

print(f"Predicted LDDT: {plddt.mean()}")
plt.figure(figsize=(8,5),dpi=100)
plt.plot(plddt)
plt.xlabel("positions")
plt.ylabel("plddt")
plt.ylim(0,1)
plt.savefig(f"{jobname}/plddt.png", bbox_inches = 'tight')
plt.show()

In [None]:
#@title Display 3D structure {run: "auto"}
color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
if use_pyrosetta:
  cf.show_pdb(f"{jobname}/model/model_1.crderr.pdb", show_sidechains, show_mainchains, color, chains=1, vmin=0.5, vmax=0.9).show()
else:
  cf.show_pdb(f"{jobname}/pred.scwrl.pdb", show_sidechains, show_mainchains, color, chains=1, vmin=0.5, vmax=0.9).show()

if color == "lDDT": cf.plot_plddt_legend().show()  

In [None]:
#@title Download prediction

#@markdown Once this cell has been executed, a zip-archive with 
#@markdown the obtained prediction will be automatically downloaded 
#@markdown to your computer.

# add settings file
settings_path = f"{jobname}/settings.txt"
with open(settings_path, "w") as text_file:
  text_file.write(f"method=RoseTTAFold\n")
  text_file.write(f"sequence={sequence}\n")
  text_file.write(f"msa_method={msa_method}\n")
  text_file.write(f"use_templates=False\n")

# --- Download the predictions ---
!zip -q -r {jobname}.zip {jobname}
try:
  files.download(f'{jobname}.zip')
except:
  pass
