<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/af_pseudo_diffusion_dgram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AF_pseudo_diffusion + proteinMPNN
Hacking AlphaFold to be a diffusion model (for backbone generation) via distogram. At each step add logits from proteinMPNN.


- **WARNING**: This notebook is experimental, designed as a control. Not intended for practical use at this stage.
- This notebook had a **major update** on 26Dec2022. See [original notebook](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/af_pseudo_diffusion_dgram_old.ipynb) here.
- Note: current protocol was optimized for proteins in the length range 100-500.

In [None]:
#@title setup
%%time
import os
from google.colab import files
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar")
  os.system("tar -xf alphafold_params_2022-03-02.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os, re
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.shared.protein import _np_get_cb

from IPython.display import HTML
import numpy as np
import pandas as pd
import jax.numpy as jnp
import jax
from scipy.special import softmax
from scipy.special import expit as sigmoid
import tqdm.notebook
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

from colabdesign.af.weights import __file__ as af_path
template_dgram_head = np.load(os.path.join(os.path.dirname(af_path),'template_dgram_head.npy'))

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

def sample_gumbel(shape, sym=False, eps=1e-20): 
  """Sample from Gumbel(0, 1)"""
  U = np.random.uniform(size=shape)
  o = -np.log(-np.log(U + eps) + eps)
  if sym:
    i,j = np.triu_indices(o.shape[0],k=1)
    o[j,i] = o[i,j]
  return o

def get_dgram(positions=None, dist=None, num_bins=39, min_bin=3.25, max_bin=50.75):
  if dist is None:
    atom_idx = residue_constants.atom_order
    atoms = {k:positions[...,atom_idx[k],:] for k in ["N","CA","C"]}
    cb = _np_get_cb(**atoms, use_jax=False)
    dist = np.sqrt(np.square(cb[None,:] - cb[:,None]).sum(-1,keepdims=True))
  lower_breaks = np.linspace(min_bin, max_bin, num_bins)
  lower_breaks = lower_breaks
  upper_breaks = np.concatenate([lower_breaks[1:],np.array([1e8], dtype=jnp.float32)], axis=-1)
  def get_bins(d):
    return ((d > lower_breaks) * (d < upper_breaks)).astype(float)  
  return get_bins(dist)

def af_diffusion(af_model, mpnn_model,
                iterations=100, dgram_noise=0.5, seqsep_mask=6,
                use_dropout=True, sample_models=True, num_recycles=1,
                mpnn_mode="conditional", mpnn_mask="cmap",
                cmap_dist=8.0, cmap_ss=6, cmap_num=1,
                starting_seq="", out_pdb="init.pdb", verbose=False):

  assert mpnn_mode in ["none","unconditional","conditional"]
  assert mpnn_mask in ["cmap","plddt","exp_res"]
  
  # restart model
  af_model.restart()
  l, L = af_model._len, sum(af_model._lengths)
  copies = af_model._args["copies"]
  af_model.set_opt(alpha=1.0, weights=dict(helix=1e-8))
  
  # gather info about inputs
  if "offset" in af_model._inputs:
    offset = af_model._inputs["offset"]
  else:
    idx = af_model._inputs["residue_index"]
    offset = idx[:,None] - idx[None,:]

  # initialize sequence
  if len(starting_seq) > 1:
    af_model.set_seq(seq=starting_seq)
  else:
    af_model.set_seq(sample_gumbel((l,20)))

  # initialize coordinates/dgram
  af_model._inputs["batch"] = {"aatype":np.zeros(L).astype(int),
                               "all_atom_mask":np.zeros((L,37)),
                               "all_atom_positions":np.zeros((L,37,3)),
                               "dgram":np.zeros((L,L,39))}

  aux = {"dgram_logits":np.zeros((L,L,39))}  
  save_best = False
  for k in range(iterations):
    # disable stochastic part for the last 10 steps
    if k > (iterations - 10):
      use_dropout = False
      sample_models = False
      save_best = True
      dgram_noise = 0.0
    
    # noise
    noise = sample_gumbel(aux["dgram_logits"].shape, sym=True)
    noise = noise * dgram_noise * (1 - k/iterations)
    dgram = softmax(aux["dgram_logits"] + noise, -1)

    # add mask to avoid local contacts being fixed (otherwise there is a bias toward helix)
    dgram_mask = np.abs(offset) > seqsep_mask
    af_model._inputs["batch"]["dgram"] = dgram * dgram_mask[...,None]

    # denoise
    aux.update(af_model.predict(return_aux=True,
                                verbose=False,
                                sample_models=sample_models,
                                dropout=use_dropout,
                                num_recycles=num_recycles))
    
    # gather features
    plddt = aux["plddt"]
    xyz = aux["atom_positions"]
    seq = aux["seq"]["hard"][0].argmax(-1)
    if copies > 1: seq = np.tile(seq, copies)

    # update inputs    
    af_model._inputs["batch"]["aatype"] = seq
    af_model._inputs["batch"]["all_atom_positions"] = xyz

    model_num = aux["log"]["models"][0]
    dgram_logits = aux["prev"]["prev_pair"] @ template_dgram_head[model_num]
    dgram_logits += dgram_logits.swapaxes(0,1)
    aux["dgram_logits"] = dgram_logits

    # per position confidence
    if mpnn_mask == "cmap":
      dgram_bins = np.append(0,np.linspace(2.3125,21.6875,63))
      dgram_probs = softmax(aux["debug"]["outputs"]["distogram"]["logits"],-1)
      cmap = (dgram_probs * (dgram_bins < cmap_dist)).sum(-1)
      conf = np.sort(cmap * (np.abs(offset) > cmap_ss))[:,-cmap_num:].mean(-1)

      # WARNING: option under development
      if copies > 1:
        chain_id = af_model._inputs["asym_id"]
        inter_mask = chain_id[:,None] != chain_id[None,:]
        i_cmap = dgram_probs[...,:-1].sum(-1)
        i_conf = np.sort(i_cmap * inter_mask)[:,-cmap_num:].mean(-1)
        conf = 0.5 * conf + 0.5 * i_conf

    if mpnn_mask == "plddt":
      conf = aux["plddt"]

    if mpnn_mask == "exp_res":
      exp_res = sigmoid(aux["debug"]["outputs"]["experimentally_resolved"]["logits"])
      conf = exp_res[:,1]
              
    # add logits from proteinmpnn at each stage
    if mpnn_mode != "none":    
      mpnn_model.get_af_inputs(af_model)
      opt = {"mask":np.sqrt(conf)}    
      if mpnn_mode == "unconditional":
        opt["ar_mask"] = np.zeros((L,L))
      mpnn_out = mpnn_model.score(**opt)
      mpnn_logits = mpnn_out["logits"][:l,:20]
      aux["log"]["mpnn"] = mpnn_out["score"]

      # accumulate sequence
      c = conf.reshape(copies,-1).mean(0)[:,None]
      new_logits = (1 - c) * sample_gumbel((l,20)) + c * mpnn_logits
      af_model._params["seq"] = 0.9 * af_model._params["seq"] + 0.1 * new_logits

    # save results
    af_model._save_results(aux, save_best=save_best, verbose=verbose)
    af_model._k += 1

  af_model.save_pdb(out_pdb)
  return aux

def designability_test(af_model_test, mpnn_model_test,
                       num_seqs=16, sampling_temp=0.1, num_recycles=3, 
                       model_num=4, best_metric="dgram_cce",
                       in_pdb="init.pdb", out_pdb="final.pdb",
                       out_dir=None, verbose=False):
  alphafold_model = f"model_{model_num}_ptm"

  af_model_test.prep_inputs(in_pdb)
  af_model_test._args["best_metric"] = best_metric
  L = sum(af_model_test._lengths)
  mpnn_model_test.get_af_inputs(af_model_test)
  out = mpnn_model_test.sample(num=num_seqs//8, batch=8,
                               temperature=sampling_temp)

  af_terms = ["plddt","ptm","pae","rmsd","dgram_cce"]
  for k in af_terms: out[k] = []

  with tqdm.notebook.tqdm(total=num_seqs, bar_format=TQDM_BAR_FORMAT) as pbar:
    for n in range(num_seqs):
      seq = out["seq"][n]
      af_model_test.predict(seq=seq,
                            num_recycles=num_recycles,
                            num_models=1,
                            verbose=False,
                            models=alphafold_model)

      for k in af_terms: out[k].append(af_model_test.aux["log"][k])
      out["pae"][-1] = out["pae"][-1] * 31
      af_model_test._save_results(save_best=True, verbose=verbose)
      af_model_test._k += 1
      if out_dir is not None:
        out_pdb_tmp = os.path.join(os.path.dirname(out_dir), f"n{n}.pdb")
        af_model_test.save_current_pdb(out_pdb_tmp)
      pbar.update(1)

  af_model_test.save_pdb(out_pdb)
  labels = ["score"] + af_terms + ["seq"]
  data = [[out[k][n] for k in labels] for n in range(num_seqs)]
  labels[0] = "mpnn"
  return pd.DataFrame(data, columns=labels)

In [None]:
#@title initialize the model
length = 100 #@param {type:"integer"}

# WARNING: option under development
copies = 1 # param {type:"integer"}

#@markdown Provide a starting point (optional)
starting_seq = "" #@param {type:"string"}
starting_seq = re.sub("[^A-Z]", "", starting_seq.upper())
#@markdown - if `starting_seq` provided the `length` option will be overwritten by length of starting sequence.

if len(starting_seq) > 0:
  length = len(starting_seq)

# initialize the model
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination",
                             use_templates=True,
                             debug=True)
