<a href="https://colab.research.google.com/github/sokrypton/ColabBio/blob/main/categorical_jacobian/esm3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**get Categorical Jacobian from ESM3**
##(aka. extract conservation and coevolution for your favorite protein)

Notes:
- max length: 1200 (on free colab T4 GPU)

In [None]:
%set_env TOKENIZERS_PARALLELISM=false
#@markdown ##huggingface token
#@markdown In order to download the weights, esm requires users to accept the non-commercial license.
#@markdown
#@markdown - Go here to create accpet license: https://huggingface.co/EvolutionaryScale/esm3
#@markdown
#@markdown - Go here to get token:
#@markdown https://huggingface.co/settings/tokens
token = "" # @param {type:"string"}
from huggingface_hub import login
login(token=token, add_to_git_credential=True)

In [None]:
%%time
#@markdown ##setup model (3m 30s)
assert "token" in dir() and token != "", "please set token"
import os
import numpy as np
import torch

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

import os
if not os.path.isfile("utils.py"):
  os.system("wget -qnc https://raw.githubusercontent.com/sokrypton/ColabBio/main/categorical_jacobian/utils.py")
  os.system("pip install git+https://github.com/sokrypton/esm3.git")


from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.constants.esm3 import SEQUENCE_VOCAB
from esm.models.esm3 import ESM3
from esm.sdk.api import (
    ESMProtein,
    GenerationConfig,
)
model =  ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device(DEVICE)).eval()

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import softmax

import pandas as pd
import numpy as np
import bokeh.plotting
bokeh.io.output_notebook()
from bokeh.models import BasicTicker, PrintfTickFormatter
from bokeh.palettes import viridis, RdBu
from bokeh.transform import linear_cmap
from bokeh.plotting import figure, show

from matplotlib.colors import to_hex
cmap = plt.colormaps["bwr_r"]
bwr_r = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]
cmap = plt.colormaps["gray_r"]
gray = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]

def pssm_to_dataframe(pssm, esm_alphabet):
  sequence_length = pssm.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(pssm, index=idx, columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['Position', 'Amino Acid', 'Probability']
  return df

def contact_to_dataframe(con):
  sequence_length = con.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(con, index=idx, columns=idx)
  df = df.stack().reset_index()
  df.columns = ['i', 'j', 'value']
  return df

def pair_to_dataframe(pair,esm_alphabet):
  sequence_length = pair.shape[0]
  df = pd.DataFrame(pair, index=list(esm_alphabet), columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['aa_i', 'aa_j', 'value']
  return df

from utils import *
import tqdm.notebook

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def get_logits(seq, batch_size=1):
  protein_prompt = ESMProtein(sequence=seq)
  protein_tensor = model.encode(protein_prompt)
  x, ln = protein_tensor.sequence, len(seq)
  with torch.no_grad():
    f = lambda x: model(sequence_tokens=x).sequence_logits[:, 1:(ln+1), 4:24].to(torch.float).detach().cpu().numpy()
    logits = np.zeros((ln, 20), dtype=np.float32)
    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      for n in range(0, ln, batch_size):
        m = min(n + batch_size, ln)
        x_h = torch.clone(x).unsqueeze(0).repeat(m - n, 1)
        for i in range(m - n):
          x_h[i, n + i + 1] = SEQUENCE_VOCAB.index("<mask>")
        fx_h = f(x_h.to(DEVICE))
        for i in range(m - n):
          logits[n + i] = fx_h[i, n + i]
        pbar.update(m - n)
  return logits

def get_categorical_jacobian(seq, batch_size=1):
  # ∂in/∂out
  protein_prompt = ESMProtein(sequence=seq)
  protein_tensor = model.encode(protein_prompt)
  x, ln = protein_tensor.sequence, len(seq)
  with torch.no_grad():
    f = lambda x: model(sequence_tokens=x).sequence_logits[..., 1:(ln+1), 4:24].to(torch.float).detach().cpu().numpy()
    fx = f(x[None].to(DEVICE))[0]
    fx_h = np.zeros([ln, 20, ln, 20], dtype=np.float32)

    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      for n in range(ln):  # for each position
        for i in range(0, 20, batch_size):
          end = min(i + batch_size, 20)
          x_h = torch.clone(x).unsqueeze(0).repeat(end - i, 1).to(DEVICE)
          x_h[:, n+1] = torch.arange(4 + i, 4 + end)
          fx_h[n, i:end] = f(x_h)
        pbar.update(1)

  return fx_h - fx

esm_alphabet = SEQUENCE_VOCAB[4:24]
ALPHABET = "AFILVMWYDEKRHNQSTGPC"
ALPHABET_map = [esm_alphabet.index(a) for a in ALPHABET]

def jac_to_con(jac, center=True, diag="remove", apc=True):

  X = jac.copy()
  Lx,Ax,Ly,Ay = X.shape
  if Ax == 20:
    X = X[:,ALPHABET_map,:,:]

  if Ay == 20:
    X = X[:,:,:,ALPHABET_map]
    if symm and Ax == 20:
      X = (X + X.transpose(2,3,0,1))/2

  if center:
    for i in range(4):
      if X.shape[i] > 1:
        X -= X.mean(i,keepdims=True)

  contacts = np.sqrt(np.square(X).sum((1,3)))

  if symm and (Ax != 20 or Ay != 20):
    contacts = (contacts + contacts.T)/2

  if diag == "remove":
    np.fill_diagonal(contacts,0)

  if diag == "normalize":
    contacts_diag = np.diag(contacts)
    contacts = contacts / np.sqrt(contacts_diag[:,None] * contacts_diag[None,:])

  if apc:
    ap = contacts.sum(0,keepdims=True) * contacts.sum(1, keepdims=True) / contacts.sum()
    contacts = contacts - ap

  if diag == "remove":
    np.fill_diagonal(contacts,0)

  return {"jac":X, "contacts":contacts}

In [None]:
#@markdown ##enter sequence

sequence = "MKAKELREKSVEELNTELLNLLREQFNLRMQAASGQLQQSHLLKQVRRDVARVKTLLNEKAGA" # @param {type:"string"}

sequence = sequence.upper()
sequence = ''.join([i for i in sequence if i.isalpha()])

BATCH_SIZE = 20
if len(sequence) > 300:
  BATCH_SIZE = 5
if len(sequence) > 500:
  BATCH_SIZE = 2
if len(sequence) > 800:
  BATCH_SIZE = 1

os.makedirs("output",exist_ok=True)
with open("output/README.txt","w") as handle:
  handle.write("conservation_logits.txt = (L, A) matrix\n")
  handle.write("coevolution.txt = (L, L) matrix\n")
  handle.write("jac.npy = ((L*L-L)/2, A, A) tensor\n")
  handle.write("jac index can be recreated with np.triu_indices(L,1)\n")
  handle.write(f"[A]lphabet: {ALPHABET}\n")
  handle.write(f"sequence: {sequence}\n")

In [None]:
#@markdown ##compute conservation

logits = get_logits(sequence, batch_size=BATCH_SIZE)
logits = logits[:,ALPHABET_map]
np.savetxt(f"output/conservation_logits.txt",logits)
pssm = softmax(logits,-1)
df = pssm_to_dataframe(pssm, ALPHABET)

# plot pssm
num_colors = 256  # You can adjust this number
palette = viridis(256)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="CONSERVATION",
           x_range=[str(x) for x in range(1,len(sequence)+1)],
           y_range=list(ALPHABET)[::-1],
           width=900, height=400,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('Position', '@Position'), ('Amino Acid', '@{Amino Acid}'), ('Probability', '@Probability')])

