<a href="https://colab.research.google.com/github/sokrypton/ws2023/blob/main/day2/notebook_part2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import libraries
import jax
import jax.numpy as jnp
import matplotlib.pylab as plt

# new library for interations!
import itertools as it

## Statistical models

In [None]:
# msa
msa = jnp.array([
    [0,1], # sequence 1
    [1,0], # sequence 2
    [0,1], # sequence 3
    [1,0], # sequence 4
])

# define size
N, L = msa.shape

# define number of states
A = 2

print("msa")
print(msa)

# one hot encode MSA
msa = jnp.eye(A)[msa]

## PSSM

In [None]:
def model(params, inputs):

  msa = inputs["msa"]
  pssm_logits = params["pssm_logits"]

  # normalize so values sum to one
  # pssm = jnp.exp(pssm_logits)/jnp.exp(pssm_logits).sum(-1,keepdims=True)
  pssm = jax.nn.softmax(pssm_logits)

  # log-probability
  P = (msa * pssm).sum(-1)
  logP = jnp.log(P).sum(-1)

  # define loss function
  # note for optimization, loss is "minimized"
  # so to "maximize" the probability, we will "minimize" the -logP
  loss = -logP.mean()

  return loss, {"P":P, "pssm":pssm}

In [None]:
inputs = {"msa":msa}

# random initialization
key = jax.random.PRNGKey(1)
params = {"pssm_logits":jax.random.gumbel(key, shape=(L,A))}

loss, outputs = model(params, inputs)

# lets look at the PSSM BEFORE optimization
print("PSSM BEFORE")
print(outputs["pssm"])

print("P(0,1) BEFORE")
P = outputs["P"][0].prod()
print(P)

In [None]:
# get gradients = ∂loss/∂params
grad_model = jax.value_and_grad(model, has_aux=True)

In [None]:
(loss,outputs), grad = grad_model(params, inputs)

In [None]:
loss

In [None]:
grad

In [None]:
for n in range(100):
  params["pssm_logits"] -= 0.1 * grad["pssm_logits"]
  (loss,outputs), grad = grad_model(params, inputs)
  if (n+1) % 10 == 0:
    print(n+1,loss)

In [None]:
print("PSSM AFTER")
print(outputs["pssm"])

print("P(0,1) AFTER")
P = outputs["P"][0].prod()
print(P)

## Markov Random Field
Adding two-body term (W)

In [None]:
import itertools as it
# "Z" alignment of all possible sequences!
msa_Z =  jnp.array(list(it.product(jnp.arange(A),repeat=L)))
print("msa_Z")
print(msa_Z)

# one hot encode
msa_Z = jnp.eye(A)[msa_Z]

In [None]:
def model(params, inputs):

  msa_H = inputs["msa"]
  msa_Z = inputs["msa_Z"]

  # 1-body-term
  V = params["V"]

  # 2-body-term
  W = params["W"]
  W = W * (1-jnp.eye(L)[:,None,:,None]) # set diagonal to zero
  W = 0.5 * (W + W.transpose((2,3,0,1))) # symmetrize

  H = (msa_H * (V + jnp.tensordot(msa_H, W, 2))).sum((1,2))
  Z = (msa_Z * (V + jnp.tensordot(msa_Z, W, 2))).sum((1,2))

  # P(X|θ) = Probability of MSA (X) given parameters (θ)
  P = jnp.exp(H)/jnp.exp(Z).sum()

  # can also be thought as
  # L(θ|X) = Likelihood of parameters (θ) given MSA (X)
  # Log-Likelihood
  LL = H - jnp.log(jnp.exp(Z).sum())

  loss = -LL.mean()
  return loss, {"P":P}

In [None]:
inputs = {"msa":msa,"msa_Z":msa_Z}

params = {"V":jnp.zeros((L,A)),
          "W":jnp.zeros((L,A,L,A))}

grad_model = jax.jit(jax.value_and_grad(model, has_aux=True))
(loss, outputs), grad = grad_model(params, inputs)


print("P(0,1) BEFORE")
P = outputs["P"][0]
print(P)

In [None]:
# instead of using GD, lets try ADAM
from jax.example_libraries.optimizers import adam

In [None]:
# initialize optimizer
init_fun, update_fun, get_params = adam(1e-1)

# initialize state
state = init_fun(params)
for n in range(100):
  (loss,outputs), grad = grad_model(get_params(state), inputs)
  state = update_fun(n, grad, state)
  if (n+1) % 10 == 0:
    print(n+1,loss)

