<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/AlphaFold_wJackhmmer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AlphaFold2 w/ Jackhmmer (or MMseqs2)

---------
**UPDATE** (Aug. 13, 2021)

This notebook is being retired and no longer updated. The functionality to search using jackhmmer/mmseqs has been integrated in our [new advanced notebook](https://github.com/sokrypton/ColabFold/blob/main/beta/AlphaFold2_advanced.ipynb).

---------

This notebook modifies deepmind's [original notebook](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb) to add experimental homooligomer support and option to run MMseqs2 instead of Jackhmmer for MSA generation.

See [ColabFold](https://github.com/sokrypton/ColabFold/) for other related notebooks

**Limitations**
- This notebook does NOT use Templates.
- For a typical Google-Colab-GPU (16G) session, the max total length is **1400 residues**.

In [None]:
#@title Install software
#@markdown Please execute this cell by pressing the _Play_ button 
#@markdown on the left.
use_amber_relax = True #@param {type:"boolean"}

import os
os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'

from IPython.utils import io
import subprocess
import requests
import hashlib
import tarfile
import time
import tqdm.notebook

from sys import version_info 
python_version = f"{version_info.major}.{version_info.minor}"

GIT_REPO = 'https://github.com/deepmind/alphafold'
SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar'
PARAMS_DIR = './alphafold/data/params'
PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

# if not already installed
if not os.path.isdir("alphafold"):
  try:
    total = 100 if use_amber_relax else 55
    with tqdm.notebook.tqdm(total=total, bar_format=TQDM_BAR_FORMAT) as pbar:
      with io.capture_output() as captured:
        #######################################################################
        %shell rm -rf alphafold
        %shell git clone {GIT_REPO} alphafold

        # Apply multi-chain patch from Lim Heo @huhlim
        %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/protein.patch
        %shell (patch -u alphafold/alphafold/common/protein.py -i /content/protein.patch)

        pbar.update(4)
        %shell pip3 install ./alphafold
        pbar.update(5)
            
        %shell mkdir --parents "{PARAMS_DIR}"
        %shell wget -O "{PARAMS_PATH}" "{SOURCE_URL}"
        pbar.update(14)

        %shell tar --extract --verbose --file="{PARAMS_PATH}" \
          --directory="{PARAMS_DIR}" --preserve-permissions
        %shell rm "{PARAMS_PATH}"
        pbar.update(27)

        #######################################################################
        %shell sudo apt install --quiet --yes hmmer
        pbar.update(3)

        # Install py3dmol.
        %shell pip install py3dmol
        pbar.update(1)

        if use_amber_relax:
          # Install OpenMM and pdbfixer.
          %shell rm -rf /opt/conda
          %shell wget -q -P /tmp \
            https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
              && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
              && rm /tmp/Miniconda3-latest-Linux-x86_64.sh
          pbar.update(4)

          PATH=%env PATH
          %env PATH=/opt/conda/bin:{PATH}
          %shell conda update -qy conda \
              && conda install -qy -c conda-forge \
                python={python_version} \
                openmm=7.5.1 \
                pdbfixer
          pbar.update(40)

          %shell wget -q -P /content \
            https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
          pbar.update(1)
          %shell mkdir -p /content/alphafold/common
          %shell cp -f /content/stereo_chemical_props.txt /content/alphafold/common

          # Apply OpenMM patch.
          %shell pushd /opt/conda/lib/python{python_version}/site-packages/ && \
              patch -p0 < /content/alphafold/docker/openmm.patch && \
              popd

        # Create a ramdisk to store a database chunk to make Jackhmmer run fast.
        %shell sudo mkdir -m 777 --parents /tmp/ramdisk
        %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk
        pbar.update(1)

  except subprocess.CalledProcessError:
    print(captured)
    raise

  ########################################################################################
  ########################################################################################

import jax
if jax.local_devices()[0].platform == 'tpu':
  raise RuntimeError('Colab TPU runtime not supported. Change it to GPU via Runtime -> Change Runtime Type -> Hardware accelerator -> GPU.')
elif jax.local_devices()[0].platform == 'cpu':
  raise RuntimeError('Colab CPU runtime not supported. Change it to GPU via Runtime -> Change Runtime Type -> Hardware accelerator -> GPU.')

# --- Python imports ---
import sys
import pickle
if use_amber_relax:
  sys.path.append(f"/opt/conda/lib/python{python_version}/site-packages")

from urllib import request
from concurrent import futures
from google.colab import files
import json
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
import py3Dmol

from alphafold.model import model
from alphafold.model import config
from alphafold.model import data

from alphafold.data import parsers
from alphafold.data import pipeline
from alphafold.data.tools import jackhmmer

from alphafold.common import protein

if use_amber_relax:
  from alphafold.relax import relax
  from alphafold.relax import utils

#######################################################################################
#######################################################################################
def run_mmseqs2(query_sequence, prefix, use_env=True, filter=False):
    def submit(query_sequence, mode):
      res = requests.post('https://a3m.mmseqs.com/ticket/msa', data={'q':f">1\n{query_sequence}", 'mode': mode})
      return res.json()
    def status(ID):
      res = requests.get(f'https://a3m.mmseqs.com/ticket/{ID}')
      return res.json()
    def download(ID, path):
      res = requests.get(f'https://a3m.mmseqs.com/result/download/{ID}')
      with open(path,"wb") as out: out.write(res.content)
      
    if filter:
      mode = "env" if use_env else "all"
    else:
      mode = "env-nofilter" if use_env else "nofilter"
    
    path = f"{prefix}_{mode}"
    if not os.path.isdir(path): os.mkdir(path)

    # call mmseqs2 api
    tar_gz_file = f'{path}/out.tar.gz'
    if not os.path.isfile(tar_gz_file):
      out = submit(query_sequence, mode)
      while out["status"] in ["RUNNING","PENDING"]:
        time.sleep(1)
        out = status(out["id"])    
      download(out["id"], tar_gz_file)
    
    # parse a3m files
    a3m_lines = []
    a3m = f"{prefix}_{mode}.a3m"
    if not os.path.isfile(a3m):
      with tarfile.open(tar_gz_file) as tar_gz: tar_gz.extractall(path)
      a3m_files = [f"{path}/uniref.a3m"]
      if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
      a3m_out = open(a3m,"w")
      for a3m_file in a3m_files:
        for line in open(a3m_file,"r"):
          line = line.replace("\x00","")
          if len(line) > 0:
            a3m_lines.append(line)
            a3m_out.write(line)
    else:
      a3m_lines = open(a3m).readlines()
    return "".join(a3m_lines)

def run_jackhmmer(sequence, prefix):
  pickled_msa_path = f"{prefix}.jackhmmer.pickle"
  if os.path.isfile(pickled_msa_path):
    msas_dict = pickle.load(open(pickled_msa_path,"rb"))
    msas, deletion_matrices = (msas_dict[k] for k in ['msas', 'deletion_matrices'])
    full_msa = []
    for msa in msas:
      full_msa += msa
  else:
    # --- Find the closest source ---
    test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'
    ex = futures.ThreadPoolExecutor(3)
    def fetch(source):
      request.urlretrieve(test_url_pattern.format(source))
      return source
    fs = [ex.submit(fetch, source) for source in ['', '-europe', '-asia']]
    source = None
    for f in futures.as_completed(fs):
      source = f.result()
      ex.shutdown()
      break

    jackhmmer_binary_path = '/usr/bin/jackhmmer'
    dbs = []

    num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71}
    total_jackhmmer_chunks = sum(num_jackhmmer_chunks.values())
    with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:
      def jackhmmer_chunk_callback(i):
        pbar.update(n=1)

      pbar.set_description('Searching uniref90')
      jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
          binary_path=jackhmmer_binary_path,
          database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/uniref90_2021_03.fasta',
          get_tblout=True,
          num_streamed_chunks=num_jackhmmer_chunks['uniref90'],
          streaming_callback=jackhmmer_chunk_callback,
          z_value=135301051)
      dbs.append(('uniref90', jackhmmer_uniref90_runner.query('target.fasta')))

      pbar.set_description('Searching smallbfd')
      jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer(
          binary_path=jackhmmer_binary_path,
          database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/bfd-first_non_consensus_sequences.fasta',
          get_tblout=True,
          num_streamed_chunks=num_jackhmmer_chunks['smallbfd'],
          streaming_callback=jackhmmer_chunk_callback,
          z_value=65984053)
      dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query('target.fasta')))

      pbar.set_description('Searching mgnify')
      jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
          binary_path=jackhmmer_binary_path,
          database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/mgy_clusters_2019_05.fasta',
          get_tblout=True,
          num_streamed_chunks=num_jackhmmer_chunks['mgnify'],
          streaming_callback=jackhmmer_chunk_callback,
          z_value=304820129)
      dbs.append(('mgnify', jackhmmer_mgnify_runner.query('target.fasta')))

    # --- Extract the MSAs and visualize ---
    # Extract the MSAs from the Stockholm files.
    # NB: deduplication happens later in pipeline.make_msa_features.

    mgnify_max_hits = 501
    msas = []
    deletion_matrices = []
    for db_name, db_results in dbs:
      unsorted_results = []
      for i, result in enumerate(db_results):
        msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto'])
        e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])
        e_values = [e_values_dict[t.split('/')[0]] for t in target_names]
        zipped_results = zip(msa, deletion_matrix, target_names, e_values)
        if i != 0:
          # Only take query from the first chunk
          zipped_results = [x for x in zipped_results if x[2] != 'query']
        unsorted_results.extend(zipped_results)
      sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])
      db_msas, db_deletion_matrices, _, _ = zip(*sorted_by_evalue)
      if db_msas:
        if db_name == 'mgnify':
          db_msas = db_msas[:mgnify_max_hits]
          db_deletion_matrices = db_deletion_matrices[:mgnify_max_hits]
        msas.append(db_msas)
        deletion_matrices.append(db_deletion_matrices)
        msa_size = len(set(db_msas))
        print(f'{msa_size} Sequences Found in {db_name}')

      pickle.dump({"msas":msas,"deletion_matrices":deletion_matrices},
                  open(pickled_msa_path,"wb"))
  return msas, deletion_matrices

