<a href="https://colab.research.google.com/github/zephyris/discoba_alphafold/blob/main/DiscobaAlphaFold2HMMERv2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Input protein sequence(s), then hit `Runtime` -> `Run all`
from google.colab import files
import os
import re
import hashlib
import random

from sys import version_info 
python_version = f"{version_info.major}.{version_info.minor}"

def add_hash(x,y):
  return x+"_"+hashlib.sha1(y.encode()).hexdigest()[:5]

query_sequence = 'MESKPNHRRVWVTSAVVLTALVTLIVRRRQQKRQLARRTAADAMHEELAEALKQTLLPVLVRPEDLPAGWGVAPPVFYPPYCAAIPLVLPREMQESERDSSGVRSASTSRFQGGGADGLAVAPTAFLFCSYTGYTAGPFATRQDEEVFLTEVLAHHPVLRERLAENPRQWSPLPESAPHASTDAKTAAAVTEGRRQNSFPRTFSHVLVSDDCVLLAGMNGPYTVVAVLTDFAGAWPLSRTARASADGEESGRRRTGSAAPPTIAEETELSDAIEGIACATINVPLATLGSDSAALAFRVPGTCFAAHQGYYRVVCAREGRELELCVPPEWTMRSECVNTSCTTPAGPPVPARKPVSTVGEDAESKGIILTLSFTPSSFMSEGRVDVCISAELFSALYEEPQAAAATLWAASGATDVCNPVAKATLGALTARPLPASSHARTNVVTMVYVQPKFGVLFSVHPRSAVVYEPWMTEQPTILYYPLGDAADDEGSPRMTIEYVVELPKTWEVFARDDEEFVHNVLFHFTSGEAAAISTTLTEISGIRCAMFHETRESRRCRTYVLPRGATLLVIRWETLAESWDKDLPVFQQTLDTLHIDAAAVIQFWVVVTMLVIVFAMGIGVGAENPKDHNKDEHEALCSVLSLAVTLFESGQAGNKLQKALGWALFGSETGESNTASLLAAPP ' #@param {type:"string"}

# remove whitespaces
query_sequence = "".join(query_sequence.split())

jobname = 'A4IBK2' #@param {type:"string"}
# remove whitespaces
basejobname = "".join(jobname.split())
basejobname = re.sub(r'\W+', '', basejobname)
jobname = add_hash(basejobname, query_sequence)
os.makedirs(jobname, exist_ok=True)
queries_path = os.path.join(jobname, f"{jobname}.csv")
fasta_path = os.path.join(jobname, f"{jobname}.fasta")

while os.path.isfile(queries_path):
  jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))
  os.makedirs(jobname, exist_ok=True)
  queries_path = os.path.join(jobname, f"{jobname}.csv")
  fasta_path = os.path.join(jobname, f"{jobname}.fasta")

with open(queries_path, "w") as text_file:
  text_file.write(f"id,sequence\n{jobname},{query_sequence}")
with open(fasta_path, "w") as fasta_file:
  fasta_file.write(f">{jobname}\n{query_sequence}")

# number of models to use
num_relax = 0 #@param [0, 1, 5] {type:"raw"}
use_amber = num_relax > 0
#@markdown - specify how many of the top ranked structures to relax using amber
template_mode = "none" #@param ["none", "pdb70","custom"]
#@markdown - `none` = no template information is used. `pdb70` = detect templates in pdb70. `custom` - upload and search own templates (PDB or mmCIF format))

if template_mode == "pdb70":
  use_templates = True
  custom_template_path = None
elif template_mode == "custom":
  custom_template_path = os.path.join(jobname,f"template")
  os.makedirs(custom_template_path, exist_ok=True)
  uploaded = files.upload()
  use_templates = True
  for fn in uploaded.keys():
    os.rename(fn,os.path.join(custom_template_path,fn))
else:
  custom_template_path = None
  use_templates = False


In [None]:
#@title Install dependencies
%%bash -s $use_amber $use_templates $python_version

set -e

USE_AMBER=$1
USE_TEMPLATES=$2
PYTHON_VERSION=$3

# Install ColabFold and dependencies
if [ ! -f COLABFOLD_READY ]; then
  echo "installing colabfold..."
  # install dependencies
  # We have to use "--no-warn-conflicts" because colab already has a lot preinstalled with requirements different to ours
  pip install -q --no-warn-conflicts "colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold"
  # high risk high gain
  pip install -q "jax[cuda11_cudnn805]>=0.3.8,<0.4" -f https://storage.googleapis.com/jax-releases/jax_releases.html

  # for debugging
  ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold
  ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold
  touch COLABFOLD_READY
fi

