In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from tqdm import tqdm
import os
from lib.models.forward_model import UniformForward
from flax import linen as nn
from lib.utils import utils

In [4]:
seq_len = 10
idx = jnp.arange(seq_len, dtype=jnp.int32)
att_l2r_mask = nn.attention.make_attention_mask(idx, idx, jnp.greater_equal)
print(att_l2r_mask.shape, att_l2r_mask)
att_r2l_mask = nn.attention.make_attention_mask(idx, idx, jnp.less_equal)
att_t = jnp.ones((1, seq_len, 1))
joint_mask = jnp.concatenate([att_t, att_l2r_mask, att_r2l_mask], axis=-1)
print(joint_mask.shape, joint_mask)
joint_mask = jnp.expand_dims(joint_mask, axis=0)
print(joint_mask.shape, joint_mask)

(1, 10, 10) [[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [1. 1. 1. 0. 0. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
  [1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
  [1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
(1, 10, 21) [[[1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1

In [5]:
jnp.arange(10, dtype=jnp.int32)
np.expand_dims(jnp.arange(10, dtype=jnp.int32), 0)

array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=int32)

In [3]:
rate_const = 1
S = 4
B = 256
D = 4
uni = UniformForward(S, rate_const)

rng = jax.random.PRNGKey(1008)
t_rng, sample_rng = jax.random.split(rng)
t = jax.random.uniform(t_rng, (B,))
print(t)
xt = jax.random.randint(sample_rng, shape=(B, D), minval=0, maxval=S, dtype=jnp.int32)
qt0 = uni.transition(t)
qt0_y2x = jnp.transpose(qt0, (0, 2, 1))
print(qt0, qt0.shape)
print(qt0_y2x, qt0_y2x.shape)
print(qt0 == qt0_y2x)

b = jnp.expand_dims(jnp.arange(xt.shape[0]), tuple(range(1, xt.ndim)))


[0.16477692 0.6628156  0.5799111  0.91668844 0.38052666 0.22999239
 0.744532   0.703349   0.68598914 0.95723104 0.7563106  0.33415103
 0.9805318  0.9204687  0.7621286  0.70780146 0.7601353  0.11362052
 0.18291497 0.4065224  0.6326307  0.30837262 0.38905704 0.8518803
 0.8026649  0.47867346 0.66204286 0.3761475  0.3353703  0.63315034
 0.64062595 0.5985664  0.2490884  0.22583091 0.13790691 0.7293955
 0.68318963 0.00314939 0.35922694 0.4672613  0.5603125  0.37063622
 0.79032135 0.26515067 0.0705471  0.37446618 0.96311915 0.22508585
 0.1628896  0.4007709  0.5419551  0.24647427 0.28652573 0.38016844
 0.00712883 0.27092147 0.1842525  0.12204552 0.67230463 0.6193998
 0.09896588 0.66243434 0.68301713 0.4518553  0.55643034 0.5431613
 0.95571244 0.949721   0.42050564 0.06308067 0.95509624 0.07910037
 0.9015987  0.32073045 0.11115897 0.09862614 0.7475505  0.08981788
 0.00303638 0.4186268  0.48477638 0.48392737 0.43903816 0.24950516
 0.39431524 0.24700534 0.1004256  0.58200276 0.773296   0.66488135

In [8]:
x = jax.random.randint(sample_rng, shape=(B, 1024, 10), minval=0, maxval=S, dtype=jnp.int32)
temb = jax.random.uniform(t_rng, (B,10))
print(temb.shape)
temb = jnp.expand_dims(temb, axis=1)
print(temb.shape)
conditioner = temb
concat_dim = 5
# conditioner = jnp.concatenate([conditioner, temb], axis=1)
print(conditioner.shape)
cond_dim = conditioner.shape[1]
print("cond_dim", cond_dim)
concat_dim = x.shape[1] + cond_dim - 1
print("conc dim", concat_dim)
concat_dim = 5
pos_idx = jnp.expand_dims(jnp.arange(concat_dim, dtype=jnp.int32), 0)
print("pos", pos_idx.shape)
x = jnp.concatenate([conditioner, x[:, :-1]], axis=1)
print("conditioner", conditioner.shape)
print("x", x.shape)
mask = nn.attention.make_attention_mask(pos_idx, pos_idx,
                                        jnp.greater_equal)
#print("mask1", mask, mask.shape)
mask = mask.at[:, :, :cond_dim, :cond_dim].set(1.0)
print(mask, mask.shape)

(256, 10)
(256, 1, 10)
(256, 1, 10)
cond_dim 1
conc dim 1024
pos (1, 5)
conditioner (256, 1, 10)
x (256, 1024, 10)
[[[[1. 0. 0. 0. 0.]
   [1. 1. 0. 0. 0.]
   [1. 1. 1. 0. 0.]
   [1. 1. 1. 1. 0.]
   [1. 1. 1. 1. 1.]]]] (1, 1, 5, 5)


In [9]:
x = jnp.concatenate([x[:, 1:], conditioner], axis=1)
mask = nn.attention.make_attention_mask(pos_idx, pos_idx,
                                        jnp.less_equal)
mask = mask.at[:, :, -cond_dim:, -cond_dim:].set(1.0)
print(mask)

[[[[1. 1. 1. 1. 1.]
   [0. 1. 1. 1. 1.]
   [0. 0. 1. 1. 1.]
   [0. 0. 0. 1. 1.]
   [0. 0. 0. 0. 1.]]]]


In [10]:
t = jax.random.uniform(t_rng, (B,))
xt = jax.random.randint(sample_rng, shape=(B, D), minval=0, maxval=S, dtype=jnp.int32)
t_eps = t - 0.01
q_teps_0 = uni.transition(t_eps)
print(q_teps_0, q_teps_0.shape)
q_teps_0 = jnp.expand_dims(q_teps_0, axis=list(range(1, xt.ndim)))
print(q_teps_0, q_teps_0.shape)
q_t_teps = uni.transit_between(t_eps, t)
print(q_t_teps, q_t_teps.shape)
q_t_teps = jnp.transpose(q_t_teps, (0, 2, 1))
print(q_t_teps, q_t_teps.shape)



[[[0.27067754 0.2431075  0.2431075  0.2431075 ]
  [0.2431075  0.2706775  0.2431075  0.2431075 ]
  [0.2431075  0.2431075  0.27067754 0.2431075 ]
  [0.2431075  0.2431075  0.2431075  0.2706775 ]]

 [[0.27584603 0.24138466 0.24138464 0.24138466]
  [0.24138466 0.27584606 0.24138466 0.24138466]
  [0.24138464 0.24138466 0.27584603 0.24138466]
  [0.24138466 0.24138466 0.24138466 0.27584606]]] (2, 4, 4)
[[[[0.27067754 0.2431075  0.2431075  0.2431075 ]
   [0.2431075  0.2706775  0.2431075  0.2431075 ]
   [0.2431075  0.2431075  0.27067754 0.2431075 ]
   [0.2431075  0.2431075  0.2431075  0.2706775 ]]]


 [[[0.27584603 0.24138466 0.24138464 0.24138466]
   [0.24138466 0.27584606 0.24138466 0.24138466]
   [0.24138464 0.24138466 0.27584603 0.24138466]
   [0.24138466 0.24138466 0.24138466 0.27584606]]]] (2, 1, 4, 4)
[[[0.97059214 0.00980262 0.00980264 0.00980261]
  [0.00980263 0.9705922  0.00980261 0.00980261]
  [0.00980264 0.00980259 0.97059214 0.00980263]
  [0.00980264 0.00980258 0.00980263 0.97059214

In [11]:
b = jnp.expand_dims(jnp.arange(xt.shape[0]), tuple(range(1, xt.ndim)))
print(b, b.shape)
q_t_teps = jnp.expand_dims(q_t_teps[b, xt], axis=-2)
print(q_t_teps, q_t_teps.shape)


[[0]
 [1]] (2, 1)
[[[[0.97059214 0.00980263 0.00980264 0.00980264]]

  [[0.00980264 0.00980261 0.97059214 0.00980263]]

  [[0.00980262 0.9705922  0.00980259 0.00980258]]

  [[0.97059214 0.00980263 0.00980264 0.00980264]]]


 [[[0.00980262 0.9705922  0.00980259 0.00980258]]

  [[0.00980264 0.00980261 0.97059214 0.00980262]]

  [[0.00980264 0.00980261 0.97059214 0.00980262]]

  [[0.97059214 0.00980263 0.00980264 0.00980264]]]] (2, 4, 1, 4)


In [None]:
logits = qt0_y2x
log_p0t = nn.log_softmax(logits, axis=-1)
print(log_p0t, log_p0t.shape)
log_qt0 = jnp.where(qt0 <= 1e-35, -1e9, jnp.log(qt0))
print(log_qt0, log_qt0.shape)
log_qt0 = jnp.expand_dims(log_qt0, axis=list(range(1, xt.ndim)))
print(log_qt0, log_qt0.shape)
log_p0t = jnp.expand_dims(log_p0t, axis=-1)
print(log_p0t, log_p0t.shape)
log_prob = jax.nn.logsumexp(log_p0t + log_qt0, axis=-2)
print(log_prob, log_prob.shape)

In [None]:
qt0 = uni.transition(t)
xt_onehot = jax.nn.one_hot(xt, S)
print(xt_onehot.shape)
p0t = jax.nn.softmax(logits, axis=-1)
print(p0t, p0t.shape)
qt0 = jnp.expand_dims(qt0, axis=list(range(1, xt.ndim - 1)))
print(qt0, qt0.shape)
prob_all = p0t @ qt0
print(prob_all.shape)
log_prob = jnp.log(prob_all + 1e-35)
print(log_prob, log_prob.shape)
log_xt = jnp.sum(log_prob * xt_onehot, axis=-1)
print(log_xt, log_xt.shape)

In [None]:
qt = uni.transition(t)
b = jnp.expand_dims(jnp.arange(B), tuple(range(1, xt.ndim)))
qt0 = qt[b, xt]
print(qt0, qt0.shape)
logits = jnp.where(qt0 <= 0.0, -1e9, jnp.log(qt0))
print(logits, logits.shape)
xt = jax.random.categorical(sample_rng, logits)
print(xt, xt.shape)

In [None]:
ll_xt = xt #B, D
ll_all =  logits# B, D, S
loss = -(
    (S - 1) * ll_xt
    + jnp.sum(utils.log1mexp(ll_all), axis=-1)
    - utils.log1mexp(ll_xt)
)
print(loss, loss.shape)

In [None]:
ll_xt = xt #B, D
ll_all =  logits
xt_onehot = jax.nn.one_hot(xt, S)
b = jnp.expand_dims(jnp.arange(xt.shape[0]), tuple(range(1, xt.ndim)))
print(b, b.shape)
qt0_x2y = uni.transition(t)
print(qt0_x2y, qt0_x2y.shape)
qt0_y2x = jnp.transpose(qt0_x2y, (0, 2, 1))
print(qt0_x2y, qt0_x2y.shape)
qt0_y2x = qt0_y2x[b, xt]
print(qt0_x2y, qt0_x2y.shape)
ll_xt = jnp.expand_dims(ll_xt, axis=-1)
print("ll", ll_xt, ll_xt.shape)
backwd = jnp.exp(ll_all - ll_xt) * qt0_y2x
print(backwd , backwd.shape)


In [None]:
first_term = jnp.sum(backwd * (1 - xt_onehot), axis=-1)
print(first_term , first_term.shape)
qt0_x2y = qt0_x2y[b, xt]
print(qt0_x2y, qt0_x2y.shape)
fwd = (ll_xt - ll_all) * qt0_x2y
print(fwd, fwd.shape)
second_term = jnp.sum(fwd * (1 - xt_onehot), axis=-1)
print(second_term, second_term.shape)
loss = first_term - second_term
print(loss, loss.shape)

In [None]:
weight = jnp.ones((B, ))
weight = jnp.expand_dims(weight, axis=list(range(1, loss.ndim)))
print(weight, weight.shape)
loss = loss * weight
print(loss, loss.shape)
loss = jnp.sum(loss) / xt.shape[0]
print(loss, loss.shape)

In [None]:
"""
The main bottleneck is the design of the conditional marginal parameterization, which requires non-trivial trade-offs between computational cost 
and flexibility of the architectures; score matching for general categorical discrete variables does not benefit from prior knowledge about ordinal 
discrete data; and finally unifying score matching between continu- ous and discrete spaces would be needed to handle data in mixed spaces
"""