<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/beta/af/examples/af_single.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AlphaFold - single sequence input
- WARNING - For DEMO and educational purposes only.
- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could be useful for evaluating *de novo* designed proteins and learning the idealized principles of proteins.

### Tips and Instructions
- Patience... The first time you run the cell below it will take 1 minitue to setup, after that it should run in seconds (after each change).
- click the little ▶ play icon to the left of each cell below.
- For 3D display, hold mouseover aminoacid to get name and position number
- use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")


In [None]:
#@title Enter the amino acid sequence to fold ⬇️

###############################################################################
###############################################################################
#@title Setup
# import libraries
import os,sys,re,time

if "SETUP_DONE" not in dir():
  from IPython.utils import io
  from IPython.display import HTML
  import numpy as np
  import matplotlib
  from matplotlib import animation
  import matplotlib.pyplot as plt
  import tqdm.notebook
  TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

  if not os.path.isdir("params"):
    os.system("wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py")
    # get code
    print("installing ColabDesign...")
    os.system("(mkdir params; apt-get install aria2 -qq; \
    aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar; \
    tar -xf alphafold_params_2021-07-14.tar -C params; \
    touch params/done.txt )&")
    #aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \
    #tar -xf alphafold_params_2022-12-06.tar -C params; \

    os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@beta")
    os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")

    # download params
    if not os.path.isfile("params/done.txt"):
      print("downloading AlphaFold params...")
      while not os.path.isfile("params/done.txt"):
        time.sleep(5)

  # configure which device to use
  import jax
  # disable triton_gemm for jax versions > 0.3
  if int(jax.__version__.split(".")[1]) > 3:
    os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false"
  import jax.numpy as jnp
  try:
    # check if TPU is available
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
    print('Running on TPU')
    DEVICE = "tpu"
  except:
    if jax.local_devices()[0].platform == 'cpu':
      print("WARNING: no GPU detected, will be using CPU")
      DEVICE = "cpu"
    else:
      print('Running on GPU')
      DEVICE = "gpu"

  # import libraries
  sys.path.append('af_backprop')

  SETUP_DONE = True

if "LIBRARY_IMPORTED" not in dir():
  from colabdesign.af.loss import get_plddt, get_pae
  from colabdesign.af.prep import prep_input_features
  from colabdesign.af.inputs import update_seq, update_aatype
  from colabdesign.af.alphafold.common import protein
  from colabdesign.af.alphafold.model import data, config, model
  from colabdesign.af.alphafold.common import residue_constants
  from colabdesign.rf.utils import make_animation
  import py3Dmol
  import colabfold as cf

  # setup model
  cfg = config.model_config("model_5_ptm")
  cfg.model.num_recycle = 0
  cfg.model.global_config.subbatch_size = None
  model_name="model_2_ptm"
  model_params = data.get_model_haiku_params(model_name=model_name,
                                             data_dir=".",
                                             fuse=True)
  model_runner = model.RunModel(cfg, model_params)

  def setup_model(max_len):

    seq = "A" * max_len
    length = len(seq)
    inputs = prep_input_features(length)

    def runner(I):
      # update sequence
      inputs = I["inputs"]
      inputs["prev"] = I["prev"]

      seq_oh = jax.nn.one_hot(I["seq"],20)[None]
      update_seq(seq_oh, inputs)
      update_aatype(seq_oh, inputs)

      # mask prediction
      mask = jnp.arange(inputs["residue_index"].shape[0]) < I["length"]
      inputs["seq_mask"] = inputs["seq_mask"].at[:].set(mask)
      inputs["msa_mask"] = inputs["msa_mask"].at[:].set(mask)
      inputs["residue_index"] = jnp.where(mask, inputs["residue_index"], 0)

      # get prediction
      key = jax.random.PRNGKey(0)
      outputs = model_runner.apply(I["params"], key, inputs)

      aux = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
             "final_atom_mask":outputs["structure_module"]["final_atom_mask"],
             "plddt":get_plddt(outputs),"pae":get_pae(outputs),
             "length":I["length"], "seq":I["seq"],
             "prev":outputs["prev"],
             "residue_idx":inputs["residue_index"]}
      return aux

    return jax.jit(runner), {"inputs":inputs, "params":model_params, "length":max_length}

  def save_pdb(outs, filename):
    '''save pdb coordinates'''
    p = {"residue_index":outs["residue_idx"] + 1,
        "aatype":outs["seq"],
        "atom_positions":outs["final_atom_positions"],
        "atom_mask":outs["final_atom_mask"],
        "plddt":outs["plddt"]}
    p = jax.tree_map(lambda x:x[:outs["length"]], p)
    b_factors = 100 * p.pop("plddt")[:,None] * p["atom_mask"]
    p = protein.Protein(**p,b_factors=b_factors)
    pdb_lines = protein.to_pdb(p)
    with open(filename, 'w') as f: f.write(pdb_lines)

  LIBRARY_IMPORTED = True

