<a href="https://colab.research.google.com/github/sokrypton/ColabBio/blob/main/categorical_jacobian/gLM2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Categorical Jacobian on gLM2

In [None]:
#@markdown ## setup gLM2_650M
import os
os.system("pip -q install --no-dependencies flash_attn")

MODEL_NAME = "tattabio/gLM2_650M"
from transformers import AutoTokenizer, AutoModelForMaskedLM
DEVICE = "cuda"
MODEL = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True).eval().to(DEVICE)
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
MASK_ID = TOKENIZER.convert_tokens_to_ids('<mask>')

import torch
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from string import ascii_uppercase, ascii_lowercase
alphabet_list = list(ascii_uppercase+ascii_lowercase)


TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def get_categorical_jacobian(seqs, prepend_seq="<+>", fast=False):
  # ∂in/∂out

  xs = []
  masks = []
  for seq in seqs:
    if len(seq) > 0:
      x = TOKENIZER([prepend_seq+seq])["input_ids"][0]
      mask = np.pad(np.full(len(seq),True),[len(x)-len(seq),0])
      xs.append(x)
      masks.append(mask)

  x = torch.tensor(np.concatenate(xs,-1)[None]).to(DEVICE)
  mask = np.concatenate(masks,-1)

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
    f = lambda x: MODEL(x).logits[:, :, 4:24].detach().cpu().numpy()
    fx = f(x.to(DEVICE))[0][mask]

    ln = sum(mask)
    if fast:
      fx_h = np.zeros([ln, 1 , ln, 20], dtype=np.float32)
      x = x.to(DEVICE)
    else:
      fx_h = np.zeros([ln, 20, ln, 20], dtype=np.float32)
      x = torch.tile(x, [20, 1]).to(DEVICE)

    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      i = 0
      for n in range(len(mask)):  # for each position
        x_h = torch.clone(x)
        if mask[n]:
          # mutate to all 20 aa
          if fast:
            x_h[:, n] = MASK_ID
          else:
            x_h[:, n] = torch.arange(4, 24)
          fx_h[i] = f(x_h)[:,mask]
          i += 1
          pbar.update(1)

    return fx_h - fx

def J_to_contact_map(J):
  J_copy = J.copy()
  # center
  for k in range(4):
    if J_copy.shape[k] > 1:
      J_copy -= J_copy.mean(k,keepdims=True)

  # l2norm
  raw = np.sqrt(np.square(J_copy).sum((1,3)))
  np.fill_diagonal(raw, 0)

  # apc
  apc = raw - (raw.sum(0,keepdims=True) * raw.sum(1,keepdims=True)) / raw.sum()
  np.fill_diagonal(apc, 0)

  # symm
  apc = (apc + apc.T)/2

  return raw, apc

def plot_ticks(Ls, axes=None):
  if axes is None: axes = plt.gca()
  Ln = sum(Ls)
  L_prev = 0
  for L_i in Ls[:-1]:
    L = L_prev + L_i
    L_prev += L_i
    plt.plot([0,Ln],[L,L],color="black")
    plt.plot([L,L],[0,Ln],color="black")
  ticks = np.cumsum([0]+Ls)
  ticks = (ticks[1:] + ticks[:-1])/2
  axes.set_yticks(ticks)
  axes.set_yticklabels(alphabet_list[:len(ticks)])

In [None]:
#@markdown ## input sequence(s)

seq_A = "MRILPISTIKGKLNEFVDAVSSTQDQITITKNGAPAAVLVGADEWESLQETLYWLAQPGIRESIAEADADIASGRTYGEDEIRAEFGVPRRPHDYKDDDDK" # @param {type:"string"}
seq_B = "PYTVRFTTTARRDLHKLPPRILAAVVEFAFGDLSREPLRVGKPLRRELAGTFSARRGTYRLLYRIDDEHTTVVILRVDHRADIYRR" # @param {type:"string"}
seq_C = "" # @param {type:"string"}
seq_D = "" # @param {type:"string"}
seq_E = "" # @param {type:"string"}
seq_F = "" # @param {type:"string"}
#@markdown settings
#@markdown ---
seperator = "<->" # @param ["<+>","<->",""]
fast = False # @param {type:"boolean"}
#@markdown - `fast`=`True` - only perturb the `mask` token


seqs = []
Ls = []
for seq in [seq_A,seq_B,seq_C,seq_D,seq_E,seq_F]:
  seq = seq.replace(" ","")
  if len(seq) > 0:
    seqs.append(''.join([i for i in seq.upper() if i.isalpha()]))
    Ls.append(len(seq))

J = get_categorical_jacobian(seqs,
                             prepend_seq=seperator,
                             fast=fast)

raw, apc = J_to_contact_map(J)
L = apc.shape[0]
plt.figure(figsize=(5,5),dpi=200)
plt.imshow(apc,cmap="Blues", interpolation='none',
           extent=(0, L, L, 0))
plot_ticks(Ls)
plt.show()