# HDP Gibbs Samplers

## *( 5.1 )* Posterior Sampling in the Chinese Restaurant Franchise

### Brief Overview

The Hierarchical Dirichlet Process mixture model is given by
$$\begin{aligned}
G_0 | \gamma, H &\sim DP(\gamma, H) \\
G_j | \alpha_0, G_0 &\sim DP(\alpha_0, G_0) \\
\theta_{ji} | G_j &\sim G_j \\
x_{ji} | \theta_{ji} &\sim F(\theta_{ji})
\end{aligned} $$

This model is able to non-parametrically cluster each group's data while sharing information both between and within groups.  A Dirichlet process is essentially a discrete distribution with atoms drawn from a (not-necessarily discrete) base measure $H$ and gradually decreasing weights determined by the "stick-breaking process."  In the HDP, each group is a Dirichlet process drawn from another DP $G_0$, so these will contain the same atoms as $G_0$ but with different weights:
$$\begin{aligned}
G_0 &= \sum_{k=1}^{\infty} \beta_k \delta(\phi_k) \\
G_j &= \sum_{k=1}^{\infty} \pi_{jk} \delta(\phi_k) \\
\phi_k | H &\sim H
\end{aligned} $$
Additionally, if we define $\beta, \pi_j$ as the collected weights above, it can be shown that these vectors encode a distribution over $\mathbb{Z}^+$ such that $\beta | \gamma \sim GEM(\gamma)$ and $\pi_j | \alpha_0, \beta \sim DP(\alpha_0, \beta)$.

Successive draws from a DP exhibit clustering behavior, since the probability of taking a certain value is a related to the number of previous draws of that value.  This is shown in the hierarchical sense by the *Chinese restaurant franchise* process.  Imagine a group of Chinese restaurants with a certain number of tables at each restaurant.  Let $\phi_k$ be the global dishes, drawn from $H$; $\psi_{jt}$ be the table-specific dishes, drawn from $G_0$; and $\theta_{ji}$ be the customer-specific dishes, drawn from $G_j$.  Denote $z_{ji}$ as the dish index eaten by customer $ji$; $t_{ji}$ as the table index where customer $ji$ sits; $k_{jt}$ be the dish index served at table $jt$; $n_{jtk}$ be the customer counts; and $m_{jk}$ be the table counts.  Then:

$$\begin{aligned}
\theta_{ji} | \text{other } \theta, \alpha_0, G_0 &\sim
    \sum_{t=1}^{m_{j\cdot}} \frac{n_{jt\cdot}}{i-1+\alpha_0} \delta(\psi_{jt}) +
                            \frac{\alpha_0}{i-1+\alpha_0} G_0 \\
\psi_{jt} | \text{other } \psi, \gamma, H &\sim
    \sum_{k=1}^{K} \frac{m_{\cdot k}}{m_{\cdot k} + \gamma} \delta(\phi_k) +
                            \frac{\gamma}{m_{\cdot k} + \gamma} H
\end{aligned} $$

### Full Conditionals

Choose some base measure $h(\cdot)$ and a conjugate data-generating distribution $f(\cdot | \theta)$.  Important to compute are $f_k^{-x_{ji}}(x_{ji})$, the mixture component of customer $ij$ under $k$, and $f_k^{-\mathbf{x}_{jt}}(\mathbf{x}_{jt})$, the mixture component of table $jt$ under $k$.  This is done by integrating out $\phi_k$ over the joint density of all such points, for example:

$$\begin{aligned}
f_k^{-x_{ji}}(x_{ji}) &= \frac { \int f(x_{ij} | \phi_k) g(k)d\phi_k } { \int g(k)d\phi_k } \\
g(k) &= h(\phi_k) \prod_{j'i' \neq ji, z_{j'i'} = k} f(x_{j'i'} | \phi_k) 
\end{aligned} $$

The corresponding mixture components for a new customer assignment and new table assignment are denoted $f_{k^*}^{-x_{ji}}(x_{ji})$ and $f_{k^*}^{-\mathbf{x}_{jt}}(\mathbf{x}_{jt})$, which are special cases of their the respective $f_k$ component where no data points have $z_{ij} = k^*$.

Using this, we first compute the likelihood of a given point $x_{ji}$ given the current clustering scheme:
$$
p(x_{ji} | t^{-ji}, t_{ji} = t^*, k) =
    \sum_{k=1}^{K} \frac{m_{\cdot k}}{m_{\cdot k} + \gamma} f_k^{-x_{ji}}(x_{ji}) +
                            \frac{\gamma}{m_{\cdot k} + \gamma} f_{k^*}^{-x_{ji}}(x_{ji})