# setup conda
if [ ${USE_AMBER} == "True" ] || [ ${USE_TEMPLATES} == "True" ]; then
  if [ ! -f CONDA_READY ]; then
    echo "installing conda..."
    wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
    bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null
    rm Miniconda3-latest-Linux-x86_64.sh
    touch CONDA_READY
  fi
fi
# setup template search
if [ ${USE_TEMPLATES} == "True" ] && [ ! -f HH_READY ]; then
  echo "installing hhsuite..."
  conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python="${PYTHON_VERSION}" 2>&1 1>/dev/null
  touch HH_READY
fi
# setup openmm for amber refinement
if [ ${USE_AMBER} == "True" ] && [ ! -f AMBER_READY ]; then
  echo "installing amber..."
  conda install -y -q -c conda-forge openmm=7.5.1 python="${PYTHON_VERSION}" pdbfixer 2>&1 1>/dev/null
  touch AMBER_READY
fi

#Install Hmmer, for search and alignment
if [ ! -f HMMER_READY ]; then
  apt-get install hmmer > /dev/null 2>&1
  touch HMMER_READY
fi

#Download the custom Discoba database
if [ ! -f DISCOBA_READY ]; then
  if [ -d discoba ]; then
    rm -r discoba
  fi
  mkdir discoba
  cd discoba
    curl https://zenodo.org/record/5682928/files/discobaStats.txt?download=1
    curl https://zenodo.org/record/5682928/files/discoba.fasta.gz?download=1 -s -L -o discoba.fasta.gz
    gzip -d discoba.fasta.gz
  cd ..
  touch DISCOBA_READY
fi

#Install hh-suite
if [ ! -f HHSUITE_READY ]; then
  if [ -d hh-suite ]; then
    rm -r hh-suite
  fi
  git clone https://github.com/soedinglab/hh-suite
  touch HHSUITE_READY
fi

In [None]:
#@markdown ### Generate custom multiple sequence alignment

msa_use_discoba = True #@param {type:"boolean"}
msa_use_colabfold = True #@param {type:"boolean"}
msa_colabfold_mode = "mmseqs2_uniref_env" #@param ["mmseqs2_uniref_env", "mmseqs2_uniref"]

# discoba MSA
# system call to hmmer
import os

if not os.path.isfile(os.path.join(jobname, f"{jobname}.hmm.a3m")) and msa_use_discoba:
  print("Doing HMMER search against Discoba")
  os.system("jackhmmer -A "+os.path.join(jobname, f"{jobname}.hmm.sto")+" -o "+os.path.join(jobname, f"{jobname}.hmm.out")+" "+os.path.join(jobname, f"{jobname}.fasta")+" "+os.path.join("discoba", "discoba.fasta"))
  os.system("perl hh-suite/scripts/reformat.pl sto a3m "+os.path.join(jobname, f"{jobname}.hmm.sto")+" "+os.path.join(jobname, f"{jobname}.hmm.a3m"))

# mmseqs2 msa
# dummy colabfold prediction
if not os.path.isfile(os.path.join(jobname, f"{jobname}.mms.a3m")) and msa_use_colabfold:
  print("Fetching MMSeqs2 search")
  msa_mode = msa_colabfold_mode
  pair_mode = "unpaired_paired"

  import sys
  import warnings
  warnings.simplefilter(action='ignore', category=FutureWarning)
  from Bio import BiopythonDeprecationWarning
  warnings.simplefilter(action='ignore', category=BiopythonDeprecationWarning)
  from pathlib import Path
  from colabfold.download import download_alphafold_params, default_data_dir
  from colabfold.utils import setup_logging
  from colabfold.batch import get_queries, run, set_model_type
  from colabfold.plot import plot_msa_v2

  import os
  import numpy as np

  from colabfold.colabfold import plot_protein
  from pathlib import Path
  import matplotlib.pyplot as plt

  def input_features_callback(input_features):
    do_nothing = "at all"
    #print("Input processed...")

  def prediction_callback(protein_obj, length, prediction_result, input_features, mode):
    do_nothing = "even now"
    #print("Should never be called...")

  if 'logging_setup' not in globals():
    setup_logging(Path(os.path.join(jobname,"log.txt")))
    logging_setup = True

  queries, is_complex = get_queries(queries_path)
  model_type = set_model_type(is_complex, "alphafold2_ptm")

  download_alphafold_params(model_type, Path("."))
  try:
    results = run(
      queries=queries,
      result_dir=jobname,
      use_templates=use_templates,
      custom_template_path=custom_template_path,
      num_relax=num_relax,
      msa_mode=msa_mode,    
      model_type=model_type,
      num_models=0, # set to zero, results in a3m but prediction failure
      num_recycles=None,
      recycle_early_stop_tolerance=None,
      num_seeds=1,
      use_dropout=None,
      model_order=[1,2,3,4,5],
      is_complex=is_complex,
      data_dir=Path("."),
      keep_existing_results=False,
      rank_by="auto",
      pair_mode=pair_mode,
      stop_at_score=float(100),
      prediction_callback=prediction_callback,
      dpi=200,
      zip_results=False,
      save_all=False,
      max_msa=None,
      use_cluster_profile=True,
      input_features_callback=input_features_callback,
      save_recycles=False,
    )
  except:
    print("Colabfold exception")
    print("Expected exception: a3m should have been fetched")
    # write all but first line of a3m as a3m result
    f = open(os.path.join(jobname, f"{jobname}.a3m"), "r")
    lines = f.read().splitlines()[1:]
    f.close()
    f = open(os.path.join(jobname, f"{jobname}.tmp.a3m"), "w")
    for line in lines:
      f.write(line + "\n")
    f.close()
    os.system("perl hh-suite/scripts/reformat.pl a3m a3m "+os.path.join(jobname, f"{jobname}.tmp.a3m")+" "+os.path.join(jobname, f"{jobname}.mms.a3m"))

