<a href="https://colab.research.google.com/github/sokrypton/af_tests/blob/main/examples/alphafold_output_at_each_recycle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%bash
if [ ! -d alphafold ]; then
  pip -q install biopython dm-haiku ml-collections py3Dmol

  # download model
  git clone --quiet https://github.com/deepmind/alphafold.git
  cd /content/alphafold
  git checkout --quiet 1d43aaff941c84dc56311076b58795797e49107b
  cd /content

  # apply patch to return model at each recycle step
  wget -qnc https://raw.githubusercontent.com/sokrypton/af_tests/main/model.patch
  wget -qnc https://raw.githubusercontent.com/sokrypton/af_tests/main/modules.patch
  patch -u alphafold/alphafold/model/model.py -i model.patch
  patch -u alphafold/alphafold/model/modules.py -i modules.patch

  # download model params (~1 min)
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params

  # colabfold
  wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py
fi

patching file alphafold/alphafold/model/model.py
patching file alphafold/alphafold/model/modules.py


In [10]:
# import libraries
import os
import sys
sys.path.append('/content/alphafold')

import numpy as np
import jax
import jax.numpy as jnp

from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model

import colabfold as cf

In [3]:
# setup which model params to use
model_name = "model_2_ptm" # model we want to use
model_config = config.model_config("model_5_ptm") # configure based on model that doesn't use templates

model_config.model.num_recycle = 24
model_config.data.common.num_recycle = 24

# since we'll be using single sequence input, setting size of MSA to 1
model_config.data.common.max_extra_msa = 1 # 5120
model_config.data.eval.max_msa_clusters = 1 # 512

# setup model
model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
model_runner = model.RunModel(model_config, model_params)

In [4]:
# setup inputs
query_sequence = "MQDGPGTLDVFVAAGWNTDNTIEITGGATYQLSPYIMVKAGYGWNNSSLNRFEFGGGLQYKVTPDLEPYAWAGATYNTDNTLVPAAGAGFRYKVSPEVKLVVEYGWNNSSLQFLQAGLSYRIQP"
feature_dict = {
    **pipeline.make_sequence_features(sequence=query_sequence,description="none",num_res=len(query_sequence)),
    **pipeline.make_msa_features(msas=[[query_sequence]],deletion_matrices=[[[0]*len(query_sequence)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)

In [5]:
# get outputs
outputs, prev_outputs = model_runner.predict(inputs)
plddts = np.asarray(jnp.concatenate([prev_outputs["prev_plddt"], outputs['plddt'][None]],0))
positions = np.asarray(jnp.concatenate([prev_outputs["prev_pos"], outputs['structure_module']["final_atom_positions"][None]],0))

In [15]:
def get_contacts(preds, logits=None, dist=8.0, seq_sep=6):
  if logits is None: logits = preds["distogram"]["logits"]
  less_8 = np.append(0,preds["distogram"]["bin_edges"]) < dist
  contacts = np.array(jax.nn.softmax(logits)[:,:,less_8].sum(-1))
  if seq_sep is not None and seq_sep > 1:
    # only keep contacts with sequence seperation ≥ 6
    i,j = np.triu_indices_from(contacts,seq_sep)
    contacts_ = np.zeros_like(contacts)
    contacts_[i,j] = contacts[i,j]
    contacts = contacts_ + contacts_.T
  return contacts

In [89]:
dgrams = np.asarray(jnp.concatenate([prev_outputs["prev_dgram"], outputs['distogram']["logits"][None]],0))
contacts = np.asarray([get_contacts(outputs,dgram, seq_sep=None) for dgram in dgrams])
paes = np.asarray(jnp.concatenate([(jax.nn.softmax(prev_outputs["prev_pae"],-1) * jnp.linspace(0., 31.0, 64)).sum(-1), outputs["predicted_aligned_error"][None]],0))

LET'S ANIMATE

In [97]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from IPython.display import HTML

def make_animation(positions, plddts, paes, contacts, line_w=2.0, dpi=100):

  def ca_align_to_last(positions):
    def align(P, Q):
      p = P - P.mean(0,keepdims=True)
      q = Q - Q.mean(0,keepdims=True)
      return p @ cf.kabsch(p,q)
    
    pos = positions[-1,:,1,:] - positions[-1,:,1,:].mean(0,keepdims=True)
    best_2D_view = pos @ cf.kabsch(pos,pos,return_v=True)

    new_positions = []
    for i in range(len(positions)):
      new_positions.append(align(positions[i,:,1,:],best_2D_view))
    return np.asarray(new_positions)

  # align all to last recycle
  pos = ca_align_to_last(positions)

  fig, [[ax1, ax3],
       [ax4, ax6]] = plt.subplots(2,2)
  fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 1, left = 0, hspace = 0, wspace = 0)
  fig.set_figwidth(10)
  fig.set_figheight(10)
  fig.set_dpi(dpi)

  xy_min = pos[...,:2].min() - 1
  xy_max = pos[...,:2].max() + 1

  for ax in [ax1,ax3]:
    ax.set_xlim(xy_min, xy_max)
    ax.set_ylim(xy_min, xy_max)
    ax.axis(False)
  for ax in [ax4, ax6]:
    ax.axis(False)


  ims=[]
  for k,(xyz,plddt,pae,contact) in enumerate(zip(pos,plddts,paes,contacts)):
    ims.append([])
    im2 = [None] #ax2.plot(plddt, animated=True, color="black")
    tt1 = cf.add_text(f"recycle={k} (colored by N→C)", ax1)
    tt2 = None #cf.add_text(f"recycle={k}", ax2)
    tt3 = cf.add_text(f"(colored by pLDDT={plddt.mean():.3f})", ax3)
    tt4 = cf.add_text(f"p(contact)", ax4)
    tt6 = cf.add_text(f"pAE", ax6)
    ims[-1] += [cf.plot_pseudo_3D(xyz, ax=ax1, line_w=line_w)]
    ims[-1] += [tt1,tt3,tt4,tt6]
    ims[-1] += [cf.plot_pseudo_3D(xyz, c=plddt, cmin=50, cmax=90, ax=ax3, line_w=line_w)]
    ims[-1] += [ax4.imshow(contact, vmin=0,vmax=1, cmap="Greys")]
    ims[-1] += [ax6.imshow(pae, vmin=0,vmax=30, cmap="bwr")]
    
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)
  plt.close()
  return ani.to_html5_video()

In [100]:
HTML(make_animation(positions, plddts, paes, contacts, dpi=100))

In [None]:
# save all recycles as PDB files
for n,(plddt,pos) in enumerate(zip(plddts,positions)):
  b_factors = plddt[:,None] * outputs['structure_module']['final_atom_mask']  
  p = protein.Protein(aatype=inputs['aatype'][0],
                      atom_positions=pos,
                      atom_mask=outputs['structure_module']['final_atom_mask'],
                      residue_index=inputs['residue_index'][0] + 1,
                      b_factors=b_factors)
  pdb_lines = protein.to_pdb(p)
  with open(f"tmp.{n}.pdb", 'w') as f:
    f.write(pdb_lines)