## Making a prediction

Please paste the sequence of your protein in the text box below, then run the remaining cells via _Runtime_ > _Run after_. You can also run the cells individually by pressing the _Play_ button on the left.

Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below).

In [None]:
#@title Enter the amino acid sequence to fold ⬇️
sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK'  #@param {type:"string"}

MIN_SEQUENCE_LENGTH = 16
MAX_SEQUENCE_LENGTH = 2500

# Remove all whitespaces, tabs and end lines; upper-case
sequence = sequence.translate(str.maketrans('', '', ' \n\t')).upper()
#@markdown ### Experimental options
homooligomer = 1 #@param [1,2,3,4,5,6,7,8] {type:"raw"}

full_sequence = sequence * homooligomer

aatypes = set('ACDEFGHIKLMNPQRSTVWY')  # 20 standard aatypes
if not set(sequence).issubset(aatypes):
  raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. AlphaFold only supports 20 standard amino acids as inputs.')
if len(full_sequence) < MIN_SEQUENCE_LENGTH:
  raise Exception(f'Input sequence is too short: {len(full_sequence)} amino acids, while the minimum is {MIN_SEQUENCE_LENGTH}')
if len(full_sequence) > MAX_SEQUENCE_LENGTH:
  raise Exception(f'Input sequence is too long: {len(full_sequence)} amino acids, while the maximum is {MAX_SEQUENCE_LENGTH}. Please use the full AlphaFold system for long sequences.')