# join outputs
f = open(os.path.join(jobname, f"{jobname}.custom.a3m"), "w")
if msa_use_discoba:
  i = open(os.path.join(jobname, f"{jobname}.mms.a3m"), "r")
  f.write(i.read()+"\n")
  i.close()
if msa_use_colabfold:
  i = open(os.path.join(jobname, f"{jobname}.hmm.a3m"), "r")
  f.write(i.read()+"\n")
  i.close()
f.close()

In [None]:
#@markdown ### Advanced settings
model_type = "alphafold2_ptm"
num_recycles = "auto" #@param ["auto", "0", "1", "3", "6", "12", "24", "48"]
recycle_early_stop_tolerance = "auto" #@param ["auto", "0.0", "0.5", "1.0"]
#@markdown - if `auto` selected, will use 20 recycles if `model_type=alphafold2_multimer_v3` (with tol=0.5), all else 3 recycles (with tol=0.0).

#@markdown #### Sample settings
#@markdown -  enable dropouts and increase number of seeds to sample predictions from uncertainty of the model.
#@markdown -  decrease `max_msa` to increase uncertainity
max_msa = "auto" #@param ["auto", "512:1024", "256:512", "64:128", "32:64", "16:32"]
num_seeds = 1 #@param [1,2,4,8,16] {type:"raw"}
use_dropout = False #@param {type:"boolean"}

num_recycles = None if num_recycles == "auto" else int(num_recycles)
recycle_early_stop_tolerance = None if recycle_early_stop_tolerance == "auto" else float(recycle_early_stop_tolerance)
if max_msa == "auto": max_msa = None

#@markdown #### Save settings
save_all = False #@param {type:"boolean"}
save_recycles = False #@param {type:"boolean"}
save_to_google_drive = False #@param {type:"boolean"}
#@markdown -  if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive
dpi = 200 #@param {type:"integer"}
#@markdown - set dpi for image resolution

if save_to_google_drive:
  from pydrive.drive import GoogleDrive
  from pydrive.auth import GoogleAuth
  from google.colab import auth
  from oauth2client.client import GoogleCredentials
  auth.authenticate_user()
  gauth = GoogleAuth()
  gauth.credentials = GoogleCredentials.get_application_default()
  drive = GoogleDrive(gauth)
  print("You are logged into Google Drive and are good to go!")

#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.

In [None]:
# Real prediction, using custom a3m
msa_mode = "custom"
pair_mode = "unpaired"
a3m_file = os.path.join(jobname,f"{jobname}.custom.a3m")
queries_path=a3m_file

#@title Run Prediction
display_images = True #@param {type:"boolean"}

import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from Bio import BiopythonDeprecationWarning
warnings.simplefilter(action='ignore', category=BiopythonDeprecationWarning)
from pathlib import Path
from colabfold.download import download_alphafold_params, default_data_dir
from colabfold.utils import setup_logging
from colabfold.batch import get_queries, run, set_model_type
from colabfold.plot import plot_msa_v2

import os
import numpy as np
try:
  K80_chk = os.popen('nvidia-smi | grep "Tesla K80" | wc -l').read()
except:
  K80_chk = "0"
  pass
if "1" in K80_chk:
  print("WARNING: found GPU Tesla K80: limited to total length < 1000")
  if "TF_FORCE_UNIFIED_MEMORY" in os.environ:
    del os.environ["TF_FORCE_UNIFIED_MEMORY"]
  if "XLA_PYTHON_CLIENT_MEM_FRACTION" in os.environ:
    del os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]

from colabfold.colabfold import plot_protein
from pathlib import Path
import matplotlib.pyplot as plt

