<a href="https://colab.research.google.com/github/sokrypton/af_backprop/blob/beta/examples/AlphaFold_single.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AlphaFold - single sequence input
- WARNING - For DEMO and educational purposes only. 
- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could be useful for evaluating *de novo* designed proteins and learning the idealized principles of proteins.

### Tips and Instructions
- Patience... The first time you run the cell below it will take 1 minitue to setup, after that it should run in seconds (after each change).
- click the little ▶ play icon to the left of each cell below.
- For 3D display, hold mouseover aminoacid to get name and position number
- use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")


In [None]:
#@title Enter the amino acid sequence to fold ⬇️

###############################################################################
###############################################################################
#@title Setup
# import libraries
import os,sys,re

if "SETUP_DONE" not in dir():
  from IPython.utils import io
  from IPython.display import HTML
  import tensorflow as tf
  import jax
  import jax.numpy as jnp
  import numpy as np
  import matplotlib
  from matplotlib import animation
  import matplotlib.pyplot as plt
  import tqdm.notebook
  TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

  with io.capture_output() as captured:
    if not os.path.isdir("af_backprop"):
      %shell git clone https://github.com/sokrypton/af_backprop.git
      %shell pip -q install biopython dm-haiku ml-collections py3Dmol
      %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py
    if not os.path.isdir("params"):
      %shell mkdir params
      %shell curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params

  # configure which device to use
  try:
    # check if TPU is available
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
    print('Running on TPU')
    DEVICE = "tpu"
  except:
    if jax.local_devices()[0].platform == 'cpu':
      print("WARNING: no GPU detected, will be using CPU")
      DEVICE = "cpu"
    else:
      print('Running on GPU')
      DEVICE = "gpu"
      # disable GPU on tensorflow
      tf.config.set_visible_devices([], 'GPU')

  # import libraries
  sys.path.append('af_backprop')

  SETUP_DONE = True

if "LIBRARY_IMPORTED" not in dir():
  from utils import update_seq, update_aatype, get_plddt, get_pae
  import colabfold as cf
  from alphafold.common import protein
  from alphafold.data import pipeline
  from alphafold.model import data, config, model
  from alphafold.common import residue_constants

# initialize
if "current_seq" not in dir():
  current_seq = ""
  r = -1
  max_length = -1

# collect user inputs
sequence = 'GGGGGGGGGGGGGGGGGGGG' #@param {type:"string"}
recycles = 0 #@param ["0", "1", "2", "3", "6", "12", "24", "48"] {type:"raw"}
ori_sequence = re.sub("[^A-Z\/\:]", "", sequence.upper())
Ls = [len(s) for s in ori_sequence.replace(":","/").split("/")]
sequence = re.sub("[^A-Z]","",ori_sequence)
length = len(sequence)

# avoid recompiling if length within 25
if length > max_length or (max_length - length) > 25:
  max_length = length + 25
  runner, I = setup_model(max_length)

if ori_sequence != current_seq:
  outs = []
  positions = []
  plddts = []
  paes = []
  r = -1

  # pad sequence to max length
  seq = np.array([residue_constants.restype_order.get(aa,0) for aa in sequence])
  seq = np.pad(seq,[0,max_length-length],constant_values=-1)
  
  # update inputs, restart recycle
  I.update({"seq":seq, "length":length,
            "prev":{'prev_msa_first_row': np.zeros([max_length, 256]),
                    'prev_pair': np.zeros([max_length, max_length, 128]),
                    'prev_pos': np.zeros([max_length, 37, 3])}})
  
  I["inputs"]['residue_index'][:] = cf.chain_break(np.arange(max_length), Ls, length=32)
  current_seq = ori_sequence

# run for defined number of recycles
with tqdm.notebook.tqdm(total=(recycles+1), bar_format=TQDM_BAR_FORMAT) as pbar:
  p = 0
  while p < min(r+1,recycles+1):
    pbar.update(1)
    p += 1
  while r < recycles:
    O = runner(I)
    O = jax.tree_map(lambda x:np.asarray(x), O)
    positions.append(O["final_atom_positions"][:length])
    plddts.append(O["plddt"][:length])
    paes.append(O["pae"][:length,:length])
    I["prev"] = O["prev"]
    outs.append(O)
    r += 1
    pbar.update(1)

#@markdown #### Display options
color = "confidence" #@param ["chain", "confidence", "rainbow"]
if color == "confidence": color = "lDDT"
show_sidechains = True #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

print(f"plotting prediction at recycle={recycles}")
save_pdb(outs[recycles], "out.pdb")
v = cf.show_pdb("out.pdb", show_sidechains, show_mainchains, color,
                color_HP=True, size=(800,480), Ls=Ls)       
v.setHoverable({}, True,
               '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel("      "+atom.resn+":"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',
               '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')
v.show()           
if color == "lDDT":
  cf.plot_plddt_legend().show()

# add confidence plots
cf.plot_confidence(plddts[recycles]*100, paes[recycles], Ls=Ls).show()

In [None]:
#@title Animate
#@markdown - Animate trajectory if more than 0 recycle(s)
HTML(make_animation(np.asarray(positions),
                    np.asarray(plddts) * 100.0, Ls=Ls))