if len(full_sequence) > 1400:
  print(f"WARNING: For a typical Google-Colab-GPU (16G) session, the max total length is 1400 residues. You are at {len(full_sequence)}!")

In [None]:
#@title Search against genetic databases
#@markdown Once this cell has been executed, you will see
#@markdown statistics about the multiple sequence alignment 
#@markdown (MSA) that will be used by AlphaFold. In particular, 
#@markdown you’ll see how well each residue is covered by similar 
#@markdown sequences in the MSA.

#@markdown ---
msa_method = "jackhmmer" #@param ["jackhmmer","mmseqs2","single_sequence","custom_a3m","precomputed"]
#@markdown - `jackhmmer` - default approach from Deepmind
#@markdown - `mmseqs2` - fast method from [ColabFold](https://github.com/sokrypton/ColabFold)
#@markdown - `single_sequence` - use single sequence input (not recommended, unless a *denovo* design and you dont expect to find any homologous sequences)
#@markdown - `custom_a3m` Upload custom MSA (a3m format)
#@markdown - `precomputed` If you have previously run this notebook and saved the results,
#@markdown you can skip this step by uploading 
#@markdown the previously generated  `prediction/msa.npz`

# prediction directory
output_dir = 'prediction'
os.makedirs(output_dir, exist_ok=True)