af_model.prep_inputs(length=length, copies=copies)
mpnn_model = mk_mpnn_model()

# seperate model for scoring
af_model_test = mk_afdesign_model(protocol="fixbb", best_metric="dgram_cce")
mpnn_model_test = mk_mpnn_model()
  
print("lengths",af_model._lengths)

In [None]:
#@title run protocol
#@markdown Optimization options
iterations = 100 #@param ["50", "100", "200"] {type:"raw"}
dgram_noise = 0.5 #@param ["0.1", "0.2", "0.5", "1.0"] {type:"raw"}
seqsep_mask = 6 #@param ["0", "6", "12"] {type:"raw"}

#@markdown AlphaFold options
use_dropout = True #@param {type:"boolean"}
sample_models = True #@param {type:"boolean"}
num_recycles = 1 #@param ["0", "1", "2", "3"] {type:"raw"}

#@markdown proteinMPNN options (set to `none` to disable)
mpnn_mode = "conditional" #@param ["none", "unconditional", "conditional"]
mpnn_mask = "cmap" #@param ["cmap", "plddt", "exp_res"]

aux = af_diffusion(af_model, mpnn_model,
                  iterations=iterations,
                  dgram_noise=dgram_noise,
                  seqsep_mask=seqsep_mask,
                  use_dropout=use_dropout,
                  sample_models=sample_models,
                  num_recycles=num_recycles,
                  mpnn_mode=mpnn_mode, mpnn_mask=mpnn_mask,
                  cmap_dist=8.0, cmap_ss=6, cmap_num=1,
                  starting_seq=starting_seq,
                  out_pdb="init.pdb", verbose=True)