$$

For efficiency, the Gibbs scheme implemented below only samples the $t$ and $k$ indexes (which can later be reverse-engineered to obtain the actual parameters).  The state space of the $k$ values is technically infinite, and the number of tables/dishes currently associated with the data is undefined.  We keep a running list of active $t$ and $k$ values.  Each update step, each customer is assigned either to one of the existing tables or to a new table, and if a customer is assigned to a new table, a new $k$ corresponding value gets drawn; similarly, each table is assigned a dish, either from the existing dishes or with a new dish.  If a table/dish becomes unrepresented in the current scheme, it gets removed from its respective list.  The update full conditionals are:

$$ \begin{aligned}
p(t_{ji} = t | t^{-ji}, k, ...) &\propto \begin{cases}
    n_{jt\cdot}^{-ji} f_{k_{jt}}^{-x_{ji}}(x_{ji}) & t\text{ used}\\
    \alpha_0 p(x_{ji} | ...) & t\text{ new}
    \end{cases} \\
p(k_{jt} = k | t, k^{-jt}) &\propto \begin{cases}
    m_{\cdot k} f_k^{-\mathbf{x}_{jt}}(\mathbf{x}_{jt}) & k\text{ used}\\
    \gamma f_{k^*}^{-\mathbf{x}_{jt}}(\mathbf{x}_{jt}) & k\text{ new}
    \end{cases} \\
\end{aligned} $$

### Distribution-Specific Mixture Components

The only part of this sampling algorithm that depends on the choice of the measures $H$ and $F$ are the mixture components $f_k$, so this is the only part that needs rewritten for each type of model.  Let
$$ \begin{aligned}
V_{kji} &= \{ j'i' : j'i' \neq ji, z_{j'i'} = k \} \\
W_{kjt} &= \{ j'i' : j't_{j'i'} \neq jt, k_{j't_{j'i'} = k} \} \\
T_{jt} &= \{ j'i': t_{j'i'} = jt \} \\
\end{aligned} $$
$V$ is the set of all customers (excluding customer $ij$) eating dish $k$; $W$ is the set of all customers at tables (excluding table $jt$) eating $k$; these correspond to the product terms in the mixture components.  By conjugacy rules and kernel tricks, each $f_k$ can be expressed as functions of these sets.  Each $f_{k^*}$ can be found by using the corresponding $f_k$ formula where $V$ or $W$ is the empty set.

*F = Poisson, H = Gamma*
$$ \begin{aligned}
f_k^{-x_{ji}}(x_{ji}) &= \frac{1}{x_{ji}!} \cdot
    \frac{\Gamma(x_{ji} + \alpha_v)}{(1 + \beta_v)^{x_{ji} + \alpha_v}} \cdot
    \frac{(\beta_v)^{\alpha_v}}{\Gamma(\alpha_v)} \\
f_k^{-\mathbf{x}_{jt}}(\mathbf{x}_{jt}) &= \frac{1}{\prod_T x_t!} \cdot
    \frac{\Gamma(\sum_T x_t + \alpha_w)}{(1 + \beta_w)^{\sum_T x_t + \alpha_w}} \cdot
    \frac{(\beta_w)^{\alpha_w}}{\Gamma(\alpha_w)} \\
\alpha_v &= \sum_V x_v + \alpha \quad , \quad \beta_v = |V| + \beta \\
\alpha_w &= \sum_W x_w + \alpha \quad , \quad \beta_w = |W| + \beta \\
\end{aligned} $$

*F = Normal (known variance), H = Normal*

In [422]:
import numpy as np
import pandas as pd
from scipy.special import loggamma as logg

