# HDP Gibbs Samplers

## *( 5.1 )* Posterior Sampling in the Chinese Restaurant Franchise

### Brief Overview

The Hierarchical Dirichlet Process mixture model is given by
$$\begin{aligned}
G_0 | \gamma, H &\sim DP(\gamma, H) \\
G_j | \alpha_0, G_0 &\sim DP(\alpha_0, G_0) \\
\theta_{ji} | G_j &\sim G_j \\
x_{ji} | \theta_{ji} &\sim F(\theta_{ji})
\end{aligned} $$

This model is able to non-parametrically cluster each group's data while sharing information both between and within groups.  A Dirichlet process is essentially a discrete distribution with atoms drawn from a (not-necessarily discrete) base measure $H$ and gradually decreasing weights determined by the "stick-breaking process."  In the HDP, each group is a Dirichlet process drawn from another DP $G_0$, so these will contain the same atoms as $G_0$ but with different weights:
$$\begin{aligned}
G_0 &= \sum_{k=1}^{\infty} \beta_k \delta(\phi_k) \\
G_j &= \sum_{k=1}^{\infty} \pi_{jk} \delta(\phi_k) \\
\phi_k | H &\sim H
\end{aligned} $$
Additionally, if we define $\beta, \pi_j$ as the collected weights above, it can be shown that these vectors encode a distribution over $\mathbb{Z}^+$ such that $\beta | \gamma \sim GEM(\gamma)$ and $\pi_j | \alpha_0, \beta \sim DP(\alpha_0, \beta)$.

Successive draws from a DP exhibit clustering behavior, since the probability of taking a certain value is a related to the number of previous draws of that value.  This is shown in the hierarchical sense by the *Chinese restaurant franchise* process.  Imagine a group of Chinese restaurants with a certain number of tables at each restaurant.  Let $\phi_k$ be the global dishes, drawn from $H$; $\psi_{jt}$ be the table-specific dishes, drawn from $G_0$; and $\theta_{ji}$ be the customer-specific dishes, drawn from $G_j$.  Denote $z_{ji}$ as the dish index eaten by customer $ji$; $t_{ji}$ as the table index where customer $ji$ sits; $k_{jt}$ be the dish index served at table $jt$; $n_{jtk}$ be the customer counts; and $m_{jk}$ be the table counts.  Then:

$$\begin{aligned}
\theta_{ji} | \text{other } \theta, \alpha_0, G_0 &\sim
    \sum_{t=1}^{m_{j\cdot}} \frac{n_{jt\cdot}}{i-1+\alpha_0} \delta(\psi_{jt}) +
                            \frac{\alpha_0}{i-1+\alpha_0} G_0 \\
\psi_{jt} | \text{other } \psi, \gamma, H &\sim
    \sum_{k=1}^{K} \frac{m_{\cdot k}}{m_{\cdot k} + \gamma} \delta(\phi_k) +
                            \frac{\gamma}{m_{\cdot k} + \gamma} H
\end{aligned} $$

### Full Conditionals

Choose some base measure $h(\cdot)$ and a conjugate data-generating distribution $f(\cdot | \theta)$.  Important to compute is $f_k^{-x_{ji}}(x_{ji})$, the conditional density of a point $x_{ji}$ with mixture component $k$, after integrating out $\phi_k$ over the joint density of all such points:

$$\begin{aligned}
f_k^{-x_{ji}}(x_{ji}) &= \frac { \int f(x_{ij} | \phi_k) g(k)d\phi_k } { \int g(k)d\phi_k } \\
g(k) &= \prod_{j'i' \neq ji, z_{j'i'} = k} f(x_{j'i'} | \phi_k) h(\phi_k)
\end{aligned} $$