# tmp directory
prefix = hashlib.sha1(sequence.encode()).hexdigest()
os.makedirs('tmp', exist_ok=True)
prefix = os.path.join('tmp',prefix)

# --- Search against genetic databases ---
with open('target.fasta', 'wt') as f:
  f.write(f'>query\n{sequence}')

# Run the search against chunks of genetic databases (since the genetic
# databases don't fit in Colab ramdisk).

if msa_method == "precomputed":
  print("upload precomputed pickled msa from previous run")
  pickled_msa_dict = files.upload()
  msas_dict = pickle.loads(pickled_msa_dict[list(pickled_msa_dict.keys())[0]])
  msas, deletion_matrices = (msas_dict[k] for k in ['msas', 'deletion_matrices'])

elif msa_method == "mmseqs2":
  msa, deletion_matrix = parsers.parse_a3m(run_mmseqs2(sequence, prefix, filter=True))
  msas,deletion_matrices = [msa],[deletion_matrix]

elif msa_method == "single_sequence":
  msas = [[sequence]]
  deletion_matrices = [[[0]*len(sequence)]]

elif msa_method == "custom_a3m":
  print("upload custom a3m")
  msa_dict = files.upload()
  lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines()
  a3m_lines = []
  for line in lines:
    line = line.replace("\x00","")
    if len(line) > 0 and not line.startswith('#'):
      a3m_lines.append(line)
  msa, deletion_matrix = parsers.parse_a3m("\n".join(a3m_lines))
  msas,deletion_matrices = [msa],[deletion_matrix]

  if len(msas[0][0]) != len(sequence):
    print("ERROR: the length of msa does not match input sequence")

else:
  # run jackhmmer
  msas, deletion_matrices = run_jackhmmer(sequence, prefix)

full_msa = []
for msa in msas: full_msa += msa

# save MSA as pickle
pickle.dump({"msas":msas,"deletion_matrices":deletion_matrices},
            open("prediction/msa.pickle","wb"))

# deduplicate
deduped_full_msa = list(dict.fromkeys(full_msa))
total_msa_size = len(deduped_full_msa)
print(f'\n{total_msa_size} Sequences Found in Total\n')

msa_arr = np.array([list(seq) for seq in deduped_full_msa])
num_alignments, num_res = msa_arr.shape

plt.figure(figsize=(8,5),dpi=100)
plt.title("Sequence coverage")
seqid = (np.array(list(sequence)) == msa_arr).mean(-1)
seqid_sort = seqid.argsort() #[::-1]
non_gaps = (msa_arr != "-").astype(float)
non_gaps[non_gaps == 0] = np.nan
plt.imshow(non_gaps[seqid_sort]*seqid[seqid_sort,None],
           interpolation='nearest', aspect='auto',
           cmap="rainbow_r", vmin=0, vmax=1, origin='lower')
plt.plot((msa_arr != "-").sum(0), color='black')
plt.xlim(-0.5,msa_arr.shape[1]-0.5)
plt.ylim(-0.5,msa_arr.shape[0]-0.5)
plt.colorbar(label="Sequence identity to query",)
plt.xlabel("Positions")
plt.ylabel("Sequences")
plt.show()

In [None]:
#@title Run AlphaFold 

#@markdown ---
relax_all = False #@param {type:"boolean"}
turbo = False #@param {type:"boolean"}
#@markdown - `relax_all` If disabled, only top ranked model is amber-relaxed.
#@markdown - `turbo` mode compiles a single model then swaps out params to speedup calculation. Warning, this option is experimental, and not extensively tested.


# --- Run the model ---
from string import ascii_uppercase

def _placeholder_template_feats(num_templates_, num_res_):
  return {
      'template_aatype': np.zeros([num_templates_, num_res_, 22], np.float32),
      'template_all_atom_masks': np.zeros([num_templates_, num_res_, 37, 3], np.float32),
      'template_all_atom_positions': np.zeros([num_templates_, num_res_, 37], np.float32),
      'template_domain_names': np.zeros([num_templates_], np.float32),
      'template_sum_probs': np.zeros([num_templates_], np.float32),
  }