def pois_fk(x, t, k, tk, Kmax, ha, hb):
    """
    Computes in one sweep the mixture components for each customer/table for each k.
    MODEL: base measure H ~ Gamma(ha, hb), F(x|phi) ~ Poisson(phi)
    All components are calculated exactly in log-space and then exponentiated.
    
    returns: fk_cust = (N, Kmax) matrix
             fk_tabl = (T, Kmax) matrix
             fknew_cust = (N,) vector
             fknew_tabl = (T,) vector
    """
    
    x = x.flatten()  # reshape to 1D, since gibbs routine passes in a 2D array
    
    fk_cust = np.zeros((len(k), Kmax))
    fknew_cust = np.zeros(len(k))
    
    # FOR k WITH NO MEMBERS
    fknew_cust = np.exp( -logg(x + 1) + logg(x + ha) - logg(ha) -
                         (x + ha) * np.log(1 + hb) + ha * np.log(hb) )
    
    # FOR k WITH MEMBERS
    for kk in range(Kmax):
        x_kk = x[k == kk]               # subset of x values with value kk
        x_in = (k == kk).astype('int')  # offset for x values in subset
        # If a value for k is not used, same as the prior
        if len(x_kk) == 0:
            fk_cust[:, kk] = fknew_cust
            continue
        
        # Compute (a,b) params from gamma kernel tricks done in fk function
        a_denom = (np.sum(x_kk) - x_in*x) + ha
        b_denom = (len(x_kk) - x_in) + hb
        a_numer = x + a_denom
        b_numer = 1 + b_denom
        #print(f"kk = {kk}, subset size = {len(x_kk)}")
        #print(f"{np.c_[x,k,a_numer,a_denom,b_numer,b_denom]}")
        fk_cust[:, kk] = np.exp( -logg(x + 1) + logg(a_numer) - logg(a_denom) -
                                 a_numer * np.log(b_numer) + a_denom * np.log(b_denom) )
    
    return [fk_cust, fknew_cust]     

In [555]:
class CFRP:
    """
    Model implementing the Chinese Franchise Restaurant Process.
    
    CONSTRUCTOR PARAMETERS
    - gamma, alpha0: scaling parameters > 0 for base measures H and G0
    - f: string representing distribution of data; h is chosen to be conjugate
    - hypers: tuple of hyperparameter values specific to f/h scheme chosen
    
    PRIVATE ATTRIBUTES (volatile)
    - t_: set of active t values; formatted as {j: set(t...), ...}
    - k_: set of active k values
    - tk_map_: (J x Tmax) matrix of k values for each (j,t) pair
    - n_: (J x Tmax) matrix specifying counts of customers
    - m_: (J x Kmax) matrix specifying counts of tables
    - f_, h_: distribution functions
    - fk_routine_: a function to compute mixing components for Gibbs sampling
    
    PUBLIC ATTRIBUTES
    post_samples: (S x 4) matrix of (j, t, k, phi) values for each data point i;
                  exists only after gibbs() has been called
    """
    
    def __init__(self, gamma=1, alpha0=1, f='poisson', hypers=None):
        self.g_ = gamma
        self.a0_ = alpha0
        self.set_priors(f, hypers)
        
    def set_priors(self, f, hypers):
        """
        Initializes the type of base measure h_ and data-generation function f_.
        Also sets hypers_, the relevelant hyperparameters and
                  fk_routine_, the function to compute mixing components.
        """
        if f == 'poisson':
            # Specify parameters of H ~ Gamma(a,b)
            if hypers is None:
                self.hypers_ = (1,1)
            else: self.hypers_ = hypers
            self.fk_routine_ = pois_fk
    
    
    def tally_up(self, it, which=None):
        """
        Helper function for computing maps and counts in gibbs().
        Given a current iteration in the post_samples attribute, does a full
        recount of customer/table allocations, updating n_ and m_.
        Set which = 'n' or 'm' to only tally up that portion
        """
        
        jt_pairs = self.post_samples[it,:,0:2]
        
        if which != 'm':
            # Count customers at each table (jt)
            cust_counts = pd.Series(map(tuple, jt_pairs)).value_counts()
            j_idx, t_idx = tuple(map(np.array, zip(*cust_counts.index)))
            self.n_[j_idx, t_idx] += cust_counts
            
        if which != 'n':
            # First filter by unique tables (jt), then count tables with each k value
            jt_unique, k_idx = np.unique(jt_pairs, axis=0, return_index=True)
            jk_pairs = np.c_[self.post_samples[it, k_idx, 0],
                             self.post_samples[it, k_idx, 2]]
            #print(jk_pairs)
            tabl_counts = pd.Series(map(tuple, jk_pairs)).value_counts()
            #print(tabl_counts)
            j_idx, k_idx = tuple(map(np.array, zip(*tabl_counts.index)))
            self.m_[j_idx, k_idx] += tabl_counts
        
        
    def gibbs(self, x, j, iters=1, Kmax=10):
        """
        Runs the Gibbs sampler to generate posterior estimates of t and k.
        x: data matrix, stored row-wise if multidimensional
        j: vector of group labels; must have same #rows as x
        iters: number of iterations to run
        Kmax: maximum number of atoms to draw from base measure H
        
        results: creation of post_samples attribute
        """
            
        group_counts = pd.Series(j).value_counts()
        # number of tables cannot exceed size of max group
        J, Tmax, N = np.max(j) + 1, np.max(group_counts) + 1, len(j)
        self.n_ = np.zeros((J, Tmax))
        self.m_ = np.zeros((J, Kmax))
        self.post_samples = np.zeros((iters+1, N, 4), dtype='int')
        self.post_samples[:,:,0] = j
        
        # Set random initial values for t and k assignments
        t0, k0 = self.post_samples[0,:,1], self.post_samples[0,:,2]
        t0[:] = np.random.randint(1, Tmax, size=N)
        self.tk_map_ = np.random.randint(1, Kmax//2, (J, Tmax))
        self.tally_up(it=0, which='n')
        for jj in range(J):
            for tt in np.where(self.n_[jj, :] > 0)[0]:
                #print(f"mapping: {(jj, tt)} -> {self.tk_map_[jj, tt]}")
                k0[np.logical_and(j == jj, t0 == tt)] = self.tk_map_[jj, tt]
        self.tally_up(it=0, which='m')
        
        for s in range(iters):
            t_prev, k_prev = self.post_samples[s,:,1], self.post_samples[s,:,2]
            t_next, k_next = self.post_samples[s+1,:,1], self.post_samples[s+1,:,2]
            
            # Get matrices of f_k values (dependent on model specification)
            mixes = self.fk_routine_(x, t_prev, k_prev, self.tk_map_, Kmax, *self.hypers_) 
            self.fk_ = mixes
            # Calculate pointwise likelihood 
            Mk = np.sum(self.m_, axis=0)   # number of tables serving k
            lik = ( mixes[0] @ (Mk / (Mk + self.g_)) + 
                    np.tile(mixes[1][:,None], Kmax) @ (self.g_ / (Mk + self.g_)) )
            self.lik_ = lik
            
            t_next = t_prev
            # Cycle through each t value of each customer, conditioning on everything
            # Randomize the order in which updates occur
            for i in np.random.permutation(N):
                continue
                
            k_next = k_prev
            # Similarly, cycle through the k values of each table
            for i in np.random.permutation(N):
                continue            
        
        
        return self

In [556]:
# Simulated data
N = 25
np.random.seed(0)
j = np.random.randint(1, 10, N)
x = np.random.poisson(j, N)
data = np.c_[x, j]

c = CFRP(hypers=(1,10)).gibbs(x[:,None], j)
iter0 = c.post_samples[0,:,:]
#print(c.fk_)
#print(np.c_[c.post_samples[0,:,2], np.argmax(c.fk_, axis=1)])

In [560]:
# t, k, log(customer mixtures)
#print(np.log(c.fk_[0]).round(2))
#print(c.n_)
print(c.m_)
print(iter0)
mix = c.fk_[0]
lik = c.lik_

lik

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 3. 1. 0. 1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 3. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 1. 0. 0. 0. 0. 0.]
 [0. 2. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 2. 0. 0. 0. 0. 0.]]
