<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.0/mpnn/examples/proteinmpnn_in_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ProteinMPNN in Jax!

In [None]:
#@title import libraries
import warnings, os
warnings.simplefilter(action='ignore', category=FutureWarning)
try:
  import colabdesign
except:
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0")
  os.system("ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign")

from colabdesign.mpnn import mk_mpnn_model, clear_mem
from IPython.display import HTML

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

from google.colab import files

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

In [None]:
import re
#@markdown ### ProteinMPNN options
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
#@markdown #### Input Options
pdb='6MRR' #@param {type:"string"}
pdb_path = get_pdb(pdb)
#@markdown - leave blank to get an upload prompt
chains = "A" #@param {type:"string"}
chains = re.sub("[^A-Za-z]+",",", chains)
#@markdown #### Design constraints
fix_pos = "" #@param {type:"string"}
if fix_pos == "": fix_pos = None
#@markdown - specify which positions to keep fixed in the sequence (example: `1,2-10`)
#@markdown - you can also specify multichain constrains (example: `A1-10,B1-20`)
rm_aa = "" #@param {type:"string"}
rm_aa = ",".join(list(re.sub("[^A-Z]+","",rm_aa.upper())))
if rm_aa == "": rm_aa = None
#@markdown - specify amino acid(s) to exclude (example: `C,A,T`)

clear_mem()
mpnn_model = mk_mpnn_model(model_name)
mpnn_model.prep_inputs(pdb_filename=pdb_path,
                       chain=chains,
                       fix_pos=fix_pos,
                       rm_aa=rm_aa)
print("length", mpnn_model._len)

In [None]:
%%time
from scipy.special import log_softmax
#@markdown ### Design Options
num_seqs = 32 #@param ["0","1", "2", "4", "8", "16", "32", "64", "128", "256"] {type:"raw"}
sampling_temp = 0.1 #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5", "1.0"] {type:"raw"}
#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.

for n in range(num_seqs):
  outputs = mpnn_model.sample(temperature=sampling_temp)
  mask = mpnn_model._inputs["mask"].copy()
  if "fix_pos" in mpnn_model._inputs:
    mask[mpnn_model._inputs["fix_pos"]] = 0

  seq = outputs["S"].argmax(-1)
  seqid = seq == mpnn_model._inputs["S"]
  seqid = (seqid * mask).sum() / mask.sum()
  
  print(f'>score:{outputs["score"]:.3f}_seqid:{seqid:.3f}\n{outputs["seq"]}')

In [None]:
#@markdown ### amino acid probabilties (unconditional)
import plotly.express as px
from scipy.special import softmax
from colabdesign.mpnn.model import residue_constants
pssm = softmax(mpnn_model.get_unconditional_logits(),-1)
fig = px.imshow(np.array(pssm).T,
               labels=dict(x="positions", y="amino acids", color="probability"),
               y=residue_constants.restypes + ["X"],
               zmin=0,
               zmax=1,
               template="simple_white",
              )
fig.update_xaxes(side="top")
fig.show()