<a href="https://colab.research.google.com/github/sokrypton/af_backprop/blob/main/examples/AlphaFold_single.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AlphaFold Single
(For DEMO purposes only, run AlphaFold with single-sequence input)

In [None]:
#@title Setup
MAX_LEN =  100#@param {type:"integer"}
NUM_RECYCLES = 0#@param {type:"integer"}

from IPython.utils import io
import os,sys
import tensorflow as tf
import jax
import jax.numpy as jnp

with io.capture_output() as captured:
  if not os.path.isdir("af_backprop"):
    %shell git clone https://github.com/sokrypton/af_backprop.git
    %shell pip -q install biopython dm-haiku==0.0.5 ml-collections py3Dmol
    %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py
  if not os.path.isdir("params"):
    %shell mkdir params
    %shell curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params

try:
  # check if TPU is available
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()
  print('Running on TPU')
  DEVICE = "tpu"
except:
  if jax.local_devices()[0].platform == 'cpu':
    print("WARNING: no GPU detected, will be using CPU")
    DEVICE = "cpu"
  else:
    print('Running on GPU')
    DEVICE = "gpu"
    # disable GPU on tensorflow
    tf.config.set_visible_devices([], 'GPU')

import sys
sys.path.append('/content/af_backprop')
from utils import *

# import libraries
import colabfold as cf
from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.model import data, config, model
from alphafold.common import residue_constants

def clear_mem():
  backend = jax.lib.xla_bridge.get_backend()
  for buf in backend.live_buffers(): buf.delete()
clear_mem()

# setup model
model_name = "model_5_ptm"
cfg = config.model_config(model_name)
cfg.model.num_recycle = NUM_RECYCLES
cfg.data.common.num_recycle = NUM_RECYCLES
cfg.data.eval.max_msa_clusters = 1
cfg.data.common.max_extra_msa = 1
cfg.data.eval.masked_msa_replace_fraction = 0
cfg.model.global_config.subbatch_size = None
model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
model_runner = model.RunModel(cfg, model_params, is_training=False)

seq = "A" * MAX_LEN
length = len(seq)
feature_dict = {
    **pipeline.make_sequence_features(sequence=seq, description="none", num_res=length),
    **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0]*length]])
}
inputs = model_runner.process_features(feature_dict,random_seed=0)

@jax.jit
def runner(seq, inputs, model_params):

  # update sequence
  update_seq(seq, inputs)
  update_aatype(inputs["target_feat"][...,1:], inputs)

  # mask prediction
  mask = seq.sum(-1)
  inputs["seq_mask"] = inputs["seq_mask"].at[:].set(mask)
  inputs["msa_mask"] = inputs["msa_mask"].at[:].set(mask)
  inputs["residue_index"] = jnp.where(mask==1,inputs["residue_index"],0)

  # get prediction
  key = jax.random.PRNGKey(0)
  outputs = model_runner.apply(model_params, key, inputs)
  
  aux = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
         "final_atom_mask":outputs["structure_module"]["final_atom_mask"],
         "plddt":get_plddt(outputs),
         "pae":get_pae(outputs),
         "inputs":inputs}
  return aux

In [None]:
#@title Enter the amino acid sequence to fold ⬇️
import re
# define sequence
sequence = 'AAAAAAAAAAAAAAAAAAAAAAAAAA' #@param {type:"string"}
sequence = re.sub("[^A-Z]", "", sequence.upper())
LEN = len(sequence)

def predict(sequence):
  seq = np.array([residue_constants.restype_order.get(aa,-1) for aa in sequence])
  seq = np.pad(seq,[0,MAX_LEN-seq.shape[0]],constant_values=-1)
  outs = runner(jax.nn.one_hot(seq,20), inputs, model_params)
  return jax.tree_map(lambda x:np.asarray(x), outs)

outs = predict(sequence)

In [None]:
#@title Display 3D structure {run: "auto"}
color = "chain" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

def save_pdb(outs, filename):
  '''save pdb coordinates'''
  p = {"residue_index":outs["inputs"]["residue_index"][0][:LEN],
        "aatype":outs["inputs"]["aatype"].argmax(-1)[0][:LEN],
        "atom_positions":outs["final_atom_positions"][:LEN],
        "atom_mask":outs["final_atom_mask"][:LEN]}
  b_factors = 100.0 * outs["plddt"][:LEN,None] * p["atom_mask"]
  p = protein.Protein(**p,b_factors=b_factors)
  pdb_lines = protein.to_pdb(p)
  with open(filename, 'w') as f:
    f.write(pdb_lines)

save_pdb(outs,"out.pdb")
num_res = int(outs["inputs"]["aatype"][0].sum())

cf.show_pdb("out.pdb", show_sidechains, show_mainchains, color,
            color_HP=True, size=(800,480)).show()
if color == "lDDT":
  cf.plot_plddt_legend().show()  
if "pae" in outs:
  cf.plot_confidence(outs["plddt"][:LEN]*100, outs["pae"][:LEN,:LEN]).show()
else:
  cf.plot_confidence(outs["plddt"][:LEN]*100).show()