<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/alphafold_output_at_each_recycle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%bash
pip -q install biopython
pip -q install dm-haiku
pip -q install ml-collections
pip -q install py3Dmol

# download model
git clone https://github.com/deepmind/alphafold.git

# apply patch to return model at each recycle step
wget -qnc https://raw.githubusercontent.com/sokrypton/af_tests/main/model.patch
wget -qnc https://raw.githubusercontent.com/sokrypton/af_tests/main/modules.patch
patch -u alphafold/alphafold/model/model.py -i model.patch
patch -u alphafold/alphafold/model/modules.py -i modules.patch

# download model params (~1 min)
mkdir params
curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params

# colabfold
wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py


patching file alphafold/alphafold/model/model.py
patching file alphafold/alphafold/model/modules.py


Cloning into 'alphafold'...


In [4]:
# import libraries
import os
import sys
sys.path.append('/content/alphafold')

import numpy as np
import jax.numpy as jnp


from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model

import colabfold as cf

In [5]:
# setup which model params to use
model_name = "model_2_ptm" # model we want to use
model_config = config.model_config("model_5_ptm") # configure based on model that doesn't use templates

model_config.model.num_recycle = 24
model_config.data.common.num_recycle = 24

# since we'll be using single sequence input, setting size of MSA to 1
model_config.data.common.max_extra_msa = 1 # 5120
model_config.data.eval.max_msa_clusters = 1 # 512

# setup model
model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
model_runner = model.RunModel(model_config, model_params)

In [6]:
# setup inputs
query_sequence = "MQDGPGTLDVFVAAGWNTDNTIEITGGATYQLSPYIMVKAGYGWNNSSLNRFEFGGGLQYKVTPDLEPYAWAGATYNTDNTLVPAAGAGFRYKVSPEVKLVVEYGWNNSSLQFLQAGLSYRIQP"
feature_dict = {
    **pipeline.make_sequence_features(sequence=query_sequence,description="none",num_res=len(query_sequence)),
    **pipeline.make_msa_features(msas=[[query_sequence]],deletion_matrices=[[[0]*len(query_sequence)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)

In [7]:
# get outputs
outputs, prev_outputs = model_runner.predict(inputs)
plddts = np.asarray(jnp.concatenate([prev_outputs["prev_plddt"], outputs['plddt'][None]],0))
positions = np.asarray(jnp.concatenate([prev_outputs["prev_pos"], outputs['structure_module']["final_atom_positions"][None]],0))

LET'S ANIMATE

In [8]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from IPython.display import HTML

def make_animation(positions, plddts):
  def ca_align_to_last(positions):
    def align(P, Q):
      p = P - P.mean(0,keepdims=True)
      q = Q - Q.mean(0,keepdims=True)
      return p @ cf.kabsch(p,q)
    
    pos = positions[-1,:,1,:] - positions[-1,:,1,:].mean(0,keepdims=True)
    best_2D_view = pos @ cf.kabsch(pos,pos,return_v=True)

    new_positions = []
    for i in range(len(positions)):
      new_positions.append(align(positions[i,:,1,:],best_2D_view))
    return np.asarray(new_positions)

  # align all to last recycle
  pos = ca_align_to_last(positions)

  fig, (ax1, ax2, ax3) = plt.subplots(1,3)
  fig.subplots_adjust(top = 0.90, bottom = 0.10,
                      right = 1, left = 0,
                      hspace = 0, wspace = 0)
  fig.set_figwidth(13); fig.set_figheight(5); fig.set_dpi(100)

  xy_min = pos[...,:2].min() - 1
  xy_max = pos[...,:2].max() + 1

  for ax in [ax1,ax3]:
    ax.set_xlim(xy_min, xy_max)
    ax.set_ylim(xy_min, xy_max)
    ax.set_aspect('equal')
    ax.axis(False)
    width = fig.bbox_inches.width * ax.get_position().width
    line_w = 150 * (width/(xy_max-xy_min))

  ims=[]
  for k,(xyz,plddt) in enumerate(zip(pos,plddts)):
    ims.append([])
    im2 = ax2.plot(plddt, animated=True, color="black")
    tt1 = plt.text(0.5, 1.01, "colored by N->C", horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes)
    tt2 = plt.text(0.5, 1.01, f"recycle={k}", horizontalalignment='center', verticalalignment='bottom', transform=ax2.transAxes)
    tt3 = plt.text(0.5, 1.01, f"pLDDT={plddt.mean():.3f}", horizontalalignment='center', verticalalignment='bottom', transform=ax3.transAxes)
    ax2.set_xlabel("positions"); ax2.set_ylabel("pLDDT")
    ax2.set_ylim(0,100)
    ims[-1] += [cf.pseudo_3D_plot(xyz, ax=ax1, line_w=line_w)]
    ims[-1] += [im2[0],tt1,tt2,tt3]
    ims[-1] += [cf.pseudo_3D_plot(xyz, c=plddt, cmin=50, cmax=90, ax=ax3, line_w=line_w)]
    
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)
  plt.close()
  return ani.to_html5_video()

In [9]:
HTML(make_animation(positions, plddts))

In [10]:
# save all recycles as PDB files
for n,(plddt,pos) in enumerate(zip(plddts,positions)):
  b_factors = plddt[:,None] * outputs['structure_module']['final_atom_mask']  
  p = protein.Protein(aatype=inputs['aatype'][0],
                      atom_positions=pos,
                      atom_mask=outputs['structure_module']['final_atom_mask'],
                      residue_index=inputs['residue_index'][0] + 1,
                      b_factors=b_factors)
  pdb_lines = protein.to_pdb(p)
  with open(f"tmp.{n}.pdb", 'w') as f:
    f.write(pdb_lines)