###############################################################################
###############################################################################

# initialize
if "current_seq" not in dir():
  current_seq = ""
  r = -1
  max_length = -1

# collect user inputs
sequence = 'GGGGGGGGGG' #@param {type:"string"}
recycles = 0 #@param ["0", "1", "2", "3", "6", "12", "24", "48"] {type:"raw"}
ori_sequence = re.sub("[^A-Z\/\:]", "", sequence.upper())
Ls = [len(s) for s in ori_sequence.replace(":","/").split("/")]
sequence = re.sub("[^A-Z]","",ori_sequence)
length = len(sequence)

# avoid recompiling if length within 10
if length > max_length or (max_length - length) > 20:
  max_length = length + 10
  runner, I = setup_model(max_length)

if ori_sequence != current_seq:
  outs = []
  positions = []
  plddts = []
  paes = []
  r = -1

  # pad sequence to max length
  seq = np.array([residue_constants.restype_order.get(aa,0) for aa in sequence])
  seq = np.pad(seq,[0,max_length-length],constant_values=-1)

  # update inputs, restart recycle
  I.update({"seq":seq, "length":length,
            "prev":{'prev_msa_first_row': np.zeros([max_length, 256]),
                    'prev_pair': np.zeros([max_length, max_length, 128]),
                    'prev_pos': np.zeros([max_length, 37, 3])}})

  I["inputs"]["use_dropout"] = False
  I["inputs"]['residue_index'][:] = cf.chain_break(np.arange(max_length), Ls, length=32)
  current_seq = ori_sequence

# run for defined number of recycles
with tqdm.notebook.tqdm(total=(recycles+1), bar_format=TQDM_BAR_FORMAT) as pbar:
  p = 0
  while p < min(r+1,recycles+1):
    pbar.update(1)
    p += 1
  while r < recycles:
    O = runner(I)
    O = jax.tree_map(lambda x:np.asarray(x), O)
    positions.append(O["final_atom_positions"][:length])
    plddts.append(O["plddt"][:length])
    paes.append(O["pae"][:length,:length])
    I["prev"] = O["prev"]
    outs.append(O)
    r += 1
    pbar.update(1)

#@markdown #### Display options
color = "confidence" #@param ["chain", "confidence", "rainbow"]
if color == "confidence": color = "lDDT"
show_sidechains = True #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

print(f"plotting prediction at recycle={recycles}")
save_pdb(outs[recycles], "out.pdb")
v = cf.show_pdb("out.pdb", show_sidechains, show_mainchains, color,
                color_HP=True, size=(800,480), Ls=Ls)
v.setHoverable({}, True,
               '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel("      "+atom.resn+":"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',
               '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')
v.show()
if color == "lDDT":
  cf.plot_plddt_legend().show()

# add confidence plots
cf.plot_confidence(plddts[recycles]*100, paes[recycles], Ls=Ls).show()

In [None]:
#@title Animate
#@markdown - Animate trajectory if more than 0 recycle(s)
HTML(make_animation(np.asarray(positions)[...,1,:],
                    np.asarray(plddts) * 100.0,
                    Ls=Ls,
                    ref=-1, align_to_ref=True,
                    verbose=True))