<a href="https://colab.research.google.com/github/sokrypton/algosb_2021/blob/main/Intro_to_MRFs_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# import libraries
import jax
import jax.numpy as jnp
import matplotlib.pylab as plt

# new library for interations!
import itertools as it

## Statistical models

In [2]:
import tensorflow as tf
# msa
msa = jnp.array([
    [0,1], # sequence 1
    [1,0], # sequence 2
    [0,1], # sequence 3
    [1,0], # sequence 4
])

# define size
N, L = msa.shape

# define number of states
A = 2

print("msa")
print(msa)



msa
[[0 1]
 [1 0]
 [0 1]
 [1 0]]


## PSSM

In [11]:
def model(params, inputs):

  msa = inputs["msa"]
  pssm_logits = params["pssm_logits"]

  # normalize so values sum to one
  # pssm = jnp.exp(pssm_logits)/jnp.exp(pssm_logits).sum(-1,keepdims=True)
  pssm = jax.nn.softmax(pssm_logits)

  # log-probability
  P = pssm[jnp.arange(L),msa]
  logP = jnp.log(P).sum(-1)

  # define loss function
  # note for optimization, loss is "minimized"
  # so to "maximize" the probability, we will "minimize" the -logP
  loss = -logP.mean()

  return loss, {"P":P, "pssm":pssm}

In [15]:
inputs = {"msa":msa}

# random initialization
key = jax.random.PRNGKey(1)
params = {"pssm_logits":jax.random.gumbel(key, shape=(L,A))}

loss, outputs = model(params, inputs)

# lets look at the PSSM BEFORE optimization
print("PSSM BEFORE")
print(outputs["pssm"])

print("P(0,1) BEFORE")
P = outputs["P"][0].prod()
print(P)

PSSM BEFORE
[[0.66715103 0.3328489 ]
 [0.864988   0.13501194]]
P(0,1) BEFORE
0.090073355


In [16]:
# get gradients = ∂loss/∂params
grad_model = jax.value_and_grad(model, has_aux=True)

In [17]:
(loss,outputs), grad = grad_model(params, inputs)

In [18]:
loss

DeviceArray(1.8261186, dtype=float32)

In [19]:
grad

{'pssm_logits': DeviceArray([[ 0.16715103, -0.16715111],
              [ 0.36498797, -0.36498806]], dtype=float32)}

In [20]:
for n in range(100):
  params["pssm_logits"] -= 0.1 * grad["pssm_logits"]
  (loss,outputs), grad = grad_model(params, inputs)
  if (n+1) % 10 == 0:
    print(n+1,loss)

10 1.5813975
20 1.4637878
30 1.4152719
40 1.3968499
50 1.3901005
60 1.3876617
70 1.3867848
80 1.3864703
90 1.3863573
100 1.386317


In [21]:
print("PSSM AFTER")
print(outputs["pssm"])

print("P(0,1) AFTER")
P = outputs["P"][0].prod()
print(P)

PSSM AFTER
[[0.50105166 0.4989483 ]
 [0.5031935  0.49680653]]
P(0,1) AFTER
0.24892575


## Markov Random Field
Adding two-body term (W)

In [22]:
import itertools as it
# "Z" alignment of all possible sequences!
msa_Z =  jnp.array(list(it.product(jnp.arange(A),repeat=L)))
print("msa_Z")
print(msa_Z)

msa_Z
[[0 0]
 [0 1]
 [1 0]
 [1 1]]


In [23]:
def model(params, inputs):

  msa_H = inputs["msa"]
  msa_Z = inputs["msa_Z"]

  # one-hot encode msa
  oh_msa_H = jnp.eye(A)[msa_H]
  oh_msa_Z = jnp.eye(A)[msa_Z]

  # 1-body-term
  V = params["V"]

  # 2-body-term
  W = params["W"]
  W = W * (1-jnp.eye(L)[:,None,:,None]) # set diagonal to zero
  W = 0.5 * (W + W.transpose((2,3,0,1))) # symmetrize

  H = (oh_msa_H * (V + jnp.tensordot(oh_msa_H, W, 2))).sum((1,2))
  Z = (oh_msa_Z * (V + jnp.tensordot(oh_msa_Z, W, 2))).sum((1,2))

  # P(X|θ) = Probability of MSA (X) given parameters (θ)
  P = jnp.exp(H)/jnp.exp(Z).sum()

  # can also be thought as
  # L(θ|X) = Likelihood of parameters (θ) given MSA (X)
  # Log-Likelihood
  LL = H - jnp.log(jnp.exp(Z).sum())

  loss = -LL.mean()
  return loss, {"P":P}

In [24]:
inputs = {"msa":msa,"msa_Z":msa_Z}

params = {"V":jnp.zeros((L,A)),
          "W":jnp.zeros((L,A,L,A))}

grad_model = jax.jit(jax.value_and_grad(model, has_aux=True))
(loss, outputs), grad = grad_model(params, inputs)


print("P(0,1) BEFORE")
P = outputs["P"][0]
print(P)

P(0,1) BEFORE
0.25


In [25]:
# instead of using GD, lets try ADAM
from jax.experimental.optimizers import adam

In [26]:
# initialize optimizer
init_fun, update_fun, get_params = adam(1e-1)

# initialize state
state = init_fun(params)
for n in range(1000):
  (loss,outputs), grad = grad_model(get_params(state), inputs)
  state = update_fun(n, grad, state)
  if (n+1) % 100 == 0:
    print(n+1,loss)

100 0.6942606
200 0.69378376
300 0.6935549
400 0.69343185
500 0.69335747
600 0.6933098
700 0.69327736
800 0.69325256
900 0.6932354
1000 0.69322205


In [27]:
print("P(0,1) AFTER")
P = outputs["P"][0]
print(P)

P(0,1) AFTER
0.49996245
