<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/alphafold_output_at_each_recycle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%bash
# download model
git clone https://github.com/deepmind/alphafold.git

# apply patch to return model at each recycle step
wget -qnc https://raw.githubusercontent.com/sokrypton/af_tests/main/model.patch
wget -qnc https://raw.githubusercontent.com/sokrypton/af_tests/main/modules.patch
patch -u alphafold/alphafold/model/model.py -i model.patch
patch -u alphafold/alphafold/model/modules.py -i modules.patch

patching file alphafold/alphafold/model/model.py
patching file alphafold/alphafold/model/modules.py


Cloning into 'alphafold'...


In [2]:
%%bash
# download model params (~1 min)
wget -qnc https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar
mkdir params
tar -xf alphafold_params_2021-07-14.tar -C params/
rm alphafold_params_2021-07-14.tar

In [3]:
%%bash
pip -q install biopython
pip -q install dm-haiku
pip -q install ml-collections

In [4]:
# import libraries
import os
import sys
sys.path.append('/content/alphafold')

import numpy as np
import matplotlib.pyplot as plt

from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model

In [5]:
# setup which model params to use
model_name = "model_2_ptm" # model we want to use
model_config = config.model_config("model_5_ptm") # configure based on model that doesn't use templates

model_config.model.num_recycle = 24
model_config.data.common.num_recycle = 24

# since we'll be using single sequence input, setting size of MSA to 1
model_config.data.common.max_extra_msa = 1 # 5120
model_config.data.eval.max_msa_clusters = 1 # 512

# setup model
model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
model_runner = model.RunModel(model_config, model_params)

In [6]:
# setup inputs
query_sequence = "MQDGPGTLDVFVAAGWNTDNTIEITGGATYQLSPYIMVKAGYGWNNSSLNRFEFGGGLQYKVTPDLEPYAWAGATYNTDNTLVPAAGAGFRYKVSPEVKLVVEYGWNNSSLQFLQAGLSYRIQP"
feature_dict = {
    **pipeline.make_sequence_features(sequence=query_sequence,description="none",num_res=len(query_sequence)),
    **pipeline.make_msa_features(msas=[[query_sequence]],deletion_matrices=[[[0]*len(query_sequence)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)

In [7]:
outputs, prev_outputs = model_runner.predict(inputs)

In [8]:
import jax.numpy as jnp

def ca_align_to_last(positions):
  def align(P,Q):
    p = P - P.mean(0,keepdims=True)
    q = Q - Q.mean(0,keepdims=True)
    u,s,v = jnp.linalg.svd(p.T @ q, full_matrices=False)
    return p @ (u @ v)
  
  best_2D_view = positions[-1,:,1,:] - positions[-1,:,1,:].mean(0,keepdims=True)
  eigen_val, eigen_vec = jnp.linalg.eigh(jnp.cov(best_2D_view))
  best_2D_view = eigen_vec[:,-3:] @ eigen_vec[:,-3:].T @ best_2D_view
  best_2D_view = best_2D_view[...,::-1]

  new_positions = []
  for i in range(len(positions)):
    new_positions.append(align(positions[i,:,1,:],best_2D_view))
  return jnp.asarray(new_positions)

plddts = jnp.concatenate([prev_outputs["prev_plddt"], outputs['plddt'][None]],0)
positions = jnp.concatenate([prev_outputs["prev_pos"], outputs['structure_module']["final_atom_positions"][None]],0)
ca_positions = ca_align_to_last(positions)

In [9]:
def make_segments(x, y):
  points = np.array([x, y]).T.reshape(-1, 1, 2)
  segments = np.concatenate([points[:-1], points[1:]], axis=1)
  return segments

def get_color(x, alpha=None, tint=None, shade=None, vmin=50, vmax=90):
  if x < vmin: x = vmin
  if x > vmax: x = vmax
  x = (x - vmin)/(vmax - vmin)
  color = np.array(cmap(x * 0.8, alpha=alpha))
  if tint is not None:
    color[:3] = color[:3] + (1 - color[:3]) * tint
  if shade is not None:
    color[:3] = color[:3] * shade
  return color

In [10]:
from matplotlib import animation
import matplotlib
from matplotlib import collections as mcoll
cmap = matplotlib.cm.get_cmap('gist_rainbow')
from IPython.display import HTML

In [14]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3)
fig.subplots_adjust(top = 0.9, bottom = 0.1, right = 1, left = 0, hspace = 0.1, wspace = 0.1)

range_xy = (ca_positions[3:,:,:2].min(),ca_positions[3:,:,:2].max())
range_z = (ca_positions[3:,:,-1].min(),ca_positions[3:,:,-1].max())

z = (ca_positions[:,:,-1] - range_z[0]) / (range_z[1] - range_z[0])

fig.set_figwidth(13)
fig.set_figheight(5)
fig.set_dpi(100)
ims=[]
for k,xyz in enumerate(ca_positions):

  # color by plddt
  colors1 = np.array([get_color(p, vmin=0, vmax=len(plddts[k])) for p in np.arange(len(plddts[k]))])
  colors3 = np.array([get_color(p) for a,p in zip(z[k], plddts[k])])
  
  srt = (z[k,:-1]+z[k,1:]).argsort()[::-1]
  seg = make_segments(xyz[:,0],xyz[:,1])
  ax1.axis('scaled')
  ax1.set_xlim(*range_xy)
  ax1.set_ylim(*range_xy)
  ax1.axis(False)
  im1 = ax1.add_collection(mcoll.LineCollection(seg[srt], colors=colors1[srt], animated=True, linewidths=5))
  
  im2 = ax2.plot(plddts[k], animated=True, color="black")
  ttl2 = plt.text(0.5, 1.01, f"recycle={k}", horizontalalignment='center', verticalalignment='bottom', transform=ax2.transAxes)
  ax2.set_xlabel("positions")
  ax2.set_ylabel("plddt")
  ax2.set_ylim(0,100)

  ttl3 = plt.text(0.5, 1.01, f"colored by pLDDT", horizontalalignment='center', verticalalignment='bottom', transform=ax3.transAxes)
  ax3.axis('scaled')
  ax3.set_xlim(*range_xy)
  ax3.set_ylim(*range_xy)
  ax3.axis(False)
  im3 = ax3.add_collection(mcoll.LineCollection(seg[srt], colors=colors3[srt], animated=True, linewidths=5))


  ims.append([im1,ttl2,im2[0],ttl3,im3])
ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)
plt.close()
HTML(ani.to_html5_video())

In [12]:
# save all recycles
for n,(plddt,pos) in enumerate(zip(plddts,positions)):
  b_factors = plddt[:,None] * outputs['structure_module']['final_atom_mask']  
  p = protein.Protein(aatype=inputs['aatype'][0],
                      atom_positions=pos,
                      atom_mask=outputs['structure_module']['final_atom_mask'],
                      residue_index=inputs['residue_index'][0] + 1,
                      b_factors=b_factors)
  pdb_lines = protein.to_pdb(p)
  with open(f"tmp.{n}.pdb", 'w') as f:
    f.write(pdb_lines)