<a href="https://colab.research.google.com/github/sokrypton/AM216/blob/main/gLM2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Categorical Jacobian on gLM2

In [None]:
#@markdown ## setup gLM2_650M
import os
import torch
os.system("pip -q install --no-dependencies flash_attn")

MODEL_NAME = "tattabio/gLM2_650M"
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

from transformers import AutoTokenizer, AutoModelForMaskedLM
MODEL = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True).eval().to(DEVICE)
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
MASK_ID = TOKENIZER.convert_tokens_to_ids('<mask>')

import torch
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from string import ascii_uppercase, ascii_lowercase
alphabet_list = list(ascii_uppercase+ascii_lowercase)


TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def get_categorical_jacobian(seqs, prepend_seq="<+>", fast=False):
  # ∂in/∂out

  xs = []
  masks = []
  for seq in seqs:
    if len(seq) > 0:
      x = TOKENIZER([prepend_seq+seq])["input_ids"][0]
      mask = np.pad(np.full(len(seq),True),[len(x)-len(seq),0])
      xs.append(x)
      masks.append(mask)

  x = torch.tensor(np.concatenate(xs,-1)[None]).to(DEVICE)
  mask = np.concatenate(masks,-1)

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
    f = lambda x: MODEL(x).logits[:, :, 4:24].detach().cpu().numpy()
    fx = f(x.to(DEVICE))[0][mask]

    ln = sum(mask)
    if fast:
      fx_h = np.zeros([ln, 1 , ln, 20], dtype=np.float32)
      x = x.to(DEVICE)
    else:
      fx_h = np.zeros([ln, 20, ln, 20], dtype=np.float32)
      x = torch.tile(x, [20, 1]).to(DEVICE)

    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      i = 0
      for n in range(len(mask)):  # for each position
        x_h = torch.clone(x)
        if mask[n]:
          # mutate to all 20 aa
          if fast:
            x_h[:, n] = MASK_ID
          else:
            x_h[:, n] = torch.arange(4, 24)
          fx_h[i] = f(x_h)[:,mask]
          i += 1
          pbar.update(1)

    return fx_h - fx

def jac_to_con(jac, symm=True, center=True, diag="remove", apc=True):

  X = jac.copy()
  Lx,Ax,Ly,Ay = X.shape
  if Ax == 20:
    X = X[:,ALPHABET_map,:,:]

  if Ay == 20:
    X = X[:,:,:,ALPHABET_map]
    if symm and Ax == 20:
      X = (X + X.transpose(2,3,0,1))/2

  if center:
    for i in range(4):
      if X.shape[i] > 1:
        X -= X.mean(i,keepdims=True)

  contacts = np.sqrt(np.square(X).sum((1,3)))

  if symm and (Ax != 20 or Ay != 20):
    contacts = (contacts + contacts.T)/2

  if diag == "remove":
    np.fill_diagonal(contacts,0)

  if diag == "normalize":
    contacts_diag = np.diag(contacts)
    contacts = contacts / np.sqrt(contacts_diag[:,None] * contacts_diag[None,:])

  if apc:
    ap = contacts.sum(0,keepdims=True) * contacts.sum(1, keepdims=True) / contacts.sum()
    contacts = contacts - ap

  if diag == "remove":
    np.fill_diagonal(contacts,0)

  return {"jac":X, "contacts":contacts}

def plot_ticks(Ls, axes=None):
  if axes is None: axes = plt.gca()
  Ln = sum(Ls)
  L_prev = 0
  for L_i in Ls[:-1]:
    L = L_prev + L_i
    L_prev += L_i
    plt.plot([0,Ln],[L,L],color="black")
    plt.plot([L,L],[0,Ln],color="black")
  ticks = np.cumsum([0]+Ls)
  ticks = (ticks[1:] + ticks[:-1])/2
  axes.set_yticks(ticks)
  axes.set_yticklabels(alphabet_list[:len(ticks)])


####################################
####################################

from scipy.special import softmax
import pandas as pd
import bokeh.plotting
from bokeh.transform import linear_cmap
from bokeh.plotting import figure, show
from bokeh.palettes import viridis
from transformers import AutoTokenizer, AutoModelForMaskedLM
from matplotlib.colors import to_hex
import tqdm.notebook
bokeh.io.output_notebook()

def pssm_to_dataframe(pssm, esm_alphabet):
  sequence_length = pssm.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(pssm, index=idx, columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['Position', 'Amino Acid', 'Probability']
  return df

def contact_to_dataframe(con):
  sequence_length = con.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(con, index=idx, columns=idx)
  df = df.stack().reset_index()
  df.columns = ['i', 'j', 'value']
  return df

def pair_to_dataframe(pair,esm_alphabet):
  df = pd.DataFrame(pair, index=list(esm_alphabet), columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['aa_i', 'aa_j', 'value']
  return df

cmap = plt.colormaps["bwr_r"]
bwr_r = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]
cmap = plt.colormaps["gray_r"]
gray = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]

