<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.0.9-alpha/mpnn/examples/proteinmpnn_in_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ProteinMPNN in Jax!

In [None]:
#@title import libraries
import warnings, os
warnings.simplefilter(action='ignore', category=FutureWarning)
try:
  import colabdesign
except:
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.0.9-alpha")
  os.system("ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign")

from colabdesign.af import mk_af_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.af.prep import order_aa
from colabdesign.mpnn import mk_mpnn_model
from IPython.display import HTML

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

from google.colab import files

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

def af2mpnn(self):
  atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"])
  X = self._inputs["batch"]["all_atom_positions"][:,atom_idx]
  mask = self._inputs["batch"]["all_atom_mask"][:,1]
  inputs ={"X":X,
           "S":self._inputs["batch"]["aatype"],
           "mask":mask,
           "residue_idx":self._inputs["residue_index"],
           "chain_idx":self._inputs["asym_id"],
           "key":self.key(),
           "bias":self._inputs["bias"]}
  return inputs

def seq2aa(seq):
  return "".join([order_aa[aa] for aa in seq.argmax(-1)])

In [None]:
import re
#@markdown ### ProteinMPNN options
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
#@markdown ### Input Options
pdb='6MRR' #@param {type:"string"}
pdb_path = get_pdb(pdb)
#@markdown - pdb code (leave blank to get an upload prompt)
chains = "A" #@param {type:"string"}
chains = re.sub("[^A-Za-z]+",",", chains)

clear_mem()
af_model = mk_af_model(protocol="fixbb", use_alphafold=False)
af_model.prep_inputs(pdb_filename=pdb_path, chain=chains)
mpnn_model = mk_mpnn_model(model_name)
print("length",sum(af_model._lengths))

In [None]:
%%time
#@markdown ### Design Options
num_seqs = 32 #@param ["1", "2", "4", "8", "16", "32", "64"] {type:"raw"}

sampling_temp = 0.1 #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5"] {type:"raw"}
#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.

# runtime
print(f"score\tseqid\tsequence")
mpnn_inputs = af2mpnn(af_model)
for n in range(num_seqs+1):
  mpnn_inputs["key"] = af_model.key()
  mpnn_inputs["temperature"] = sampling_temp
  if n == 0:
    tmp = mpnn_model.score(**mpnn_inputs)
  else:
    tmp = mpnn_model.sample(**mpnn_inputs)
  tmp = jax.tree_map(np.array, tmp)
  seqid = (tmp["seq"].argmax(-1) == af_model._wt_aatype).mean(-1)
  print(f'{tmp["score"]:.3f}\t{seqid:.3f}\t{seq2aa(tmp["seq"])}')

In [None]:
#@markdown ### amino acid probabilties (unconditional)
import plotly.express as px
mpnn_inputs = af2mpnn(af_model)
pssm = jax.nn.softmax(mpnn_model.score(**mpnn_inputs)["logits"])
fig = px.imshow(np.array(pssm).T,
               labels=dict(x="positions", y="amino acids", color="probability"),
               y=residue_constants.restypes,
               zmin=0,
               zmax=1,
               template="simple_white",
              )
fig.update_xaxes(side="top")
fig.show()