<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/alphafold_output_at_each_recycle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%bash
pip -q install biopython
pip -q install dm-haiku
pip -q install ml-collections

# download model
git clone https://github.com/deepmind/alphafold.git

# apply patch to return model at each recycle step
wget -qnc https://raw.githubusercontent.com/sokrypton/af_tests/main/model.patch
wget -qnc https://raw.githubusercontent.com/sokrypton/af_tests/main/modules.patch
patch -u alphafold/alphafold/model/model.py -i model.patch
patch -u alphafold/alphafold/model/modules.py -i modules.patch

# download model params (~1 min)
mkdir params
curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params

patching file alphafold/alphafold/model/model.py
patching file alphafold/alphafold/model/modules.py


Cloning into 'alphafold'...


In [2]:
# import libraries
import os
import sys
sys.path.append('/content/alphafold')

import numpy as np
import jax.numpy as jnp


from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model

In [118]:
# setup which model params to use
model_name = "model_2_ptm" # model we want to use
model_config = config.model_config("model_5_ptm") # configure based on model that doesn't use templates

model_config.model.num_recycle = 24
model_config.data.common.num_recycle = 24

# since we'll be using single sequence input, setting size of MSA to 1
model_config.data.common.max_extra_msa = 1 # 5120
model_config.data.eval.max_msa_clusters = 1 # 512

# setup model
model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
model_runner = model.RunModel(model_config, model_params)

In [119]:
# setup inputs
query_sequence = "MQDGPGTLDVFVAAGWNTDNTIEITGGATYQLSPYIMVKAGYGWNNSSLNRFEFGGGLQYKVTPDLEPYAWAGATYNTDNTLVPAAGAGFRYKVSPEVKLVVEYGWNNSSLQFLQAGLSYRIQP"
feature_dict = {
    **pipeline.make_sequence_features(sequence=query_sequence,description="none",num_res=len(query_sequence)),
    **pipeline.make_msa_features(msas=[[query_sequence]],deletion_matrices=[[[0]*len(query_sequence)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)

In [120]:
# get outputs
outputs, prev_outputs = model_runner.predict(inputs)
plddts = np.asarray(jnp.concatenate([prev_outputs["prev_plddt"], outputs['plddt'][None]],0))
positions = np.asarray(jnp.concatenate([prev_outputs["prev_pos"], outputs['structure_module']["final_atom_positions"][None]],0))

LET'S ANIMATE

In [121]:
#@title
##@title
import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib
from matplotlib import collections as mcoll
import matplotlib.patheffects as path_effects
from IPython.display import HTML

def make_animation(positions, plddts):

  def kabsch(P, Q, return_v=False):
    # code borrowed from: https://github.com/charnley/rmsd/blob/master/rmsd/calculate_rmsd.py
    V, S, W = np.linalg.svd(P.T @ Q)
    if (np.linalg.det(V) * np.linalg.det(W)) < 0.0:
      S[-1] = -S[-1]
      V[:,-1] = -V[:,-1]
    if return_v: return V
    else: return np.dot(V, W)

  def ca_align_to_last(positions):

    def align(P,Q):
      p = P - P.mean(0,keepdims=True)
      q = Q - Q.mean(0,keepdims=True)
      return p @ kabsch(p,q)
    
    pos = positions[-1,:,1,:] - positions[-1,:,1,:].mean(0,keepdims=True)
    best_2D_view = pos @ kabsch(pos,pos,return_v=True)

    new_positions = []
    for i in range(len(positions)):
      new_positions.append(align(positions[i,:,1,:],best_2D_view))
    return np.asarray(new_positions)

  def make_segments(x, y):
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    return segments

  cmap = matplotlib.cm.get_cmap('gist_rainbow')
  
  def get_color(x, alpha=None, tint=None, shade=None, vmin=50, vmax=90):
    if x < vmin: x = vmin
    if x > vmax: x = vmax
    x = (x - vmin)/(vmax - vmin)
    color = np.array(cmap(x * 0.8, alpha=alpha))
    if tint is not None:
      color[:3] = color[:3] + (1 - color[:3]) * tint
    if shade is not None:
      color[:3] = color[:3] * shade
    return color

  # align all to last recycle
  ca_positions = ca_align_to_last(positions)

  fig, (ax1, ax2, ax3) = plt.subplots(1,3)
  fig.subplots_adjust(top = 0.9, bottom = 0.1, right = 1, left = 0, hspace = 0.1, wspace = 0.1)

  range_xy = (ca_positions[1:,:,:2].min(),ca_positions[1:,:,:2].max())
  range_z = (ca_positions[1:,:,-1].min(),ca_positions[1:,:,-1].max())

  z = (ca_positions[:,:,-1] - range_z[0]) / (range_z[1] - range_z[0])

  fig.set_figwidth(13)
  fig.set_figheight(5)
  fig.set_dpi(100)
  ims=[]
  for k,(xyz,plddt) in enumerate(zip(ca_positions,plddts)):

    ims.append([])
    srt = (z[k,:-1]+z[k,1:]).argsort()
    seg = make_segments(xyz[:,0],xyz[:,1])

    L = len(z[k])
    for n,(ax,p) in enumerate(zip([ax1,ax3],[{"r":np.arange(L)[::-1],"a":0,"b":L},
                                             {"r":plddt,"a":50,"b":90}])):
      c = np.array([get_color(v, vmin=p["a"], vmax=p["b"]) for v in p["r"]])
      ax.axis('scaled')
      ax.set_xlim(*range_xy)
      ax.set_ylim(*range_xy)
      ax.axis(False)
      im = ax.add_collection(mcoll.LineCollection(seg[srt], colors=c[srt], animated=True,
                                                  linewidths=5, path_effects=[path_effects.Stroke(capstyle="round")]))
      ims[-1] += [im]

    im2 = ax2.plot(plddt, animated=True, color="black")
    ttl2 = plt.text(0.5, 1.01, f"recycle={k}", horizontalalignment='center', verticalalignment='bottom', transform=ax2.transAxes)
    ttl3 = plt.text(0.5, 1.01, f"pLDDT={plddt.mean():.3f}", horizontalalignment='center', verticalalignment='bottom', transform=ax3.transAxes)
    ax2.set_xlabel("positions")
    ax2.set_ylabel("pLDDT")
    ax2.set_ylim(0,100)
    ims[-1] += [im2[0],ttl2,ttl3]
    
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)
  plt.close()
  return ani.to_html5_video()

In [122]:
HTML(make_animation(positions, plddts))

In [None]:
# save all recycles as PDB files
for n,(plddt,pos) in enumerate(zip(plddts,positions)):
  b_factors = plddt[:,None] * outputs['structure_module']['final_atom_mask']  
  p = protein.Protein(aatype=inputs['aatype'][0],
                      atom_positions=pos,
                      atom_mask=outputs['structure_module']['final_atom_mask'],
                      residue_index=inputs['residue_index'][0] + 1,
                      b_factors=b_factors)
  pdb_lines = protein.to_pdb(p)
  with open(f"tmp.{n}.pdb", 'w') as f:
    f.write(pdb_lines)