In [1]:
import numpy as np
import jax

from scipy.optimize import linear_sum_assignment

from lda_jax.utils.generator import generate_lda_corpus
from lda_jax.models.lda import LDAModel
from lda_jax.inference.sampler import GibbsSampler, SamplerConfig

In [2]:
def _distance(beta: np.ndarray, beta_hat: np.ndarray) -> float:
    """Minimal matched TV distance"""
    K = beta.shape[0]
    cost_matrix = np.zeros((K, K))
    for i in range(K):
        for j in range(K):
            cost_matrix[i, j] = 0.5 * np.linalg.norm(beta[i] - beta_hat[j])

    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    avg_tv = cost_matrix[row_ind, col_ind].mean()
    return avg_tv

In [3]:
key = jax.random.PRNGKey(0)
synth = generate_lda_corpus(
    key,
    num_docs=100,
    num_topics=5,
    vocab_size=30,
    doc_length=35,
)

corpus     = synth.corpus
true_beta  = np.asarray(synth.beta)  # (K, V)

In [4]:
# ------------------------------------------------------------------
# 2. Fit model with Gibbs sampler
# ------------------------------------------------------------------
model   = LDAModel(num_topics=5, vocab_size=corpus.vocab_size)
config  = SamplerConfig(num_iters=1000, burn_in=500, thin=20, rng_key=jax.random.PRNGKey(1), show_progress=True)
sampler = GibbsSampler(corpus, model, config)

In [5]:
sampler.run()
est_phi = np.asarray(sampler.posterior_phi())  # (K, V)

In [6]:
dist = _distance(true_beta, est_phi)

In [7]:
dist

0.014288789592683315

In [8]:
# directly use run_gibbs for much faster lax.scan gibbs

from lda_jax.models.lda import run_gibbs

final_state, _ = run_gibbs(corpus,
                           model,
                           num_iters=1000,
                           key=key)

In [11]:
final_state.n_kw.shape

(5, 30)