In [17]:
import time

import numpy as np
from numpy.testing import assert_allclose

import jax
import jax.numpy as jnp

In [2]:
vocab_size = 100
word_size = 40
doc_size = 100
topic_size = 10

data_tokens = np.random.randint(0, vocab_size, size=(doc_size, word_size))

data_bow = np.zeros((doc_size, vocab_size))
for doc in range(doc_size):
    for word in range(word_size):
        token = data_tokens[doc][word]
        data_bow[doc][token] += 1
data_bow /= data_bow.sum(axis=1)[:, None]

data_bow.shape, data_tokens.shape

((100, 100), (100, 40))

In [3]:
def norm(x: jax.Array) -> jax.Array:
    # take x+ = max(x, 0) element-wise (perform projection on positive simplex)
    x = jnp.maximum(x, jnp.zeros_like(x))
    # normalize values in non-zero rows to 1 (mapping from the positive simplex to the unit simplex)
    norm = x.sum(axis=0)
    x = jnp.where(norm > 1e-12, x / norm, jnp.zeros_like(x))
    return x


def test_norm_vector():
    x = np.random.rand(1000) * 100 - 50  # [-50, 50)
    y = norm(x)
    assert jnp.all(jnp.sign(y) >= 0)
    assert jnp.isclose(jnp.sum(y), 1)

def test_norm_matrix():
    x = np.random.rand(100, 100) * 100 - 50  # [-50, 50)
    y = norm(x)
    assert jnp.all(jnp.sign(y) >= 0)
    assert_allclose(jnp.sum(y, axis=0), 1, rtol=1e-6)

In [4]:
test_norm_vector()
test_norm_matrix()

In [5]:
def prepare_n_t():
    n = word_size * doc_size

    n_t = np.zeros(topic_size)
    for t in range(topic_size - 1):
        n_t[t] = np.random.randint(0, n - np.sum(n_t))
    n_t[-1] = n - np.sum(n_t)
    assert np.sum(n_t) == n
    return n_t


def prepare_phi():
    phi = np.random.rand(vocab_size, topic_size)
    phi = norm(phi)
    assert np.sum(phi, axis=0).shape == (topic_size, )
    assert_allclose(np.sum(phi, axis=0), 1.0, rtol=1e-6)
    return phi

In [6]:
n_t = prepare_n_t()
phi = prepare_phi()

In [7]:
def calc_naive_phi_hatch(phi, n_t):
    phi_hatch_naive = np.zeros_like(phi)
    for w in range(vocab_size):
        for t in range(topic_size):
            phi_hatch_naive[w][t] = phi[w][t] * n_t[t]
        phi_hatch_naive[w] = norm(phi_hatch_naive[w])
    return phi_hatch_naive


def calc_fast_phi_hatch(phi, n_t):
    return norm(phi.T * n_t[:, None]).T


def test_phi_hatch(phi, n_t):
    phi_hatch_vec = calc_fast_phi_hatch(phi, n_t)
    phi_hatch_naive = calc_naive_phi_hatch(phi, n_t)

    assert phi_hatch_vec.shape == phi_hatch_naive.shape
    assert_allclose(phi_hatch_vec, phi_hatch_naive)

In [8]:
test_phi_hatch(phi, n_t)

In [9]:
def create_context_coeff_matrix(ctx_len, seq_len):
    gamma = 1 / ctx_len

    # construct tril matrix (suffix context)
    tril_matrix = np.zeros((seq_len, seq_len))
    for i in np.arange(1, ctx_len + 1):
        tril_matrix[np.arange(i, seq_len), np.arange(seq_len - i)] = gamma * (1 - gamma) ** i

    # contstruct full matrix (self + prefix + suffix context)
    full_matrix = np.eye(tril_matrix.shape[0]) * gamma + tril_matrix + tril_matrix.T

    # normalize weights and transpose
    full_matrix /= full_matrix.sum(axis=0)
    full_matrix = full_matrix.T
    return jnp.array(full_matrix)


def calc_naive_theta(data, phi_hatch, ctx_len):
    gamma = 1 / ctx_len
    theta = []
    context_coeffs = create_context_coeff_matrix(ctx_len=ctx_len, seq_len=word_size)
    for d in range(doc_size):
        for w in range(word_size):
            left_context_vec = np.zeros(topic_size)
            right_context_vec = np.zeros(topic_size)
            for i in range(1, ctx_len + 1):
                if w + i < word_size:
                    left_context_vec += context_coeffs[w][w + i] * phi_hatch[data[d][w + i]]
                if w - i >= 0:
                    right_context_vec += context_coeffs[w][w - i] * phi_hatch[data[d][w - i]]
            theta.append(left_context_vec + right_context_vec + context_coeffs[w][w] * phi_hatch[data[d][w]])
    return np.array(theta)


