<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.0/mpnn/examples/proteinmpnn_in_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ProteinMPNN in Jax!

In [None]:
#@title import libraries
import warnings, os
warnings.simplefilter(action='ignore', category=FutureWarning)
try:
  import colabdesign
except:
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0")
  os.system("ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign")

from colabdesign.mpnn import mk_mpnn_model, clear_mem
from IPython.display import HTML

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

from google.colab import files

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

In [None]:
import re
#@markdown ### ProteinMPNN options
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
#@markdown #### Input Options
pdb='6MRR' #@param {type:"string"}
pdb_path = get_pdb(pdb)
#@markdown - leave blank to get an upload prompt
chains = "A" #@param {type:"string"}
chains = re.sub("[^A-Za-z]+",",", chains)
homooligomer = False #@param {type:"boolean"}
#@markdown #### Design constraints
fix_pos = "" #@param {type:"string"}
if fix_pos == "": fix_pos = None
#@markdown - specify which positions to keep fixed in the sequence (example: `1,2-10`)
#@markdown - you can also specify chain specific constraints (example: `A1-10,B1-20`)
#@markdown - you can also specify to fix entire chain(s) (example: `A`)
inverse = False #@param {type:"boolean"}
#@markdown - inverse the `fix_pos` selection (define position to "free" [or design] instead of "fix")

rm_aa = "" #@param {type:"string"}
rm_aa = ",".join(list(re.sub("[^A-Z]+","",rm_aa.upper())))
if rm_aa == "": rm_aa = None
#@markdown - specify amino acid(s) to exclude (example: `C,A,T`)

clear_mem()
mpnn_model = mk_mpnn_model(model_name)
mpnn_model.prep_inputs(pdb_filename=pdb_path,
                       chain=chains, homooligomer=homooligomer,
                       fix_pos=fix_pos, inverse=inverse,
                       rm_aa=rm_aa)
print("length", mpnn_model._len)
if "fix_pos" in mpnn_model._inputs:
  print("the following positions will be fixed:")
  print(mpnn_model._inputs["fix_pos"])

In [None]:
%%time
from scipy.special import log_softmax
#@markdown ### Design Options
num_seqs = 32 #@param ["0","1", "2", "4", "8", "16", "32", "64", "128", "256"] {type:"raw"}
sampling_temp = 0.1 #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5", "1.0"] {type:"raw"}
#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.
parallel = False #@param {type:"boolean"}
#@markdown - sample sequences in parallel (same result, but may speedup runtime for fancy GPUs)

def compute_seqid(S):
  mask = mpnn_model._inputs["mask"].copy()
  if "fix_pos" in mpnn_model._inputs:
    mask[mpnn_model._inputs["fix_pos"]] = 0
  seq = S.argmax(-1)
  seqid = seq == mpnn_model._inputs["S"]
  seqid = (seqid * mask).sum() / mask.sum()
  return seqid

def split_seq(seq):
  if len(mpnn_model._lengths) > 1:
    return "".join(np.insert(list(seq),np.cumsum(mpnn_model._lengths[:-1]),"/"))
  else:
    return seq

if parallel:
  outputs = mpnn_model.sample_parallel(batch=num_seqs, temperature=sampling_temp)
  outputs["seqid"] = [compute_seqid(S) for S in outputs["S"]]
  for n in range(num_seqs):
    print(f'>score:{outputs["score"][n]:.3f}_seqid:{outputs["seqid"][n]:.3f}\n{split_seq(outputs["seq"][n])}')
else:
  for n in range(num_seqs):
    outputs = mpnn_model.sample(temperature=sampling_temp)
    outputs["seqid"] = compute_seqid(outputs["S"])    
    print(f'>score:{outputs["score"]:.3f}_seqid:{outputs["seqid"]:.3f}\n{split_seq(outputs["seq"])}')

In [None]:
#@markdown ### amino acid probabilties
mode = "unconditional" #@param ["unconditional", "conditional", "conditional_fix_pos"]
#@markdown - `unconditional` - P(sequence | structure) 
#@markdown - `conditional` - P(sequence | structure, sequence)
#@markdown - `conditional_fix_pos` - P(sequence | structure, sequence[fix_pos])
show = "all" #@param ["all", "fix_pos", "free_pos"]
import plotly.express as px
from scipy.special import softmax
from colabdesign.mpnn.model import residue_constants
L = mpnn_model._len
fix_pos = mpnn_model._inputs.get("fix_pos",[])
free_pos = np.delete(np.arange(L),fix_pos)
if mode == "unconditional":
  ar_mask = np.zeros((L,L))
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
if mode == "conditional":
  ar_mask = 1-np.eye(L)
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
if mode == "conditional_fix_pos":
  ar_mask = 1-np.eye(L)
  p = np.delete(np.arange(L),mpnn_model._inputs.get("fix_pos",[]))
  ar_mask[free_pos[:,None],free_pos[None,:]] = 0
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]

pdb_labels = np.array([f"{i}_{c}" for c,i in zip(mpnn_model.pdb["idx"]["chain"], mpnn_model.pdb["idx"]["residue"])])
pssm = softmax(logits,-1)
if show == "fix_pos":
  pdb_labels = pdb_labels[fix_pos]
  pssm = pssm[fix_pos]
if show == "free_pos":
  pdb_labels = pdb_labels[free_pos]
  pssm = pssm[free_pos]
fig = px.imshow(np.array(pssm).T,
               labels=dict(x="positions", y="amino acids", color="probability"),
               y=residue_constants.restypes + ["X"],
               #x=pdb_labels,
               zmin=0,
               zmax=1,
               template="simple_white",
              )
fig.update_xaxes(side="top")
fig.show()