<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/af_pseudo_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AF_pseudo_diffusion + proteinMPNN
Hacking AlphaFold to be a diffusion model (for backbone generation). At each step unconditional logits from proteinMPNN are added.

In [1]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar")
  os.system("tar -xf alphafold_params_2022-03-02.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model
from IPython.display import HTML
from google.colab import files
import numpy as np

CPU times: user 1.57 s, sys: 323 ms, total: 1.89 s
Wall time: 1min 47s


In [72]:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination", use_templates=True)
af_model.prep_inputs(length=200)
mpnn_model = mk_mpnn_model()

print("length",af_model._len)
print("weights",af_model.opt["weights"])

length 200
weights {'con': 1.0, 'exp_res': 0.0, 'helix': 0.0, 'pae': 0.0, 'plddt': 0.0, 'seq_ent': 0.0}


In [73]:
af_model.restart(mode="gumbel")
af_model._inputs["batch"] = {"aatype":np.full(af_model._len,0),
                             "all_atom_mask":np.tile(np.arange(37)[None] < 4, (af_model._len,1)).astype(float),
                             "all_atom_positions":np.tile(5.5 * np.arange(af_model._len)[:,None,None], (1,37,3))}
for k in range(100):
  # add noise
  w = 1 - k/100
  random_noise_xyz = np.random.normal(size=(af_model._len,37,3)) * w
  af_model._inputs["batch"]["all_atom_positions"] += random_noise_xyz

  # denoise
  aux = af_model.predict(return_aux=True, verbose=False)
  af_model._inputs["batch"]["all_atom_positions"] = af_model.aux["atom_positions"]

  # add unconditional logits from proteinmpnn at each stage
  mpnn_model.get_af_inputs(af_model)
  af_model._params["seq"] = 0.9 * af_model._params["seq"] + 0.1 * mpnn_model.get_unconditional_logits()[...,:20]

  # save results
  af_model._save_results(aux)
  af_model._k += 1

1 models [0] recycles 0 hard 1 soft 0 temp 1 loss 3.89 con 3.89 plddt 0.30 ptm 0.17
2 models [0] recycles 0 hard 1 soft 0 temp 1 loss 3.57 con 3.57 plddt 0.36 ptm 0.26
3 models [0] recycles 0 hard 1 soft 0 temp 1 loss 3.04 con 3.04 plddt 0.46 ptm 0.40
4 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.78 con 2.78 plddt 0.59 ptm 0.55
5 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.56 con 2.56 plddt 0.64 ptm 0.62
6 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.49 con 2.49 plddt 0.65 ptm 0.65
7 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.40 con 2.40 plddt 0.69 ptm 0.67
8 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.35 con 2.35 plddt 0.72 ptm 0.71
9 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.35 con 2.35 plddt 0.72 ptm 0.71
10 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.29 con 2.29 plddt 0.74 ptm 0.73
11 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.21 con 2.21 plddt 0.76 ptm 0.75
12 models [0] recycles 0 hard 1 soft 0 temp 1 loss 2.16 con 2.16 plddt 0.7

In [74]:
af_model.plot_pdb()

In [75]:
af_model.get_seqs()

['EPPIKVPKWLLEYLADPPDPNDEEARAALRARFPNLPEDLRDKLLDLLLGKGTISDEEYEALMASPELGLIPGTDRYRNPRAVAALTLLYIITKANNGKPHKGKSVIFIDYSDPSNPKPLILNPSGLLPPPPPEPLPIGKIIVSTPDGSFKKEAKLVDGQWVVLLSPEELAALKALFNGAPLSECLKGLEFEAIPASELL']

In [76]:
af_model.save_pdb(f"tmp.pdb")

In [77]:
HTML(af_model.animate(dpi=100))