Using this, we first compute the likelihood of a given point $x_{ji}$ given the current clusering scheme:
$$
p(x_{ji} | t^{-ji}, t_{ji} = t^*, k) =
    \sum_{k=1}^{K} \frac{m_{\cdot k}}{m_{\cdot k} + \gamma} f_k^{-x_{ji}}(x_{ji}) +
                            \frac{\gamma}{m_{\cdot k} + \gamma} f_{k^*}^{-x_{ji}}(x_{ji})
$$
where $f_k$ is the conditional density defined above and $f_{k^*} = \int f(x_{ji} | \phi) h(\phi) d\phi$ is the prior density of $x_{ji}$, found by integrating out all possible atoms $\phi_k$.

For efficiency, the Gibbs scheme implemented below only samples the $t$ and $k$ indexes (which can later be reverse-engineered to obtain the actual parameters).  The state space of the $k$ values is technically infinite, and the number of tables/dishes currently associated with the data is undefined.  We keep a running list of active $t$ and $k$ values.  Each update step, each customer is assigned either to one of the existing tables or to a new table, and if a customer is assigned to a new table, a new $k$ corresponding value gets drawn; similarly, each table is assigned a dish, either from the existing dishes or with a new dish.  If a table/dish becomes unrepresented in the current scheme, it gets removed from its respective list.  The update full conditionals are:

$$ \begin{aligned}
p(t_{ji} = t | t^{-ji}, k, ...) &\propto \begin{cases}
    n_{jt\cdot}^{-ji} f_{k_{jt}}^{-x_{ji}}(x_{ji}) & t\text{ used}\\
    \alpha_0 p(x_{ji} | ...) & t\text{ new}
    \end{cases} \\
p(k_{jt} = k | t, k^{-jt}) &\propto \begin{cases}
    m_{\cdot k} f_k^{-\mathbf{x}_{jt}}(\mathbf{x}_{jt}) & k\text{ used}\\
    \gamma f_{k^*}^{-\mathbf{x}_{jt}}(\mathbf{x}_{jt}) & k\text{ new}
    \end{cases} \\
\end{aligned} $$

In [279]:
import numpy as np
import pandas as pd
from scipy.special import digamma

def pois_fk(x, k, Kmax, ha, hb):
    """
    Computes in one sweep the mixture components f_k(x_ji) for each x and each k.
    MODEL: base measure H ~ Gamma(ha, hb), F(x|phi) ~ Poisson(phi)
    All components are calculated exactly in log-space and then exponentiated.
    
    returns: (N, Kmax) matrix
    """
    
    x = x.flatten()  # reshape to 1D, since gibbs routine passes in a 2D array
    log_fk_vals = np.zeros((len(k), Kmax))
    for kk in range(Kmax):
        x_kk = x[k == kk]               # subset of x values with value kk
        x_in = (k == kk).astype('int')  # offset for x values in subset
        # If a value for k is not used, all mixture components will be 0
        if len(x_kk) == 0:
            continue
        
        # Compute (a,b) params from gamma kernel tricks done in fk function
        a_denom = np.sum(x_kk) + (len(x_kk) - x_in) * (ha - 1) + 1
        b_denom = np.maximum((len(x_kk) - x_in) * (hb + 1), .001)   # if this is 0, it's a problem
        a_numer = x + a_denom
        b_numer = 1 + b_denom
        #print(f"kk = {kk}, subset size = {len(x_kk)}")
        #print(f"{np.c_[x,k,a_numer,a_denom,b_numer,b_denom]}")
        log_fk_vals[:, kk] = (-digamma(x + 1) - digamma(a_numer) + digamma(a_denom) +
                              a_numer * np.log(b_numer) - a_denom * np.log(b_denom) )
    
    return log_fk_vals        