# For some reason we need that to get pdbfixer to import
if use_amber and f"/usr/local/lib/python{python_version}/site-packages/" not in sys.path:
    sys.path.insert(0, f"/usr/local/lib/python{python_version}/site-packages/")

def input_features_callback(input_features):  
  if display_images:    
    plot_msa_v2(input_features)
    plt.show()
    plt.close()

def prediction_callback(protein_obj, length,
                        prediction_result, input_features, mode):
  model_name, relaxed = mode
  if not relaxed:
    if display_images:
      fig = plot_protein(protein_obj, Ls=length, dpi=150)
      plt.show()
      plt.close()

result_dir = jobname
if 'logging_setup' not in globals():
    setup_logging(Path(os.path.join(jobname,"log.txt")))
    logging_setup = True

queries, is_complex = get_queries(queries_path)
model_type = set_model_type(is_complex, model_type)

if "multimer" in model_type and max_msa is not None:
  use_cluster_profile = False
else:
  use_cluster_profile = True

download_alphafold_params(model_type, Path("."))
results = run(
    queries=queries,
    result_dir=result_dir,
    use_templates=use_templates,
    custom_template_path=custom_template_path,
    num_relax=num_relax,
    msa_mode=msa_mode,    
    model_type=model_type,
    num_models=5,
    num_recycles=num_recycles,
    recycle_early_stop_tolerance=recycle_early_stop_tolerance,
    num_seeds=num_seeds,
    use_dropout=use_dropout,
    model_order=[1,2,3,4,5],
    is_complex=is_complex,
    data_dir=Path("."),
    keep_existing_results=False,
    rank_by="auto",
    pair_mode=pair_mode,
    stop_at_score=float(100),
    prediction_callback=prediction_callback,
    dpi=dpi,
    zip_results=False,
    save_all=save_all,
    max_msa=max_msa,
    use_cluster_profile=use_cluster_profile,
    input_features_callback=input_features_callback,
    save_recycles=save_recycles,
)
results_zip = f"{jobname}.result.zip"
os.system(f"zip -r {results_zip} {jobname}")

In [None]:
#@title Display 3D structure {run: "auto"}
import py3Dmol
import glob
import matplotlib.pyplot as plt
from colabfold.colabfold import plot_plddt_legend
from colabfold.colabfold import pymol_color_list, alphabet_list
rank_num = 1 #@param ["1", "2", "3", "4", "5"] {type:"raw"}
color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

tag = results["rank"][0][rank_num - 1]
jobname_prefix = ".custom" if msa_mode == "custom" else ""
pdb_filename = f"{jobname}/{jobname}{jobname_prefix}_unrelaxed_{tag}.pdb"
pdb_file = glob.glob(pdb_filename)

def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color="lDDT"):
  model_name = f"rank_{rank_num}"
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(pdb_file[0],'r').read(),'pdb')

  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = len(queries[0][1]) + 1 if is_complex else 1
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

show_pdb(rank_num, show_sidechains, show_mainchains, color).show()
if color == "lDDT":
  plot_plddt_legend().show() 

In [None]:
#@title Plots {run: "auto"}
from IPython.display import display, HTML
import base64
from html import escape

# see: https://stackoverflow.com/a/53688522
def image_to_data_url(filename):
  ext = filename.split('.')[-1]
  prefix = f'data:image/{ext};base64,'
  with open(filename, 'rb') as f:
    img = f.read()
  return prefix + base64.b64encode(img).decode('utf-8')

pae = image_to_data_url(os.path.join(jobname,f"{jobname}{jobname_prefix}_pae.png"))
cov = image_to_data_url(os.path.join(jobname,f"{jobname}{jobname_prefix}_coverage.png"))
plddt = image_to_data_url(os.path.join(jobname,f"{jobname}{jobname_prefix}_plddt.png"))
display(HTML(f"""
<style>
  img {{
    float:left;
  }}
  .full {{
    max-width:100%;
  }}
  .half {{
    max-width:50%;
  }}
  @media (max-width:640px) {{
    .half {{
      max-width:100%;
    }}
  }}
</style>
<div style="max-width:90%; padding:2em;">
  <h1>Plots for {escape(jobname)}</h1>
  <img src="{pae}" class="full" />
  <img src="{cov}" class="half" />
  <img src="{plddt}" class="half" />
</div>
"""))


In [None]:
#@title Package and download results
#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \"Download\").

if msa_mode == "custom":
  print("Don't forget to cite your custom MSA generation method.")

files.download(f"{jobname}.result.zip")

if save_to_google_drive == True and drive:
  uploaded = drive.CreateFile({'title': f"{jobname}.result.zip"})
  uploaded.SetContentFile(f"{jobname}.result.zip")
  uploaded.Upload()
  print(f"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}")