In [None]:
print("P(0,1) AFTER")
P = outputs["P"][0]
print(P)

# Coevolution Approximation (DCA - Inverse Covariance)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from matplotlib import colors
YlBu = colors.ListedColormap(['gold','blue'])
YlBuGr = colors.ListedColormap(['gold','blue','green'])

import jax
import jax.numpy as jnp

@jax.jit
def inv_cov(Y, lam=None):
  Y = jnp.asarray(Y)
  N,L,A = Y.shape
  Y_flat = Y.reshape(N,-1)
  c = jnp.cov(Y_flat.T)
  if lam is None:
    lam = 4.5/jnp.sqrt(N)
  shrink = lam * jnp.eye(c.shape[0])
  ic = jnp.linalg.inv(c + shrink)
  return ic.reshape(L,A,L,A)

def get_mtx(raw, apc=False, exclude_last=False, return_all=False):
  raw = np.asarray(raw)
  if exclude_last:
    raw = raw[:,:-1,:,:-1]
  mtx = np.sqrt(np.square(raw).sum((1,3)))
  np.fill_diagonal(mtx,0)
  if apc:
    # apc (average product correction)
    ap = np.sum(mtx,0,keepdims=True) * np.sum(mtx,1,keepdims=True)/np.sum(mtx)
    mtx = mtx - ap
    np.fill_diagonal(mtx,0)
  if return_all:
    return {"W":raw,"mtx":mtx}
  else:
    return mtx

def get_coevo(X, reg=None, apc=False, exclude_last=False):
  '''
  ---------------------------------
  input:
  ---------------------------------
   X = multiple sequence alignment
  ---------------------------------
  output:
   W = coevolution matrix
  ---------------------------------
  '''
  Y = np.eye(X.max()+1)[X] if X.ndim == 2 else X
  return get_mtx(inv_cov(Y, lam=reg), apc=apc, exclude_last=exclude_last)

In [None]:
example_seqs = np.array([
  [0, 1, 1, 0, 0, 1],
  [1, 1, 0, 1, 1, 1],
  [0, 0, 1, 1, 1, 1],
  [0, 0, 1, 0, 0, 0],
  [1, 1, 0, 0, 0, 1]
])
plt.imshow(example_seqs, cmap=YlBu)
plt.xlabel("feautres (positions)")
plt.ylabel("samples (sequences)")
plt.show()

In [None]:
example_coevo = get_coevo(example_seqs, reg=0.001)
plt.imshow(example_coevo, cmap="Blues")

## EXERCISE 1
Try change the sequence (or make up a new sequence) to see if you can change the co-evolution pattern!

What happens if you expand the alphabet?

In [None]:
seqs = np.array([
  [0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0],
])
plt.imshow(seqs, cmap=YlBu)
plt.xlabel("feautres (positions)")
plt.ylabel("samples (sequences)")
plt.show()

In [None]:
coevo = get_coevo(seqs, reg=0.001)
plt.imshow(coevo, cmap="Blues")

# REAL DATA

In [None]:
#@title helper functions <- RUN THIS
import string
def parse_fasta(filename, a3m=False):
  '''function to parse fasta file'''
  if a3m:
    # for a3m files the lowercase letters are removed
    # as these do not align to the query sequence
    rm_lc = str.maketrans(dict.fromkeys(string.ascii_lowercase))

  header, sequence = [],[]
  lines = open(filename, "r")
  for line in lines:
    line = line.rstrip()
    if len(line) > 0:
      if line[0] != "#":
        if line[0] == ">":
          header.append(line[1:])
          sequence.append([])
        else:
          if a3m: line = line.translate(rm_lc)
          else: line = line.upper()
          sequence[-1].append(line)
  lines.close()
  sequence = [''.join(seq) for seq in sequence]
  return header, sequence

def mk_msa(seqs):
  '''one hot encode msa'''
  alphabet = list("ARNDCQEGHILKMFPSTWYV-")
  states = len(alphabet)

  alpha = np.array(alphabet, dtype='|S1').view(np.uint8)
  msa = np.array([list(s) for s in seqs], dtype='|S1').view(np.uint8)
  for n in range(states):
    msa[msa == alpha[n]] = n
  msa[msa > states] = states-1
  return np.eye(states)[msa]

In [None]:
nams, seqs = parse_fasta("example_filt.fasta")
msa = mk_msa(seqs)

In [None]:
coevo = get_coevo(msa, apc=True, exclude_last=True)

In [None]:
plt.imshow(coevo,vmin=0,vmax=0.3, cmap="Blues")