<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/Alphafold_single.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AlphaFold_single

In [None]:
#@title Install software
#@markdown Please execute this cell by pressing the _Play_ button 
#@markdown on the left.

# setup device
import os
import tensorflow as tf
import jax

try:
  # check if TPU is available
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()
  print('Running on TPU')
  DEVICE = "tpu"
except:
  if jax.local_devices()[0].platform == 'cpu':
    print("WARNING: no GPU detected, will be using CPU")
    DEVICE = "cpu"
  else:
    print('Running on GPU')
    DEVICE = "gpu"
    # disable GPU on tensorflow
    tf.config.set_visible_devices([], 'GPU')

    
from IPython.utils import io
import subprocess
import tqdm.notebook

GIT_REPO = 'https://github.com/deepmind/alphafold'
SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar'
PARAMS_DIR = './alphafold/data/params'
PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

# if not already installed
try:
  with tqdm.notebook.tqdm(total=50, bar_format=TQDM_BAR_FORMAT) as pbar:
    with io.capture_output() as captured:
      if not os.path.isdir("alphafold"):
        %shell rm -rf alphafold
        %shell git clone {GIT_REPO} alphafold
        %shell (cd alphafold; git checkout 1e216f93f06aa04aa699562f504db1d02c3b704c --quiet)

        # colabfold patches
        %shell mkdir --parents tmp
        %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py
        %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/pairmsa.py
        %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/protein.patch -P tmp/
        %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/config.patch -P tmp/
        %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/model.patch -P tmp/
        %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/modules.patch -P tmp/

        # Apply multi-chain patch from Lim Heo @huhlim
        %shell patch -u alphafold/alphafold/common/protein.py -i tmp/protein.patch
        
        # Apply patch to dynamically control number of recycles (idea from Ryan Kibler)
        %shell patch -u alphafold/alphafold/model/model.py -i tmp/model.patch
        %shell patch -u alphafold/alphafold/model/modules.py -i tmp/modules.patch
        %shell patch -u alphafold/alphafold/model/config.py -i tmp/config.patch
        pbar.update(4)

        %shell pip3 install ./alphafold
        pbar.update(4)
      
        # speedup from kaczmarj
        %shell mkdir --parents "{PARAMS_DIR}"
        %shell curl -fsSL "{SOURCE_URL}" | tar x -C "{PARAMS_DIR}"
        pbar.update(14+27)

        # Install py3dmol.
        %shell pip install py3dmol
        pbar.update(1)
      else:
        pbar.update(50)

except subprocess.CalledProcessError:
  print(captured)
  raise

########################################################################################
# --- Python imports ---
import colabfold as cf
import sys

from google.colab import files
import matplotlib.pyplot as plt
import numpy as np
import py3Dmol

from alphafold.model import model
from alphafold.model import config
from alphafold.model import data
from alphafold.data import pipeline
from alphafold.common import protein

from alphafold.model.tf import shape_placeholders

def make_fixed_size(feat, shape_schema, num_res=100, msa_cluster_size=1, extra_msa_size=1, num_templates=0):
  """Guess at the MSA and sequence dimensions to make fixed size."""
  pad_size_map = {
      shape_placeholders.NUM_RES: num_res,
      shape_placeholders.NUM_MSA_SEQ: msa_cluster_size,
      shape_placeholders.NUM_EXTRA_SEQ: extra_msa_size,
      shape_placeholders.NUM_TEMPLATES: num_templates,
  }
  for k, v in feat.items():
    # Don't transfer this to the accelerator.
    if k == 'extra_cluster_assignment':
      continue
    shape = list(v.shape)
    schema = shape_schema[k]
    assert len(shape) == len(schema), (
        f'Rank mismatch between shape and shape schema for {k}: '
        f'{shape} vs {schema}')
    pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)]
    padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)]
    if padding:
      feat[k] = tf.pad(v, padding, name=f'pad_to_fixed_{k}')
      feat[k].set_shape(pad_size)
  return {k:np.asarray(v) for k,v in feat.items()}

def parse_results(prediction_result, processed_feature_dict):
  b_factors = prediction_result['plddt'][:,None] * prediction_result['structure_module']['final_atom_mask']  
  dist_bins = jax.numpy.append(0,prediction_result["distogram"]["bin_edges"])
  dist_mtx = dist_bins[prediction_result["distogram"]["logits"].argmax(-1)]
  contact_mtx = jax.nn.softmax(prediction_result["distogram"]["logits"])[:,:,dist_bins < 8].sum(-1)

  out = {"unrelaxed_protein": protein.from_prediction(processed_feature_dict, prediction_result, b_factors=b_factors),
         "plddt": prediction_result['plddt'],
         "pLDDT": prediction_result['plddt'].mean(),
         "dists": dist_mtx,
         "adj": contact_mtx}

  if "ptm" in prediction_result:
    out.update({"pae": prediction_result['predicted_aligned_error'],
                "pTMscore": prediction_result['ptm']})
  return out

if "model_runner" not in dir():
  name = "model_5_ptm"
  model_config = config.model_config(name)
  model_config.data.eval.max_msa_clusters = 1
  model_config.data.common.max_extra_msa = 1
  model_config.model.num_recycle = 2
  model_config.data.common.num_recycle = 2

  model_runner = model.RunModel(model_config, data.get_model_haiku_params(name,'./alphafold/data'))

  eval_cfg = model_runner.config.data.eval
  crop_feats = {k:[None]+v for k,v in dict(eval_cfg.feat).items()}   

  MAX_LEN = 100  

In [None]:
#@title Enter the amino acid sequence to fold ⬇️
import re

# define sequence
sequence = 'GGGGGGGGGGGGGGGGGGG' #@param {type:"string"}
sequence = re.sub("[^A-Z]", "", sequence.upper())

msas, deletion_matrices = [],[]
msas.append([sequence])
deletion_matrices.append([[0]*len(sequence)])

num_res = len(sequence)
feature_dict = {**pipeline.make_sequence_features(sequence, 'test', num_res),
                **pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices)}

if num_res > MAX_LEN:
  MAX_LEN = num_res

# process features
processed_feature_dict = model_runner.process_features(feature_dict, random_seed=0)
processed_feature_dict = make_fixed_size(processed_feature_dict, crop_feats, num_res=MAX_LEN)

# predict
prediction_result, (r, t) = model_runner.predict(processed_feature_dict)

# save results
outs = parse_results(prediction_result, processed_feature_dict)
outs.update({"recycles":r, "tol":t})
with open("out.pdb", 'w') as f: f.write(protein.to_pdb(outs["unrelaxed_protein"]))

In [None]:
#@title Display 3D structure {run: "auto"}
color = "chain" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

cf.show_pdb("out.pdb", show_sidechains, show_mainchains, color,
            color_HP=True, size=(800,480)).show()
if color == "lDDT":
  cf.plot_plddt_legend().show()  
if "pae" in outs:
  cf.plot_confidence(outs["plddt"][:num_res], outs["pae"][:num_res,:][:,:num_res]).show()
else:
  cf.plot_confidence(outs["plddt"][:num_res]).show()