<a href="https://colab.research.google.com/github/sokrypton/ColabBio/blob/main/categorical_jacobian_esm2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**get Categorical Jacobian from ESM2**
##(aka. extract conservation and coevolution for your favorite protein)

In [None]:
%%time
#@markdown ##setup model
model_name = "esm2_t33_650M_UR50D" # @param ["esm2_t48_15B_UR50D","esm2_t36_3B_UR50D","esm2_t33_650M_UR50D","esm2_t30_150M_UR50D","esm2_t12_35M_UR50D","esm2_t6_8M_UR50D"]
# this step will take ~3mins
import torch
import os
if not os.path.isfile("utils.py"):
  os.system("wget -qnc https://raw.githubusercontent.com/sokrypton/algosb_2021/main/utils.py")
  os.system("apt-get install aria2 -qq")
  os.system("mkdir -p /root/.cache/torch/hub/checkpoints/")

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import softmax

import pandas as pd
import numpy as np
import bokeh.plotting
bokeh.io.output_notebook()
from bokeh.models import BasicTicker, PrintfTickFormatter
from bokeh.palettes import viridis, RdBu
from bokeh.transform import linear_cmap
from bokeh.plotting import figure, show

from matplotlib.colors import to_hex
cmap = plt.colormaps["bwr_r"]
bwr_r = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]

def pssm_to_dataframe(pssm, esm_alphabet):
  sequence_length = pssm.shape[0]
  df = pd.DataFrame(pssm, index=np.arange(1, sequence_length + 1), columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['Position', 'Amino Acid', 'Probability']
  return df

def contact_to_dataframe(con):
  sequence_length = con.shape[0]
  df = pd.DataFrame(con, index=np.arange(1, sequence_length + 1), columns=np.arange(1, sequence_length + 1))
  df = df.stack().reset_index()
  df.columns = ['i', 'j', 'value']
  return df

def pair_to_dataframe(pair,esm_alphabet):
  sequence_length = pair.shape[0]
  df = pd.DataFrame(pair, index=list(esm_alphabet), columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['aa_i', 'aa_j', 'value']
  return df

from utils import *
import tqdm.notebook

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def load_model(model_name="esm2_t36_3B_UR50D"):
  if not os.path.isfile(f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"):
    os.system(f"aria2c -q -x 16 -d /root/.cache/torch/hub/checkpoints/ https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt")
    os.system(f"aria2c -q -x 16 -d /root/.cache/torch/hub/checkpoints/ https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt")
  model, alphabet = torch.hub.load("facebookresearch/esm:main", model_name)
  model = model.to(DEVICE)
  model = model.eval()
  return model, alphabet

def get_logits(seq, p=1):
  x,ln = alphabet.get_batch_converter()([(None,seq)])[-1],len(seq)
  if p is None: p = ln
  with torch.no_grad():
    def f(x):
      fx = model(x)["logits"][:,1:(ln+1),4:24]
      return fx
    logits = np.zeros((ln,20))
    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      for n in range(0,ln,p):
        m = min(n+p,ln)
        x_h = torch.tile(torch.clone(x),[m-n,1])
        for i in range(m-n):
          x_h[i,n+i+1] = alphabet.mask_idx
        fx_h = f(x_h.to(DEVICE))
        for i in range(m-n):
          logits[n+i] = fx_h[i,n+i].cpu().numpy()
        pbar.update(p)
    return logits

def get_categorical_jacobian(seq):
  # ∂in/∂out
  x,ln = alphabet.get_batch_converter()([("seq",seq)])[-1],len(seq)
  with torch.no_grad():
    f = lambda x: model(x)["logits"][...,1:(ln+1),4:24].cpu().numpy()
    fx = f(x.to(DEVICE))[0]
    x = torch.tile(x,[20,1]).to(DEVICE)
    fx_h = np.zeros((ln,20,ln,20))
    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      for n in range(ln): # for each position
        x_h = torch.clone(x)
        x_h[:,n+1] = torch.arange(4,24) # mutate to all 20 aa
        fx_h[n] = f(x_h)
        pbar.update(1)
    # note: direction here differs from manuscript
    # positive = good
    # negative = bad
    jac = fx_h-fx
  # center & symmetrize
  for i in range(4): jac -= jac.mean(i,keepdims=True)
  jac = (jac + jac.transpose(2,3,0,1))/2

  return jac

model, alphabet = load_model(model_name)
esm_alphabet_len = len(alphabet.all_toks)
esm_alphabet = list("".join(alphabet.all_toks[4:24]))

In [None]:
#@markdown ##enter sequence

sequence = "MKAKELREKSVEELNTELLNLLREQFNLRMQAASGQLQQSHLLKQVRRDVARVKTLLNEKAGA" # @param {type:"string"}


In [None]:
#@markdown ##compute conservation

logits = get_logits(sequence)
np.savetxt("conservation_logits.txt",logits)
pssm = softmax(logits,-1)
df = pssm_to_dataframe(pssm, esm_alphabet)

# plot pssm
num_colors = 256  # You can adjust this number
palette = viridis(256)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="CONSERVATION",
           x_range=[str(x) for x in range(len(sequence)+1)],
           y_range=list(esm_alphabet),
           width=900, height=400,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('Position', '@Position'), ('Amino Acid', '@{Amino Acid}'), ('Probability', '@Probability')])

r = p.rect(x="Position", y="Amino Acid", width=1, height=1, source=df,
           fill_color=linear_cmap('Probability', palette, low=0, high=1),
           line_color=None)
p.xaxis.visible = False  # Hide the x-axis
show(p)

In [None]:
#@markdown ##compute coevolution

jac = get_categorical_jacobian(sequence)
contacts = get_contacts(jac)
np.savetxt("coevolution.txt",contacts)
df = contact_to_dataframe(contacts)

# plot pssm
num_colors = 256  # You can adjust this number
palette = viridis(256)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="COEVOLUTION",
           x_range=[str(x) for x in range(len(sequence)+1)],
           y_range=[str(x) for x in range(len(sequence)+1)],
           width=900, height=900,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('i', '@i'), ('j', '@j'), ('value', '@value')])

r = p.rect(x="i", y="j", width=1, height=1, source=df,
           fill_color=linear_cmap('value', palette, low=df.value.min(), high=df.value.max()),
           line_color=None)
p.xaxis.visible = False  # Hide the x-axis
p.yaxis.visible = False  # Hide the x-axis
show(p)

In [None]:
position_i = 10 # @param {type:"integer"}
position_j = 13 # @param {type:"integer"}
i = position_i - 1
j = position_j - 1
df = pair_to_dataframe(jac[i,:,j,:], esm_alphabet)

# plot pssm
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title=f"coevolution between {position_i} {position_j}",
           x_range=list(esm_alphabet),
           y_range=list(esm_alphabet),
           width=400, height=400,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('aa_i', '@aa_i'), ('aa_j', '@aa_j'), ('value', '@value')])
p.xaxis.axis_label = f"{position_i}"
p.yaxis.axis_label = f"{position_j}"

r = p.rect(x="aa_i", y="aa_j", width=1, height=1, source=df,
           fill_color=linear_cmap('value', bwr_r, low=-3.0, high=3.0),
           line_color=None)
show(p)