num_templates = 0
num_res = len(sequence)

if homooligomer == 1:
  msas_mod = msas
  deletion_matrices_mod = deletion_matrices
else:
  msas_mod = []
  deletion_matrices_mod = []
  for o in range(homooligomer):
    L = num_res * o
    R = num_res * (homooligomer-(o+1))
    for msa,deletion_matrix in zip(msas,deletion_matrices):
      msas_mod.append(["-"*L+seq+"-"*R for seq in msa])
      deletion_matrices_mod.append([[0]*L+mtx+[0]*R for mtx in deletion_matrix])


feature_dict = {}
feature_dict.update(pipeline.make_sequence_features(sequence*homooligomer, 'test', num_res*homooligomer))
feature_dict.update(pipeline.make_msa_features(msas_mod, deletion_matrices=deletion_matrices_mod))
feature_dict.update(_placeholder_template_feats(num_templates, num_res*homooligomer))

# Minkyung's code
# add big enough number to residue index to indicate chain breaks
idx_res = feature_dict['residue_index']
L_prev = 0
Ls = [num_res]*homooligomer
for L_i in Ls[:-1]:
  idx_res[L_prev+L_i:] += 200
  L_prev += L_i  
feature_dict['residue_index'] = idx_res

plddts = {}
pae_outputs = {}
unrelaxed_proteins = {}

model_names = ['model_4', 'model_1', 'model_2', 'model_3', 'model_5']

total = len(model_names)
if use_amber_relax:
  if relax_all: total += total
  else: total += 1

with tqdm.notebook.tqdm(total=total, bar_format=TQDM_BAR_FORMAT) as pbar:
  for num,model_name in enumerate(model_names):
    pbar.set_description(f'Running model_{num+1}')

    cfg = config.model_config(model_name+"_ptm")
    params = data.get_model_haiku_params(model_name+"_ptm", './alphafold/data')
    if turbo:
      #####################################################
      # load models or params
      if model_name == "model_4":
        # define model and process features
        model_runner = model.RunModel(cfg, params)
        processed_feature_dict = model_runner.process_features(feature_dict,random_seed=0)
      else:
        # swap params
        for k in model_runner.params.keys():
          model_runner.params[k] = params[k]
      prediction_result = model_runner.predict(processed_feature_dict)  
      # cleanup to save memory
      if model_name == "model_5": del model_runner
      del params
      #####################################################
    else:
      #####################################################
      model_runner = model.RunModel(cfg, params)

      processed_feature_dict = model_runner.process_features(feature_dict,random_seed=0)
      prediction_result = model_runner.predict(processed_feature_dict)    

      # cleanup to save memory
      del params
      del model_runner
      #####################################################

    mean_plddt = prediction_result['plddt'].mean()

    # Get the pLDDT confidence metrics.
    pae_outputs[model_name] = (
        prediction_result['predicted_aligned_error'],
        prediction_result['max_predicted_aligned_error']
    )
    plddts[model_name] = prediction_result['plddt']

    # Set the b-factors to the per-residue plddt.
    final_atom_mask = prediction_result['structure_module']['final_atom_mask']
    b_factors = prediction_result['plddt'][:, None] * final_atom_mask
    unrelaxed_protein = protein.from_prediction(processed_feature_dict,
                                                prediction_result,
                                                b_factors=b_factors)
    unrelaxed_proteins[model_name] = unrelaxed_protein

    # Delete unused outputs to save memory.
    del prediction_result
    pbar.update(n=1)

  # Find the best model according to the mean pLDDT.
  model_rank = list(plddts.keys())
  model_rank = [model_rank[i] for i in np.argsort([plddts[x].mean() for x in model_rank])[::-1]]

  # Write out the prediction
  pbar.set_description(f'AMBER relaxation')
  for n,name in enumerate(model_rank):
    pred_output_path = os.path.join(output_dir,f'model_{n+1}_unrelaxed.pdb')
    unrelaxed_protein = unrelaxed_proteins[name]
    pdb_lines = protein.to_pdb(unrelaxed_protein)
    with open(pred_output_path, 'w') as f:
      f.write(pdb_lines)
    
    if use_amber_relax:
      if relax_all or n == 0:
        amber_relaxer = relax.AmberRelaxation(
            max_iterations=0,
            tolerance=2.39,
            stiffness=10.0,
            exclude_residues=[],
            max_outer_iterations=20)
        relaxed_pdb_lines, _, _ = amber_relaxer.process(prot=unrelaxed_protein)        
        pred_output_path = os.path.join(output_dir,f'model_{n+1}_relaxed.pdb')
        with open(pred_output_path, 'w') as f:
          f.write(relaxed_pdb_lines)
      pbar.update(n=1)
      