def calc_fast_theta(data, phi_hatch, ctx_len):
    context_coeffs = create_context_coeff_matrix(ctx_len=ctx_len, seq_len=word_size)
    data_emb = jnp.take_along_axis(phi_hatch[None, ...], indices=data[..., None], axis=1)  # (D, W_d, T)
    theta_new = jnp.sum(
        data_emb[:, None, :, :] * context_coeffs[None, :, :, None],
        axis=2,
    )  # (D, W_d, T)
    theta_new = theta_new.reshape(-1, theta_new.shape[-1])  # (I, T)
    return theta_new


def test_theta(data, phi_hatch, ctx_len=8):
    theta_naive = calc_naive_theta(data, phi_hatch, ctx_len)
    theta_fast = calc_fast_theta(data, phi_hatch, ctx_len)
    assert theta_naive.shape == theta_fast.shape
    assert_allclose(theta_fast, theta_naive, rtol=1e-6)

In [10]:
phi_hatch = calc_fast_phi_hatch(phi, n_t)
test_theta(data_tokens, phi_hatch)

In [11]:
def calc_fast_p_ti(phi, data, theta_new):
    data_emb = jnp.take_along_axis(phi[None, ...], indices=data[..., None], axis=1)  # (D, W_d, T)
    p_ti = data_emb.reshape(-1, data_emb.shape[-1])  # (I, T)
    p_ti = norm((p_ti * theta_new).T).T  # (I, T)
    return p_ti


def calc_naive_p_ti(phi, data, theta_new):
    p_ti_naive = np.zeros((word_size * doc_size, topic_size))
    for d in range(doc_size):
        for w in range(word_size):
            i = d * word_size + w
            token = data[d][w]
            for t in range(topic_size):
                p_ti_naive[i][t] = phi[token][t] * theta_new[i][t]
            p_ti_naive[i] = norm(p_ti_naive[i])
    return p_ti_naive


def test_p_ti(phi, data, theta_new):
    p_ti_fast = calc_fast_p_ti(phi, data, theta_new)
    p_ti_naive = calc_naive_p_ti(phi, data, theta_new)

    assert p_ti_fast.shape == p_ti_naive.shape
    assert_allclose(p_ti_fast, p_ti_naive)

In [12]:
theta_new = calc_fast_theta(data_tokens, phi_hatch, ctx_len=8)
test_p_ti(phi, data_tokens, theta_new)

In [13]:
def calc_naive_n_t(p_ti):
    n_t = np.zeros(topic_size)
    for i in range(word_size * doc_size):
        for t in range(topic_size):
            n_t[t] += p_ti[i][t]
    return n_t


def calc_fast_n_t(p_ti):
    return jnp.sum(p_ti, axis=0)  # (T, )


def test_n_t(p_ti):
    n_t_naive = calc_naive_n_t(p_ti)
    n_t_fast = calc_fast_n_t(p_ti)

    assert n_t_naive.shape == n_t_fast.shape
    assert_allclose(n_t_fast, n_t_naive, rtol=1e-5)

In [14]:
p_ti = calc_fast_p_ti(phi, data_tokens, theta_new)
test_n_t(p_ti)

In [15]:
def calc_naive_phi(data, p_ti):
    phi_new = np.zeros((topic_size, vocab_size))
    for t in range(topic_size):
        for word in range(vocab_size):
            for d in range(doc_size):
                for w in range(word_size):
                    i = d * word_size + w
                    token = data[d][w]
                    phi_new[t][word] += (token == word) * p_ti[i][t]
        phi_new[t] = norm(phi_new[t])
    return phi_new.T


def calc_fast_phi(data, p_ti):
    indices = data.flatten()  # (I, )
    phi_new = jnp.add.at(jnp.zeros((vocab_size, topic_size)), indices, p_ti, inplace=False)  # (W, T)
    phi_new = norm(phi_new)  # (W, T)
    return phi_new


def test_phi(data, p_ti):
    phi_naive = calc_naive_phi(data, p_ti)
    phi_fast = calc_fast_phi(data, p_ti)

    assert phi_naive.shape == phi_fast.shape
    assert_allclose(phi_fast, phi_naive)

In [16]:
test_phi(data_tokens, p_ti)

100%|██████████| 10/10 [04:36<00:00, 27.66s/it]
