<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.0/mpnn/examples/proteinmpnn_in_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ProteinMPNN in Jax!

---

fixbb monomer design:
 - `pdb="6MRR" chains="A"`

fixbb homooligomer design:
 - `pdb="5XZK" chains="A,B,C" homooligomer=True`

binder design:
 - `pdb="1SSC" chains="A,B" fix_pos="A"`

---


In [None]:
#@title Install colabdesign
import os
try:
  import colabdesign
except:
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0")
  os.system("ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign")

from colabdesign.mpnn import mk_mpnn_model, clear_mem

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
from google.colab import files

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

In [None]:
%%time
#@title Run ProteinMPNN to design new sequences for given backbone

import warnings, os, re
warnings.simplefilter(action='ignore', category=FutureWarning)

# USER OPTIONS
#@markdown #### ProteinMPNN options
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
#@markdown #### Input Options
pdb='6MRR' #@param {type:"string"}
#@markdown - leave blank to get an upload prompt
chains = "A" #@param {type:"string"}
homooligomer = False #@param {type:"boolean"}
#@markdown #### Design constraints
fix_pos = "" #@param {type:"string"}
#@markdown - specify which positions to keep fixed in the sequence (example: `1,2-10`)
#@markdown - you can also specify chain specific constraints (example: `A1-10,B1-20`)
#@markdown - you can also specify to fix entire chain(s) (example: `A`)
inverse = False #@param {type:"boolean"}
#@markdown - inverse the `fix_pos` selection (define position to "free" [or design] instead of "fix")
rm_aa = "" #@param {type:"string"}
#@markdown - specify amino acid(s) to exclude (example: `C,A,T`)

#@markdown #### Design Options
num_seqs = 32 #@param ["32", "64", "128", "256", "512", "1024"] {type:"raw"}
sampling_temp = 0.1 #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5", "1.0"] {type:"raw"}
#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.

#@markdown Note: designed sequences are saved to `design.fasta`

# cleaning user options
chains = re.sub("[^A-Za-z]+",",", chains)
if fix_pos == "": fix_pos = None
rm_aa = ",".join(list(re.sub("[^A-Z]+","",rm_aa.upper())))
if rm_aa == "": rm_aa = None

pdb_path = get_pdb(pdb)

mpnn_args = [pdb_path, chains, homooligomer, fix_pos, inverse, rm_aa]
if "mpnn_args_current" not in dir() or mpnn_args != mpnn_args_current:
  mpnn_model = mk_mpnn_model(model_name)
  mpnn_model.prep_inputs(pdb_filename=pdb_path,
                        chain=chains, homooligomer=homooligomer,
                        fix_pos=fix_pos, inverse=inverse,
                        rm_aa=rm_aa, verbose=True)
  mpnn_args_current = [x for x in mpnn_args]

out = mpnn_model.sample(num=num_seqs//32, batch=32,
                        temperature=sampling_temp,
                        rescore=homooligomer)

with open("design.fasta","w") as fasta:
  for n in range(num_seqs):
    line = f'>score:{out["score"][n]:.3f}_seqid:{out["seqid"][n]:.3f}\n{out["seq"][n]}'
    fasta.write(line+"\n")
    print(line)

In [None]:
#@title ### Get amino acid probabilties from ProteinMPNN (optional)
mode = "unconditional" #@param ["unconditional", "conditional", "conditional_fix_pos"]
#@markdown - `unconditional` - P(sequence | structure) 
#@markdown - `conditional` - P(sequence | structure, sequence)
#@markdown - `conditional_fix_pos` - P(sequence[not_fixed] | structure, sequence[fix_pos])
show = "all" 
import plotly.express as px
from scipy.special import softmax
from colabdesign.mpnn.model import residue_constants
L = sum(mpnn_model._lengths)
fix_pos = mpnn_model._inputs.get("fix_pos",[])
free_pos = np.delete(np.arange(L),fix_pos)

if mode == "conditional":
  ar_mask = 1-np.eye(L)
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  pdb_labels = None
if mode == "conditional_fix_pos":
  assert "fix_pos" in mpnn_model._inputs, "no positions fixed"
  ar_mask = 1-np.eye(L)
  p = np.delete(np.arange(L),mpnn_model._inputs["fix_pos"])
  ar_mask[free_pos[:,None],free_pos[None,:]] = 0
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  logits = logits[free_pos]
  pdb_labels = np.array([f"{i}_{c}" for c,i in zip(mpnn_model.pdb["idx"]["chain"], mpnn_model.pdb["idx"]["residue"])])
  pdb_labels = pdb_labels[free_pos]