############################################################
############################################################

for n,rank in enumerate(model_rank):
  print(f"model_{n+1} {plddts[rank].mean()}")
  pae, max_pae = pae_outputs[rank]
  # Save pLDDT and predicted aligned error (if it exists)
  pae_output_path = os.path.join(output_dir,f'predicted_aligned_error_{n+1}.json')
  if pae_outputs:
    # Save predicted aligned error in the same format as the AF EMBL DB
    rounded_errors = np.round(pae.astype(np.float64), decimals=1)
    indices = np.indices((len(rounded_errors), len(rounded_errors))) + 1
    indices_1 = indices[0].flatten().tolist()
    indices_2 = indices[1].flatten().tolist()
    pae_data = json.dumps([{
        'residue1': indices_1,
        'residue2': indices_2,
        'distance': rounded_errors.flatten().tolist(),
        'max_predicted_aligned_error': max_pae.item()
    }],
                          indent=None,
                          separators=(',', ':'))
    with open(pae_output_path, 'w') as f:
      f.write(pae_data)

In [None]:
#@title Display 3D structure {run: "auto"}
model_num = 1 #@param ["1", "2", "3", "4", "5"] {type:"raw"}
color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

def plot_plddt_legend():
  thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)']
  plt.figure(figsize=(1,0.1),dpi=100)
  ########################################
  for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt

def plot_confidence(model_num=1):
  plt.figure(figsize=(10,3),dpi=100)
  """Plots the legend for plDDT."""
  #########################################
  plt.subplot(1,2,1); plt.title('Predicted lDDT')
  plt.plot(plddts[model_rank[model_num-1]])
  for n in range(homooligomer+1):
    x = n*(len(sequence))
    plt.plot([x,x],[0,100],color="black")
  plt.ylabel('plDDT')
  plt.xlabel('position')
  #########################################
  plt.subplot(1,2,2);plt.title('Predicted Aligned Error')
  pae, max_pae = pae_outputs[model_rank[model_num-1]]
  plt.imshow(pae, cmap="bwr",vmin=0,vmax=30)
  plt.colorbar()
  plt.xlabel('Scored residue')
  plt.ylabel('Aligned residue')
  #########################################
  return plt

def show_pdb(model_num=1, show_sidechains=False, show_mainchains=False, color="lDDT"):
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  pred_output_path = os.path.join(output_dir,f'model_{model_num}_unrelaxed.pdb')
  view.addModel(open(pred_output_path,'r').read(),'pdb')
  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    for n,chain,color in zip(range(homooligomer),list("ABCDEFGH"),
                     ["lime","cyan","magenta","yellow","salmon","white","blue","orange"]):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

show_pdb(model_num,show_sidechains, show_mainchains, color).show()
if color == "lDDT": plot_plddt_legend().show()  
plot_confidence(model_num).show()

In [None]:
#@title Download prediction

#@markdown Once this cell has been executed, a zip-archive with 
#@markdown the obtained prediction will be automatically downloaded 
#@markdown to your computer.

# add settings file
settings_path = os.path.join(output_dir,"settings.txt")
with open(settings_path, "w") as text_file:
  text_file.write(f"sequence={sequence}\n")
  text_file.write(f"msa_method={msa_method}\n")
  text_file.write(f"homooligomer={homooligomer}\n")
  text_file.write(f"use_amber_relax={use_amber_relax}\n")
  text_file.write(f"turbo={turbo}\n")
  text_file.write(f"use_templates=False\n")

# --- Download the predictions ---
!zip -q -r {output_dir}.zip {output_dir}
files.download(f'{output_dir}.zip')