In [None]:
af_model.plot_pdb()
af_model.get_seqs()

In [None]:
HTML(af_model.animate(dpi=100))

In [None]:
#@title sample new sequences using proteinMPNN and rescore with alphafold (w/o template)
import pandas as pd
from google.colab import data_table
from colabdesign.shared.protein import alphabet_list as chain_list
data_table.enable_dataframe_formatter()
os.system("mkdir -p output/all_pdb")

#@markdown #### Design Options
num_seqs = 16 #@param ["8", "16", "32"] {type:"raw"}
sampling_temp = 0.1 
num_recycles = 3 #@param ["0", "1", "2", "3"] {type:"raw"}
model_num = 4 #@param ["1", "2", "3", "4", "5"] {type:"raw"}
alphafold_model = f"model_{model_num}_ptm"

af_model_test.prep_inputs(f"init.pdb",
                          chain=",".join(chain_list[:copies]),
                          copies=copies,
                          homooligomer=copies>1)

mpnn_model_test.get_af_inputs(af_model_test)
out = mpnn_model_test.sample(num=num_seqs//8, batch=8,
                              temperature=sampling_temp)

df = designability_test(af_model_test, mpnn_model_test,
                        num_seqs=num_seqs, sampling_temp=sampling_temp, num_recycles=3, 
                        model_num=model_num, best_metric="dgram_cce",
                        in_pdb="init.pdb", out_pdb="final.pdb",
                        out_dir="output/all_pdb", verbose=False)
df.to_csv('output/mpnn_results.csv')
data_table.DataTable(df.round(3).sort_values("dgram_cce"))

In [None]:
af_model_test.plot_pdb()
af_model_test.get_seqs()