<a href="https://colab.research.google.com/github/sokrypton/ColabBio/blob/main/categorical_jacobian/proteinmpnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**get Categorical Jacobian from ProteinMPNN**
##(aka. extract conservation and coevolution for your favorite protein)

In [None]:
%%time
#@markdown ##setup model
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
use_solublempnn = False # @param {type:"boolean"}

import os
if not os.path.isdir("colabdesign"):
  print("installing ColabDesign...")
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")

if not os.path.isfile("utils.py"):
  os.system("wget -qnc https://raw.githubusercontent.com/sokrypton/ColabBio/main/categorical_jacobian/utils.py")
  os.system("apt-get install aria2 -qq")
  os.system("mkdir -p /root/.cache/torch/hub/checkpoints/")

from colabdesign.mpnn import mk_mpnn_model, clear_mem
from colabdesign.shared.protein import pdb_to_string

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
import pandas as pd

from google.colab import files
def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

from scipy.special import softmax
import pandas as pd
import numpy as np
import bokeh.plotting
bokeh.io.output_notebook()
from bokeh.models import BasicTicker, PrintfTickFormatter
from bokeh.palettes import viridis, RdBu
from bokeh.transform import linear_cmap
from bokeh.plotting import figure, show

from matplotlib.colors import to_hex
cmap = plt.colormaps["bwr_r"]
bwr_r = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]
cmap = plt.colormaps["gray_r"]
gray = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]

def pssm_to_dataframe(pssm, alphabet):
  sequence_length = pssm.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(pssm, index=idx, columns=list(alphabet))
  df = df.stack().reset_index()
  df.columns = ['Position', 'Amino Acid', 'Probability']
  return df

def contact_to_dataframe(con):
  sequence_length = con.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(con, index=idx, columns=idx)
  df = df.stack().reset_index()
  df.columns = ['i', 'j', 'value']
  return df

def pair_to_dataframe(pair, alphabet):
  sequence_length = pair.shape[0]
  df = pd.DataFrame(pair, index=list(alphabet), columns=list(alphabet))
  df = df.stack().reset_index()
  df.columns = ['aa_i', 'aa_j', 'value']
  return df

from utils import *
import tqdm.notebook

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

clear_mem()
mpnn_model = mk_mpnn_model(model_name, weights="soluble" if use_solublempnn else "original")

# alphabet output
from colabdesign.mpnn.model import residue_constants
alphabet = "".join(residue_constants.restypes)
ALPHABET = "AFILVMWYDEKRHNQSTGPC"
ALPHABET_map = [alphabet.index(a) for a in ALPHABET]

In [None]:
#@markdown ##enter structure

#@markdown #### Input Options
pdb='6MRR' #@param {type:"string"}
#@markdown - leave blank to get an upload prompt
chains = "A" #@param {type:"string"}

pdb_path = get_pdb(pdb)
mpnn_model.prep_inputs(pdb_filename=pdb_path, chain=chains)
sequence = "".join([alphabet[x] for x in mpnn_model._inputs["S"]])
L = sum(mpnn_model._lengths)

os.makedirs("output",exist_ok=True)
with open("output/README.txt","w") as handle:
  handle.write("conservation_logits.txt = (L, A) matrix\n")
  handle.write("coevolution.txt = (L, L) matrix\n")
  handle.write("jac.npy = ((L*L-L)/2, A, A) tensor\n")
  handle.write("jac index can be recreated with np.triu_indices(L,1)\n")
  handle.write(f"[A]lphabet: {ALPHABET}\n")
  handle.write(f"sequence: {sequence}\n")

In [None]:
#@markdown ##compute conservation
ar_mask = 1-np.eye(L)
logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
logits = logits[:,ALPHABET_map]
np.savetxt(f"output/conservation_logits_{model_name}.txt",logits)
pssm = softmax(logits,-1)
df = pssm_to_dataframe(pssm, ALPHABET)

# plot pssm
num_colors = 256  # You can adjust this number
palette = viridis(256)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="CONSERVATION",
           x_range=[str(x) for x in range(1,len(sequence)+1)],
           y_range=list(ALPHABET)[::-1],
           width=900, height=400,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('Position', '@Position'), ('Amino Acid', '@{Amino Acid}'), ('Probability', '@Probability')])

r = p.rect(x="Position", y="Amino Acid", width=1, height=1, source=df,
           fill_color=linear_cmap('Probability', palette, low=0, high=1),
           line_color=None)
p.xaxis.visible = False  # Hide the x-axis
show(p)

In [None]:
#@markdown ##compute coevolution
ar_mask = 1-np.eye(L)
fx = mpnn_model.score(ar_mask=ar_mask)["logits"]
fx_h = np.zeros((L,21,L,21))
with tqdm.notebook.tqdm(total=L, bar_format=TQDM_BAR_FORMAT) as pbar:
  for i in range(L):
    S = mpnn_model._inputs["S"].copy()
    for a in range(20):
      S[i] = a
      fx_h[i,a] = mpnn_model.score(S=S,ar_mask=ar_mask)["logits"]
    pbar.update(1)

jac = fx_h-fx
# center & symmetrize
for i in range(4): jac -= jac.mean(i,keepdims=True)
jac = (jac + jac.transpose(2,3,0,1))/2
jac = jac[:,ALPHABET_map][...,ALPHABET_map]

contacts = get_contacts(jac, rm=0)
np.savetxt(f"output/coevolution_{model_name}.txt",contacts)
i,j = np.triu_indices(len(sequence),1)
np.save(f"output/jac_{model_name}.npy",jac[i,:,j,:].astype(np.float16))

df = contact_to_dataframe(contacts)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="COEVOLUTION",
          x_range=[str(x) for x in range(1,len(sequence)+1)],
          y_range=[str(x) for x in range(1,len(sequence)+1)][::-1],
          width=800, height=800,
          tools=TOOLS, toolbar_location='below',
          tooltips=[('i', '@i'), ('j', '@j'), ('value', '@value')])

r = p.rect(x="i", y="j", width=1, height=1, source=df,
          fill_color=linear_cmap('value', gray, low=df.value.min(), high=df.value.max()),
          line_color=None)
p.xaxis.visible = False  # Hide the x-axis
p.yaxis.visible = False  # Hide the x-axis
show(p)

In [None]:
#@markdown ##show table of top covarying positions
from google.colab import data_table

sub_df = df[df["j"]>df["i"]].sort_values('value',ascending=False)
data_table.DataTable(sub_df, include_index=False, num_rows_per_page=20, min_width=10)

In [None]:
#@markdown ##select pair of residues to investigate
#@markdown Note: 1-indexed (first position is 1)

position_i = 16 # @param {type:"integer"}
position_j = 48 # @param {type:"integer"}
i = position_i - 1
j = position_j - 1
df = pair_to_dataframe(jac[i,:,j,:], ALPHABET)

# plot pssm
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title=f"coevolution between {position_i} {position_j}",
           x_range=list(ALPHABET),
           y_range=list(ALPHABET)[::-1],
           width=400, height=400,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('aa_i', '@aa_i'), ('aa_j', '@aa_j'), ('value', '@value')])
p.xaxis.axis_label = f"{sequence[i]}{position_i}"
p.yaxis.axis_label = f"{sequence[j]}{position_j}"

r = p.rect(x="aa_i", y="aa_j", width=1, height=1, source=df,
           fill_color=linear_cmap('value', bwr_r, low=-3.0, high=3.0),
           line_color=None, dilate=True)
show(p)

In [None]:
#@title download results (optional)
from google.colab import files
os.system(f"zip -r output.zip output/")
files.download(f'output.zip')