# Gibbs Sampler for LDA

Latent Dirichlet Allocation (LDA) is a mixed membership model for topic modelling. Given a set of documents in a bag of words representation, we want to infer the underlying topics those documents represent. To get a better intuition, we shall look at LDA's generative story. 

Given $i = \{1,\ldots,N_D\}$ the document index, $v = \{1,\ldots,N_W\}$ the word index, $k= \{1,\ldots,N_k\}$ the topic index, LDA assumes:

$$
\begin{array}
& \pi_i &\sim & \mathrm{Dir}(\pi_i|\alpha)\\
z_{iw} & \sim &\mathrm{Cat}(z_{iw}|\pi_i)\\
\mathbf{b}_k &\sim & \mathrm{Dir}(\mathbf{b}_k|\gamma)\\
y_{iw} &\sim &\mathrm{Cat}(y_{iw}|z_{iw} = k, \mathbf{B})
\end{array}
$$
where $\alpha$ and $\gamma$ are the parameters for the Dirichlet priors. They tell us how narrow or spread the document topic and topic word distributions are.

Details for the above generative process in words:
1. Assume each document is generated by selecting the topic first. Thus, sample $\pi_i$, the topic distribution for the $i$-th document.
2. Assume each word in the $i$-th document comes from one of the topics. Therefore, we sample $z_{iw}$, the topic for each word $w$ in document $i$.
3. Assume each topic is composed of words, e.g. topic 'computer' consists of words 'cpu', 'gpu', etc. Therefore, we sample $\mathbf{b}_k$, the distribution of those words for particular topic $k$.
4. Finally, to actually generate the word, given that we already know it comes from topic $k$, we sample the word $y_{iw}$ given the $k$-th topic word distribution.

## Inference
The goal of inference in LDA is that given a corpus, we infer the underlying topics that explain those documents, according to the generative process above. Essentially, given $y_{iw}$, we are inverting the above process to find $z_{iw}$, $\pi_i$ and $\mathbf{b}_k$.

We will infer those variables using Gibbs Sampling algorithm. In short, it works by sampling each of those variables given the other variables (full conditional distribution). Because of the conjugacy, the full conditionals are as follows:

$$
\begin{array}
& p(z_{iw} = k|\pi_i, \mathbf{b}_k) &\propto & \exp(\log \pi_{ik} + \log b_{k, y_{iw}}) \\
p(\pi_i | z_{iw} = k, \mathbf{b}_k) & = & \mathrm{Dir}(\alpha + \sum_l\mathbb{I}(z_{il} = k )) \\
p(\mathbf{b}_k|z_{iw} = k, \pi_i) & = & \mathrm{Dir}(\gamma + \sum_i\sum_l\mathbb{I}(y_{il}=w, z_{il}=k))
\end{array}
$$

Essentially, what we are doing is to count the assignment of words and documents to particular topics. Those are the sufficient statistics for the full conditionals.

Given those full conditionals, the rest is as easy as plugging those into the Gibbs Sampling framework, as we shall discuss in the next section. 

## Implementation

In [1]:
import numpy as np

In [2]:
# Words
W = np.array([0, 1, 2, 3, 4])

# D:= document words
X = np.array([
    [0, 0, 1, 2, 2],
    [0, 0, 1, 1, 1],
    [0, 1, 2, 2, 2],
    [4, 4, 4, 4, 4],
    [3, 3, 4, 4, 4],
    [3, 4, 4, 4, 4]
])

N_D = X.shape[0] # num of docs
N_W = W.shape[0] # num of words
N_K = 2 # num of topics

We begin with randomly initializing topic assignment matrix $\mathbf{Z}_{N_D\times N_W}$. We also sample the initial values of $\boldsymbol{\Pi}_{N_D\times N_K}$ and $\mathbf{B}_{N_K\times N_W}$.

In [4]:
# Dirichlet priors
alpha = 1
gamma = 1

# --------------
# Initialization
# --------------

# Z := word topic assigmnet
Z = np.zeros(shape=[N_D, N_W])
for i in range(N_D):
    for l in range(N_W):
        Z[i, l] = np.random.randint(N_K) # randomly assign word's topic
        
# Pi := document topic distribution
Pi = np.zeros([N_D, N_K])
for i in range(N_D):
    Pi[i] = np.random.dirichlet(alpha*np.ones(N_K))
    
# B := word topic distribution
B = np.zeros([N_K, N_W])
for k in range(N_K):
    B[k] = np.random.dirichlet(gamma*np.ones(N_W))

We sample the new values for each of those variables from the full conditionals in the previous section and iterate:

In [5]:
# --------------
# Gibbs sampling
# --------------
for it in range(1000):
    # Sample from full conditional of Z
    # ---------------------------------
    for i in range(N_D):
        for l in range(N_W):
            # Calculate params for Z
            p_bar_il = np.exp(np.log(Pi[i]) + np.log(B[:, X[i, l]]))
            p_il = p_bar_il / np.sum(p_bar_il)
            
            # Resample word topic assignment Z
            z_il = np.random.multinomial(1, p_il)
            Z[i, l] = np.argmax(z_il)
            
    # Sample from full conditional of Pi
    # ----------------------------------
    for i in range(N_D):
        m = np.zeros(N_K)
        
        # Gather sufficient statistics
        for k in range(N_K):
            m[k] = np.sum(Z[i] == k)
            
        # Resample doc topic distribution.
        Pi[i, :] = np.random.dirichlet(alpha + m)
        
    # Sample from full conditional of B
    # ---------------------------------
    for k in range(N_K):
        n = np.zeros(N_W)
        
        # Gather sufficient statistics
        for v in range(N_W):
            for i in range(N_D):
                for l in range(N_W):
                    n[v] += (X[i, l] == v) and (Z[i, l] == k)
                    
        # Resample word topic distribution
        B[k, :] = np.random.dirichlet(gamma + n)

And basically we are done. We could inspect the result by looking at those variables after some iterations of the algorithm.

In [6]:
print('Documents:')
print('----------')
print(X)

Documents:
----------
[[0 0 1 2 2]
 [0 0 1 1 1]
 [0 1 2 2 2]
 [4 4 4 4 4]
 [3 3 4 4 4]
 [3 4 4 4 4]]


In [7]:
print('Document topic distribution:')
print('----------------------------')
print(Pi)

Document topic distribution:
----------------------------
[[0.17379786 0.82620214]
 [0.33210815 0.66789185]
 [0.0322541  0.9677459 ]
 [0.75891264 0.24108736]
 [0.84967238 0.15032762]
 [0.81500144 0.18499856]]


In [8]:
print('Topic\'s word distribution:')
print('---------------------------')
print(B)

Topic's word distribution:
---------------------------
[[0.11478112 0.07575037 0.03618259 0.06698306 0.70630286]
 [0.2382687  0.3871144  0.1468959  0.17844455 0.04927644]]


In [9]:
print('Word topic assignment:')
print('----------------------')
print(Z)

Word topic assignment:
----------------------
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