r = p.rect(x="Position", y="Amino Acid", width=1, height=1, source=df,
           fill_color=linear_cmap('Probability', palette, low=0, high=1),
           line_color=None)
p.xaxis.visible = False  # Hide the x-axis
show(p)

In [None]:
#@markdown ##compute coevolution
#@markdown Set postprocessing options such as to `center`, [`symm`]etrize, remove [`diag`]onal and to perform average product correction `apc`.
center = True # @param {type:"boolean"}
symm = True # @param {type:"boolean"}
diag = "remove" # @param ["remove", "normalize", "none"]
apc = True # @param {type:"boolean"}
settings = dict(sequence=sequence)
if not "jac" in dir() or settings != settings_:
  jac = get_categorical_jacobian(sequence, batch_size=BATCH_SIZE)
  settings_ = settings.copy()

con = jac_to_con(jac, center=center, diag=diag, apc=apc)

np.savetxt(f"output/coevolution.txt",con["contacts"])
i,j = np.triu_indices(len(sequence),1)
np.save(f"output/jac.npy",con["jac"][i,:,j,:].astype(np.float16))

df = contact_to_dataframe(con["contacts"])
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="COEVOLUTION",
          x_range=[str(x) for x in range(1,len(sequence)+1)],
          y_range=[str(x) for x in range(1,len(sequence)+1)][::-1],
          width=800, height=800,
          tools=TOOLS, toolbar_location='below',
          tooltips=[('i', '@i'), ('j', '@j'), ('value', '@value')])

r = p.rect(x="i", y="j", width=1, height=1, source=df,
          fill_color=linear_cmap('value', gray, low=df.value.min(), high=df.value.max()),
          line_color=None)
p.xaxis.visible = False  # Hide the x-axis
p.yaxis.visible = False  # Hide the x-axis
show(p)

In [None]:
#@markdown ##show table of top covarying positions
from google.colab import data_table

sub_df = df[df["j"]>df["i"]].sort_values('value',ascending=False)
data_table.DataTable(sub_df, include_index=False, num_rows_per_page=20, min_width=10)

In [None]:
#@markdown ##select pair of residues to investigate
#@markdown Note: 1-indexed (first position is 1)

position_i = 15 # @param {type:"integer"}
position_j = 57 # @param {type:"integer"}
i = position_i - 1
j = position_j - 1
df = pair_to_dataframe(con["jac"][i,:,j,:], ALPHABET)

# plot pssm
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title=f"coevolution between {position_i} {position_j}",
          x_range=list(ALPHABET),
          y_range=list(ALPHABET)[::-1],
          width=400, height=400,
          tools=TOOLS, toolbar_location='below',
          tooltips=[('aa_i', '@aa_i'), ('aa_j', '@aa_j'), ('value', '@value')])
p.xaxis.axis_label = f"{sequence[i]}{position_i}"
p.yaxis.axis_label = f"{sequence[j]}{position_j}"

r = p.rect(x="aa_i", y="aa_j", width=1, height=1, source=df,
            fill_color=linear_cmap('value', bwr_r, low=-3.0, high=3.0),
            line_color=None, dilate=True)
show(p)


In [None]:
#@title download results (optional)
from google.colab import files
os.system(f"zip -r output.zip output/")
files.download(f'output.zip')