else:
  ar_mask = np.zeros((L,L))
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  pdb_labels = None

pssm = softmax(logits,-1)

fig = px.imshow(np.array(pssm).T,
               labels=dict(x="positions", y="amino acids", color="probability"),
               y=residue_constants.restypes + ["X"],
               x=pdb_labels,
               zmin=0,
               zmax=1,
               template="simple_white",
              )
fig.update_xaxes(side="top")
fig.show()

In [None]:
#@title Run AlphaFold Prediction on ProteinMPNN sequences (optional)
#@markdown ###AlphaFold Options
num_models = 1 #@param ["1","2","3","4","5"] {type:"raw"}
num_recycles = 1 #@param ["0","1","2","3"] {type:"raw"}
use_multimer = False #@param {type:"boolean"}
#@markdown ###AF2Rank Options (WIP)
use_AF2Rank = False #@param {type:"boolean"}
#@markdown - AF2Rank uses native structure as input template and assess the 
#@markdown agreement between sequence and structure using AlphaFold's confidence metrics.
#@markdown - The "composite" metric is defined as pLDDT * pTMscore. (WIP: TMscore between input/output not yet implemented.)
rm_template_interchain = False #@param {type:"boolean"}
#@markdown - Remove interface template info. (Recommended for evaluating redesigned interfaces).
constrain_fix_pos = False #@param {type:"boolean"}
#@markdown - constrain fixed position (aka do not remove template sequence/sidechain on for fixed positions)
if not os.path.isdir("params"):
  os.system("mkdir params")
  os.system("curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params")

# where pdb files will be save:
if not os.path.isdir("all_pdb"): os.system("mkdir all_pdb")
else: os.system("rm all_pdb/*")

from colabdesign.af import mk_af_model
af_args = [pdb_path, chains, homooligomer,
           use_multimer, use_AF2Rank]

if "af_args_current" not in dir() or af_args != af_arg_current:
  af_model = mk_af_model(use_multimer=use_multimer,
                         use_templates=use_AF2Rank,
                         best_metric="dgram_cce")
  af_model.prep_inputs(pdb_path,chains,homooligomer=homooligomer)
  af_arg_current = [x for x in af_args]

af_model.restart()
if use_AF2Rank:
  af_model.set_opt("template", rm_ic=rm_template_interchain)
  if constrain_fix_pos and "fix_pos" in mpnn_model._inputs:
    p = mpnn_model._inputs["fix_pos"]
    af_model._inputs["rm_template_seq"][p] = False
    af_model._inputs["rm_template_sc"][p] = False
  else:
    af_model._inputs["rm_template_seq"][:] = True
    af_model._inputs["rm_template_sc"][:] = True

for S in out["S"]:
  seq = S[:af_model._len].argmax(-1)
  af_model.predict(seq=seq,
                   num_recycles=num_recycles,
                   num_models=num_models,
                   verbose=False)
  (rmsd, ptm, plddt) = (af_model.aux["log"][k] for k in ["rmsd","ptm","plddt"])
  if use_AF2Rank:
    af_model.aux["log"]["composite"] = ptm * plddt
    af_model._save_results(save_best=True,
                           best_metric="composite",
                           metric_higher_better=True)
  else:
    af_model._save_results(save_best=True)

  af_model._k += 1
  af_model.save_current_pdb(f"all_pdb/ptm{ptm:.3f}_plddt{plddt:.3f}_rmsd{rmsd:.3f}_n{af_model._k}.pdb")

af_model.save_pdb(f"best.pdb")
#@markdown Note: designed pdbs are saved to `all_pdb/`


In [None]:
#@title animate
color_by = "plddt" #@param ["chain", "plddt", "rainbow"]
dpi = 100 #@param {type:"integer"}
HTML(af_model.animate(color_by=color_by, dpi=dpi))

In [None]:
#@title display best protein {run: "auto"}
color = "pLDDT" #@param ["chain", "pLDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
color_HP = False #@param {type:"boolean"}
animate = False #@param {type:"boolean"}
#@markdown - if `num_models` > 1, will iterate through the models when `animate` is enabled.
af_model.plot_pdb(show_sidechains=show_sidechains,
                  show_mainchains=show_mainchains,
                  color=color, color_HP=color_HP, animate=animate)

In [None]:
# get stats about best sequence
print(af_model.get_seq())
af_model._tmp["best"]["aux"]["log"]