esm_alphabet = TOKENIZER.convert_ids_to_tokens(range(4, 25))
ALPHABET = "AFILVMWYDEKRHNQSTGPC"
ALPHABET_map = [esm_alphabet.index(a) for a in ALPHABET]

In [None]:
#@markdown # **RUN**
#@markdown ---
#@markdown ## settings
seperator = "<+>" # @param ["<+>","<->",""]
fast = False # @param {type:"boolean"}
#@markdown - only perturb the `mask` token
#@markdown ---
#@markdown ## input sequence(s)
seq_A = "MRILPISTIKGKLNEFVDAVSSTQDQITITKNGAPAAVLVGADEWESLQETLYWLAQPGIRESIAEADADIASGRTYGEDEIRAEFGVPRRPHDYKDDDDK" # @param {type:"string"}
seq_B = "PYTVRFTTTARRDLHKLPPRILAAVVEFAFGDLSREPLRVGKPLRRELAGTFSARRGTYRLLYRIDDEHTTVVILRVDHRADIYRR" # @param {type:"string"}
seq_C = "" # @param {type:"string"}
seq_D = "" # @param {type:"string"}
seq_E = "" # @param {type:"string"}
seq_F = "" # @param {type:"string"}

seqs = []
Ls = []
for seq in [seq_A,seq_B,seq_C,seq_D,seq_E,seq_F]:
  seq = seq.replace(" ","")
  if len(seq) > 0:
    seqs.append(''.join([i for i in seq.upper() if i.isalpha()]))
    Ls.append(len(seq))

jac = get_categorical_jacobian(seqs,
                             prepend_seq=seperator,
                             fast=fast)

con = jac_to_con(jac)


L = con["contacts"].shape[0]
plt.figure(figsize=(5,5),dpi=200)
plt.imshow(con["contacts"],cmap="Blues", interpolation='none',
           extent=(0, L, L, 0))
plot_ticks(Ls)
plt.show()

In [None]:
#@markdown ## Interactive Coevolution Plot (Optional)
sequence = "".join(seqs)

os.makedirs("output",exist_ok=True)
with open("output/README.txt","w") as handle:
  handle.write("conservation_logits.txt = (L, A) matrix\n")
  handle.write("jac.npy = ((L*L-L)/2, A, A) tensor\n")
  handle.write("jac index can be recreated with np.triu_indices(L,1)\n")
  handle.write(f"[A]lphabet: {ALPHABET}\n")
  handle.write(f"sequence: {sequence}\n")

model_name = os.path.basename(MODEL_NAME)

np.savetxt(f"output/coevolution_{model_name}.txt",con["contacts"])
i,j = np.triu_indices(len(sequence),1)
np.save(f"output/jac_{model_name}.npy",con["jac"][i,:,j,:].astype(np.float16))

df = contact_to_dataframe(con["contacts"])
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="COEVOLUTION",
          x_range=[str(x) for x in range(1,len(sequence)+1)],
          y_range=[str(x) for x in range(1,len(sequence)+1)][::-1],
          width=800, height=800,
          tools=TOOLS, toolbar_location='below',
          tooltips=[('i', '@i'), ('j', '@j'), ('value', '@value')])

r = p.rect(x="i", y="j", width=1, height=1, source=df,
          fill_color=linear_cmap('value', gray, low=df.value.min(), high=df.value.max()),
          line_color=None)
p.xaxis.visible = False  # Hide the x-axis
p.yaxis.visible = False  # Hide the x-axis
show(p)

In [None]:
#@markdown ##show table of top covarying positions (optional)
from google.colab import data_table

df = contact_to_dataframe(con["contacts"])
sub_df = df[df["j"]>df["i"]].sort_values('value',ascending=False)
data_table.DataTable(sub_df, include_index=False, num_rows_per_page=20, min_width=10)

In [None]:
#@markdown ##select pair of residues to investigate (optional)
#@markdown Note: 1-indexed (first position is 1)

position_i = 152 # @param {type:"integer"}
position_j = 68 # @param {type:"integer"}
if fast:
  print("this function is only supported when `fast=True`")
else:
  i = position_i - 1
  j = position_j - 1
  df = pair_to_dataframe(con["jac"][i,:,j,:], ALPHABET)

  # plot pssm
  TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
  p = figure(title=f"coevolution between {position_i} {position_j}",
            x_range=list(ALPHABET),
            y_range=list(ALPHABET)[::-1],
            width=400, height=400,
            tools=TOOLS, toolbar_location='below',
            tooltips=[('aa_i', '@aa_i'), ('aa_j', '@aa_j'), ('value', '@value')])
  p.xaxis.axis_label = f"{sequence[i]}{position_i}"
  p.yaxis.axis_label = f"{sequence[j]}{position_j}"

  r = p.rect(x="aa_i", y="aa_j", width=1, height=1, source=df,
              fill_color=linear_cmap('value', bwr_r, low=-3.0, high=3.0),
              line_color=None, dilate=True)
  show(p)

In [None]:
#@title download results (optional)
from google.colab import files
os.system(f"zip -r output.zip output/")
files.download(f'output.zip')