[[6 3 1 0]
 [1 5 1 0]
 [4 5 1 0]
 [4 3 1 0]
 [8 4 1 0]
 [4 4 4 0]
 [6 2 1 0]
 [3 3 4 0]
 [5 5 3 0]
 [8 4 1 0]
 [7 4 4 0]
 [9 2 4 0]
 [9 2 4 0]
 [2 2 2 0]
 [7 2 3 0]
 [8 3 1 0]
 [8 1 2 0]
 [9 4 4 0]
 [2 2 2 0]
 [6 5 1 0]
 [9 4 4 0]
 [5 2 1 0]
 [4 2 1 0]
 [1 3 4 0]
 [4 1 2 0]]


array([4.26990434e-04, 1.31224069e+00, 5.23941875e-01, 7.44624745e-01,
       4.26990434e-04, 5.23733255e-01, 1.56038830e-03, 1.31140398e+00,
       9.83782403e-02, 1.52812166e-02, 4.16253425e-02, 2.22162572e-03,
       1.71709570e-05, 5.12980035e-01, 1.89963906e-02, 9.35155726e-02,
       9.79996891e-02, 4.16253425e-02, 1.31846192e+00, 3.36867777e-01,
       1.86824208e-01, 4.02646250e-02, 4.02646250e-02, 6.68880021e+00,
       1.89591123e-02])

## *( 5.2 )* Posterior Sampling with Augmented Representation

## *( 5.3 )* Posterior Sampling by Direct Assignment