In [280]:
class CFRP:
    """
    Model implementing the Chinese Franchise Restaurant Process.
    
    CONSTRUCTOR PARAMETERS
    - gamma, alpha0: scaling parameters > 0 for base measures H and G0
    - f: string representing distribution of data; h is chosen to be conjugate
    - hypers: tuple of hyperparameter values specific to f/h scheme chosen
    
    PRIVATE ATTRIBUTES (volatile)
    - t_: set of active t values; formatted as {j: set(t...), ...}
    - k_: set of active k values
    - tk_: (J x Tmax) of corresponding k values for each t
    - n_: (J x Tmax) tensor specifying counts of customers
    - m_: (J x Kmax) matrix specifying counts of tables
    
    PUBLIC ATTRIBUTES
    post_samples: (S x 4) matrix of (j, t, k, phi) values for each data point i;
                  exists only after gibbs() has been called
    """
    
    def __init__(self, gamma=1, alpha0=1, f='poisson', hypers=None):
        self.g_ = gamma
        self.a0_ = alpha0
        self.set_priors(f, hypers)
        
    def set_priors(self, f, hypers):
        """Initializes the type of base measure and data-generation function."""
        if f == 'poisson':
            # Specify parameters of H ~ Gamma(a,b)
            if hypers is None:
                self.hypers_ = (1,1)
            else: self.hypers_ = hypers
        self.f_ = f
        
        
    def gibbs(self, x, j, iters=1, Kmax=10):
        """
        Runs the Gibbs sampler to generate posterior estimates of t and k.
        x: data matrix, stored row-wise if multidimensional
        j: vector of group labels; must have same #rows as x
        iters: number of iterations to run
        Kmax: maximum number of atoms to draw from base measure H
        
        results: creation of post_samples attribute
        """
        
        # Set the distribution-specific function for fk (conditional density)
        fk_routine = None
        if self.f_ == 'poisson':
            fk_routine = pois_fk;
            
        group_counts = pd.Series(j).value_counts()
        # number of tables cannot exceed size of max group
        J, Tmax, N = np.max(j), np.max(group_counts), len(j)
        self.n_ = np.zeros((J, Tmax))
        self.m_ = np.zeros((J, Kmax))
        self.post_samples = np.zeros((iters+1, N, 4), dtype='int')
        self.post_samples[:,:,0] = j
        self.t_, self.k_ = {}, set()
        self.tk_ = np.zeros((J, Tmax))
        
        # Set random initial values for t and k assignments
        t0, k0 = self.post_samples[0,:,1], self.post_samples[0,:,2]   # define two views
        t0[:] = np.random.randint(1, Tmax, size=N)
        self.t_ = {jj: set(t0[j == jj]) for jj in range(J)}
        self.tk_ = np.random.randint(1, Kmax//2, (J, Tmax))           # one table => one dish
        for jj in range(J):
            for tt in self.t_[jj]:
                k0[np.logical_and(j == jj, t0 == tt)] = self.tk_[jj, tt]
        
        for s in range(iters):
            t_prev, k_prev = self.post_samples[s,:,1], self.post_samples[s,:,2]
            t_next, k_next = self.post_samples[s+1,:,1], self.post_samples[s+1,:,2]
            
            # Get a matrix of log fk(x_ji) values (dependent on model specification)
            log_fk = fk_routine(x, k_prev, Kmax, *self.hypers_) 
            self.fk_ = log_fk
            
            t_next = t_prev
            # Cycle through each t value of each customer, conditioning on everything
            # Randomize the order in which updates occur
            for i in np.random.permutation(N):
                continue
                
            k_next = k_prev
            # Similarly, cycle through the k values of each table
            for i in np.random.permutation(N):
                continue            
            
        return self

In [284]:
# Simulated data
N = 25
np.random.seed(3)
j = np.random.randint(1, 10, N)
x = np.random.poisson(j, N)
data = np.c_[x, j]

c = CFRP(hypers=(1,10)).gibbs(x[:,None], j)
#print(c.post_samples[0,:,:])
#print(c.fk_.astype("int"))
#print(np.c_[c.post_samples[0,:,2], np.argmax(c.fk_, axis=1)])

## *( 5.2 )* Posterior Sampling with Augmented Representation

## *( 5.3 )* Posterior Sampling by Direct Assignment