# HDP Gibbs Samplers

## Implementation of Samplers

Both of the above schemes are implemented in the `HDP` class and callable via the `gibbs_crf` and `gibbs_direct` methods, respectively.

In [136]:
import numpy as np
import pandas as pd
from scipy.special import gammaln as logg

In [137]:
def pois_fk_cust(i, x, k, Kmax, ha, hb, new=False):
    """
    Computes the mixture components for a given customer across all k values.
    MODEL: base measure H ~ Gamma(ha, hb), F(x|phi) ~ Poisson(phi)
    All components are calculated exactly in log-space and then exponentiated.
    
    returns: (Kmax,) vector; if new=True, returns a scalar
    """
    
    x = x.flatten()  # reshape to 1D, since gibbs routine passes in a 2D array
    
    # Calculate the case where k has no members
    fknew_cust = np.exp( -logg(x[i] + 1) + logg(x[i] + ha) - logg(ha) -
                         (x[i] + ha)*np.log(1 + hb) + ha*np.log(hb) )
    if new == True: return fknew_cust        
    
    x_kks = [x[k == kk] for kk in range(Kmax)]  # subset of customers eating kk
    xi_in = np.zeros(Kmax)                      # offset if x[i] is in this subset
    xi_in[k[i]] = 1
      
    # Compute (a,b) params from gamma kernel tricks done in fk function
    av = np.array(list(map(np.sum, x_kks))) - xi_in*x[i] + ha
    bv = np.array(list(map(len, x_kks))) - xi_in + hb
    fk_cust = np.exp( -logg(x[i] + 1) + logg(x[i] + av) - logg(av) -
                      (x[i] + av)*np.log(1 + bv) + av*np.log(bv) )
     
    return fk_cust


def pois_fk_tabl(jj, tt, x, j, t, k, Kmax, ha, hb, new=False):
    """
    Computes the mixture components for a given table across all k values.
    MODEL: base measure H ~ Gamma(ha, hb), F(x|phi) ~ Poisson(phi)
    All components are calculated exactly in log-space and then exponentiated.
    
    returns: (Kmax,) vector; if new=True, returns a scalar
    """
    
    x = x.flatten()  # reshape to 1D, since gibbs routine passes in a 2D array
    x_jt = x[np.logical_and(j == jj, t == tt)]
    kk = k[np.logical_and(j == jj, t == tt)]
    
    fknew_tabl = np.exp( -np.sum(logg(x_jt + 1)) + logg(np.sum(x_jt) + ha) - logg(ha) -
                         (np.sum(x_jt) + ha)*np.log(len(x_jt) + hb) + ha*np.log(hb) )
    # If table jt doesn't exist, just return the "new" mixture component
    if len(x_jt) == 0:
        print(f"WARNING: table {(jj, tt)} does not exist currently")
        new = True
    if new == True: return np.full(Kmax, fknew_tabl)
    
    x_kks = [x[k == kk] for kk in range(Kmax)]  # subset of customers at tables serving kk
    xjt_in = np.zeros(Kmax)                      # offset if table x_jt is in this subset
    xjt_in[kk[0]] = 1
      
    # Compute (a,b) params from gamma kernel tricks done in fk function
    av = np.array(list(map(np.sum, x_kks))) - xjt_in*np.sum(x_jt) + ha
    bv = np.array(list(map(len, x_kks))) - xjt_in*len(x_jt) + hb
    fk_tabl = np.exp( -np.sum(logg(x_jt + 1)) + logg(np.sum(x_jt) + av) - logg(av) -
                       (np.sum(x_jt) + av)*np.log(len(x_jt) + bv) + ha*np.log(bv) )
     
    return fk_tabl

In [138]:
def mnom_fk_cust(i, x, k, Kmax, L, ha, new=False):
    """
    Computes the mixture components for a given customer across all k values.
    MODEL: base measure H ~ Dirichlet(L, ha_1,...,ha_L),
                        F(x|phi) ~ Multinomial(n_ji, phi_1,...,phi_L)
    All components are calculated exactly in log-space and then exponentiated.
    
    returns: (Kmax,) vector; if new=True, returns a scalar
    """
    
    xi, ni = x[i, :], np.sum(x[i, :])
    log_con = logg(ni + 1) - np.sum(logg(xi + 1)) # term constant for all k
    # Calculate the case where k has no members
    fknew_cust = np.exp( log_con + np.sum(logg(xi + ha)) - logg(np.sum(xi + ha)) + 
                         logg(np.sum(ha)) - np.sum(logg(ha)) )
    if new == True: return fknew_cust        
    
    # Get subset of customers eating kk; each entry is a (#, L) matrix
    x_kks = [x[k == kk, :] for kk in range(Kmax)]  
    
    # Compute params from Dirichlet kernel tricks done in fk function
    a_bot = np.vstack([np.sum(x_kk, axis=0) for x_kk in x_kks]) + ha[None, :]    # (Kmax, L)
    a_bot[k[i], :] -= xi                         # offset if xi is in this subset
    a_top = a_bot + xi[None, :]
    fk_cust = np.exp( log_con + np.sum(logg(a_top), axis=1) - logg(np.sum(a_top, axis=1)) +
                      logg(np.sum(a_bot, axis=1)) - np.sum(logg(a_bot), axis=1) )
     
    return fk_cust


def mnom_fk_tabl(jj, tt, x, j, t, k, Kmax, L, ha, new=False):
    """
    Computes the mixture components for a given customer across all k values.
    MODEL: base measure H ~ Dirichlet(L, ha_1,...,ha_L),
                        F(x|phi) ~ Multinomial(n_ji, phi_1,...,phi_L)
    All components are calculated exactly in log-space and then exponentiated.
    
    returns: (Kmax,) vector; if new=True, returns a scalar
    """
    
    x_jt = x[np.logical_and(j == jj, t == tt), :]                                # (|T|, L)
    kk = k[np.logical_and(j == jj, t == tt)]
    n_jt = np.sum(x_jt, axis=1)                                                  # (|T|,)
    sum_jt = np.sum(x_jt, axis=0)                                                # (L,)
    log_con = np.sum(logg(n_jt + 1)) - np.sum(logg(x_jt + 1))    # term constant for all k
    
    fknew_tabl = np.exp( log_con + np.sum(logg(sum_jt + ha)) - logg(np.sum(sum_jt + ha)) + 
                         logg(np.sum(ha)) - np.sum(logg(ha)) )
    # If table jt doesn't exist, just return the "new" mixture component
    if x_jt.shape[0] == 0:
        print(f"WARNING: table {(jj, tt)} does not exist currently")
        new = True
    if new == True: return fknew_tabl       
    
    # Get subset of customers eating kk; each entry is a (#, L) matrix
    x_kks = [x[k == kk, :] for kk in range(Kmax)]
      
    # Compute params from Dirichlet kernel tricks done in fk function
    a_bot = np.vstack([np.sum(x_kk, axis=0) for x_kk in x_kks]) + ha[None, :]    # (Kmax, L)
    a_bot[kk[0], :] -= sum_jt                       # offset if table x_jt is in this subset
    a_top = a_bot + sum_jt[None, :]
    fk_tabl = np.exp( log_con + np.sum(logg(a_top), axis=1) - logg(np.sum(a_top, axis=1)) +
                      logg(np.sum(a_bot, axis=1)) - np.sum(logg(a_bot), axis=1) )

    return fk_tabl

In [139]:
def cat_fk_cust(i, x, k, Kmax, L, ha, new=False):
    """
    Computes the mixture components for a given customer across all k values.
    MODEL: base measure H ~ Dirichlet(L, ha_1,...,ha_L),
                        F(x|phi) ~ Categorical(L, phi_1,...,phi_L)
    All components are calculated exactly in log-space and then exponentiated.
    
    returns: (Kmax,) vector; if new=True, returns a scalar
    """
    
    xi, ni = x[i, :], np.sum(x[i, :])
    log_con = logg(ni + 1) - np.sum(logg(xi + 1)) # term constant for all k
    # Calculate the case where k has no members
    fknew_cust = np.exp( log_con + np.sum(logg(xi + ha)) - logg(np.sum(xi + ha)) + 
                         logg(np.sum(ha)) - np.sum(logg(ha)) )
    if new == True: return fknew_cust        
    
    # Get subset of customers eating kk; each entry is a (#, L) matrix
    x_kks = [x[k == kk, :] for kk in range(Kmax)]  
    
    # Compute params from Dirichlet kernel tricks done in fk function
    a_bot = np.vstack([np.sum(x_kk, axis=0) for x_kk in x_kks]) + ha[None, :]    # (Kmax, L)
    a_bot[k[i], :] -= xi                         # offset if xi is in this subset
    a_top = a_bot + xi[None, :]
    fk_cust = np.exp( log_con + np.sum(logg(a_top), axis=1) - logg(np.sum(a_top, axis=1)) +
                      logg(np.sum(a_bot, axis=1)) - np.sum(logg(a_bot), axis=1) )
     
    return fk_cust


def cat_fk_tabl(jj, tt, x, j, t, k, Kmax, L, ha, new=False):
    """
    Computes the mixture components for a given customer across all k values.
    MODEL: base measure H ~ Dirichlet(L, ha_1,...,ha_L),
                        F(x|phi) ~ Categorical(L, phi_1,...,phi_L)
    All components are calculated exactly in log-space and then exponentiated.
    
    returns: (Kmax,) vector; if new=True, returns a scalar
    """
    
    x_jt = x[np.logical_and(j == jj, t == tt), :]                                # (|T|, L)
    kk = k[np.logical_and(j == jj, t == tt)]
    n_jt = np.sum(x_jt, axis=1)                                                  # (|T|,)
    sum_jt = np.sum(x_jt, axis=0)                                                # (L,)
    log_con = np.sum(logg(n_jt + 1)) - np.sum(logg(x_jt + 1))    # term constant for all k
    
    fknew_tabl = np.exp( log_con + np.sum(logg(sum_jt + ha)) - logg(np.sum(sum_jt + ha)) + 
                         logg(np.sum(ha)) - np.sum(logg(ha)) )
    # If table jt doesn't exist, just return the "new" mixture component
    if x_jt.shape[0] == 0:
        print(f"WARNING: table {(jj, tt)} does not exist currently")
        new = True
    if new == True: return fknew_tabl       
    
    # Get subset of customers eating kk; each entry is a (#, L) matrix
    x_kks = [x[k == kk, :] for kk in range(Kmax)]
      
    # Compute params from Dirichlet kernel tricks done in fk function
    a_bot = np.vstack([np.sum(x_kk, axis=0) for x_kk in x_kks]) + ha[None, :]    # (Kmax, L)
    a_bot[kk[0], :] -= sum_jt                       # offset if table x_jt is in this subset
    a_top = a_bot + sum_jt[None, :]
    fk_tabl = np.exp( log_con + np.sum(logg(a_top), axis=1) - logg(np.sum(a_top, axis=1)) +
                      logg(np.sum(a_bot, axis=1)) - np.sum(logg(a_bot), axis=1) )

    return fk_tabl

In [174]:
class StirlingEngine:
    """
    Numerically efficient engine for computing and storing computed Stirling numbers.
    
    CONSTRUCTOR PARAMETERS
    - Nmax: largest integer n for which s(n,m) will need to be computed
    
    PRIVATE ATTRIBUTES
    - s_memo_, slog_memo_: running tables of previously computed values
    """
    
    def __init__(self, Nmax):
        self.s_memo_ = np.full((Nmax, Nmax), np.nan)
        self.slog_memo_ = np.full((Nmax, Nmax), np.nan)
        
    def stirling(self, n, m):
        """
        Computes an unsigned Stirling number of the first kind.
        Uses dynamic programming to store previously computed s(n,m) values,
        as this is a repeatedly-called recursive algorithm.
        """

        # If this has already been computed, return stored value
        if not np.isnan(self.s_memo_[n, m]):
            return self.s_memo_[n, m]
        else:
            return_val = np.nan

            # Base cases
            if (n == 0 and m == 0) or (n == 1 and m == 1):
                return_val = 1
            elif (n > 0 and m == 0) or m > n:
                return_val = 0
            # Recursion relation
            else:
                return_val = stirling(n-1, m-1) + (n-1)*stirling(n-1, m)

            self.s_memo_[n, m] = return_val
            return return_val
    
    
    def stirlog(self, n, m):
        """
        Computes the natural logarithm of an unsigned Stirling number,
        using the same dynamic programming approach as above.
        If s(n,m) = 0, this gets returned as -inf (np.exp(-inf) == 0.0)
        
        This is the preferred function, as stirling() can encounter overflow errors.
        """

        # If this has already been computed, return stored value
        if not np.isnan(self.slog_memo_[n, m]):
            return self.slog_memo_[n, m]
        else:
            return_val = np.nan

            # Base cases
            if (n == 0 and m == 0) or (n == 1 and m == 1):
                return_val = 0
            elif (n > 0 and m == 0) or m > n:
                return_val = -np.inf
            # Recursion relation
            else:
                log_s1, log_s2 = stirlog(n-1, m-1), stirlog(n-1, m)
                # If s1 == 0 (log_s1 == -inf), just return (n-1)*log_s2
                # By definition, must have s2 > s1, so only need to check s1
                if np.isfinite(log_s1):
                    val = (n-1) * np.exp(log_s2 - log_s1)
                    # If there is overflow/underflow in `val`, approximate log(1+x) = log(x)
                    if np.isfinite(val):
                        return_val = log_s1 + np.log1p(val)
                    else:
                        return_val = log_s2 + np.log(n-1)
                else:
                    return_val = log_s2 + np.log(n-1)

            self.slog_memo_[n, m] = return_val
            return return_val

In [185]:
class HDP:
    """
    Model implementing the Chinese Restaurant Franchise Process formulation of the HDP.
    
    CONSTRUCTOR PARAMETERS
    - gamma, alpha0: scaling parameters > 0 for base measures H and G0
    - f: string representing distribution of data; h is chosen to be conjugate
    - hypers: tuple of hyperparameter values specific to f/h scheme chosen
    
    PRIVATE ATTRIBUTES (volatile)
    - tk_map_: (J x Tmax) matrix of k values for each (j,t) pair
    - beta_: (Kmax + 1,) vector of beta values for each k
    - n_: (J x Tmax) matrix specifying counts of customers (gibbs_cfr)
    - q_: (J x Kmax) matrix specifying counts of customers (gibbs_direct)
    - m_: (J x Kmax) matrix specifying counts of tables
    - fk_cust_, fk_tabl_: functions to compute mixing components for Gibbs sampling
    - stir_: an object of class StirlingEngine which computes Stirling numbers
    
    PUBLIC ATTRIBUTES
    cfr_samples: (S x 3) matrix of (j, t, k) values for each data point i;
                 exists only after gibbs_cfr() has been called
    direct_samples: (S x 2) matrix of (j, k) values for each data point i;
                    exists only after gibbs_direct() has been called
    """
    
    def __init__(self, gamma=1, alpha0=1, f='multinomial', hypers=None):
        self.g_ = gamma
        self.a0_ = alpha0
        self.set_priors(f, hypers)
        
    def set_priors(self, f, hypers):
        """
        Initializes the type of base measure h_ and data-generation function f_.
        Also sets hypers_, the relevelant hyperparameters and
                  fk_routine_, the function to compute mixing components.
        """
        if f == 'poisson':
            # Specify parameters of H ~ Gamma(a,b)
            if hypers is None:
                self.hypers_ = (1,1)
            else: self.hypers_ = hypers
            self.fk_cust_ = pois_fk_cust
            self.fk_tabl_ = pois_fk_tabl
        
        elif f == 'multinomial':
            if hypers is None:
                L = 2
                self.hypers_ = (L, np.full(L, 1/L))
            else: self.hypers_ = hypers
            self.fk_cust_ = mnom_fk_cust
            self.fk_tabl_ = mnom_fk_tabl
            
        elif f == 'categorical':
            # Identical to multinomial, but with some efficiency upgrades
            if hypers is None:
                L = 2
                self.hypers_ = (L, np.full(L, 1/L))
            else: self.hypers_ = hypers
            self.fk_cust_ = cat_fk_cust
            self.fk_tabl_ = cat_fk_tabl
    
    
    def tally_up(self, it, which=None):
        """
        Helper function for computing maps and counts in gibbs().
        Given a current iteration in the cfr_samples attribute, does a full
        recount of customer/table allocations, updating n_ and m_.
        Set which = 'n' or 'm' to only tally up that portion
        """    
        
        if which == 'n':
            jt_pairs = self.cfr_samples[it,:,0:2]
            # Count customers at each table (jt)
            cust_counts = pd.Series(map(tuple, jt_pairs)).value_counts()
            j_idx, t_idx = tuple(map(np.array, zip(*cust_counts.index)))
            self.n_ *= 0
            self.n_[j_idx, t_idx] = cust_counts
            
        elif which == 'm':
            jt_pairs = self.cfr_samples[it,:,0:2]
            # First filter by unique tables (jt), then count tables with each k value
            jt_unique, k_idx = np.unique(jt_pairs, axis=0, return_index=True)
            jk_pairs = np.c_[self.cfr_samples[it, k_idx, 0],
                             self.cfr_samples[it, k_idx, 2]]
            #print(jk_pairs)
            tabl_counts = pd.Series(map(tuple, jk_pairs)).value_counts()
            #print(tabl_counts)
            j_idx, k_idx = tuple(map(np.array, zip(*tabl_counts.index)))
            self.m_ *= 0
            self.m_[j_idx, k_idx] = tabl_counts
            
        elif which == 'q':
            jk_pairs = self.direct_samples[it,:,:]
            # Counts customers at each j eating k
            cust_counts = pd.Series(map(tuple, jk_pairs)).value_counts()
            j_idx, k_idx = tuple(map(np.array, zip(*cust_counts.index)))
            self.q_ *= 0
            self.q_[j_idx, k_idx] = cust_counts
    
    
    def draw_t(self, it, x, j, Tmax, Kmax, verbose):
        """
        Helper function which does the draws from the t_ij full conditional.
        Updates the counts and the samples matrices at iteration `it`.
        Called by gibbs_cfr()
        """
        
        t_next, k_next = self.cfr_samples[it,:,1], self.cfr_samples[it,:,2]
        # Cycle through the t value of each customer, conditioning on everything
        # Randomize the order in which updates occur
        for i in np.random.permutation(len(j)):
            jj, tt0, kk0 = j[i], t_next[i], k_next[i]

            # Get vector of customer f_k values (dependent on model specification)
            old_mixes = self.fk_cust_(i, x, k_next, Kmax, *self.hypers_) 
            new_mixes = self.fk_cust_(i, x, k_next, Kmax, *self.hypers_, new=True) 
            # Calculate pointwise likelihoods p(x_ji | ...)
            M = np.sum(self.m_)
            Mk = np.sum(self.m_, axis=0)   # number of tables serving k
            lik = old_mixes @ (Mk / (M + self.g_)) + new_mixes * (self.g_ / (M + self.g_))

            cust_offset = np.zeros(Tmax)
            cust_offset[tt0] = 1
            old_t = (self.n_[jj, :] - cust_offset) * old_mixes[self.tk_map_[jj, :]]      
            new_t = self.a0_ * lik
            # If a table is in use, prob comes from old_t; otherwise, from new_t
            t_used = self.n_[jj, :] > 0
            t_dist = old_t * t_used.astype('int') + new_t * np.logical_not(t_used).astype('int')
            """TEMPORARY FIX (bug should be found later):
               Remove nans and add epsilon so that distribution is all positive"""
            t_dist[np.isnan(t_dist)] = 0
            t_dist += 1e-6

            tt1 = np.random.choice(Tmax, p=t_dist/np.sum(t_dist))
            t_next[i] = tt1
            self.tally_up(it, which='n')

            # If this table was previously unoccupied, we need to select a k
            if self.n_[jj, tt1] == 1 and tt0 != tt1:
                old_k = np.sum(self.m_, axis=0) * old_mixes
                new_k = self.g_ * new_mixes
                k_used = np.sum(self.m_, axis=0) > 0
                k_dist = old_k * k_used.astype('int') + new_k * np.logical_not(k_used).astype('int')
                """TEMPORARY FIX (bug should be found later):
                   Remove nans and add epsilon so that distribution is all positive"""
                k_dist[np.isnan(k_dist)] = 0
                k_dist += 1e-6

                kk1 = np.random.choice(Kmax, p=k_dist/np.sum(k_dist))
                self.tk_map_[jj, tt1] = kk1
                k_next[i] = self.tk_map_[jj, tt1]
            self.tally_up(it, which='m')

            if verbose: print(f"~ customer (j,i) = {(jj,i)}" +
                              f" moves table: {tt0} -> {t_next[i]}, k: {kk0} -> {k_next[i]}")
    
    
    def draw_k(self, it, x, j, Kmax, verbose):
        """
        Helper function which does the draws from the t_ij full conditional.
        Updates the counts and the samples matrices at iteration `it`.
        Called by gibbs_cfr()
        """
        
        t_next, k_next = self.cfr_samples[it,:,1], self.cfr_samples[it,:,2]
        # Cycle through the k values of each table
        j_idx, t_idx = np.where(self.n_ > 0)   # find the occupied tables
        for i in np.random.permutation(len(j_idx)):
            jj, tt = j_idx[i], t_idx[i]
            kk0 = self.tk_map_[jj, tt]

            # Get vector of table f_k values (dependent on model specification)
            old_mixes = self.fk_tabl_(jj, tt, x, j, t_next, k_next, Kmax, *self.hypers_) 
            new_mixes = self.fk_tabl_(jj, tt, x, j, t_next, k_next, Kmax, *self.hypers_, new=True) 

            tabl_offset = np.zeros(Kmax)
            tabl_offset[kk0] = 1
            old_k = (np.sum(self.m_, axis=0) - tabl_offset) * old_mixes
            new_k = self.g_ * new_mixes
            k_used = np.sum(self.m_, axis=0) > 0
            k_dist = old_k * k_used.astype('int') + new_k * np.logical_not(k_used).astype('int')
            """TEMPORARY FIX (bug should be found later):
               Remove nans and add epsilon so that distribution is all positive"""
            k_dist[np.isnan(k_dist)] = 0
            k_dist += 1e-6

            #print(f"{old_k.round(3)}\n{new_k.round(3)}\n{k_used}\n{k_dist.round(3)}")
            kk1 = np.random.choice(Kmax, p=k_dist/np.sum(k_dist))
            self.tk_map_[jj, tt] = kk1
            k_next[np.logical_and(j == jj, t_next == tt)] = kk1
            self.tally_up(it, which='m')

            if verbose: print(f"~~ table (j,t) = {(jj,tt)} changes dish: {kk0} -> {kk1}")
    
    
    def draw_z(self, it, x, j, Kmax, verbose):
        """
        Helper function which does the draws from the z_ij full conditional.
        Updates the counts and the samples matrices at iteration `it`.
        Called by gibbs_direct()
        """
        
        k_next = self.direct_samples[it,:,1]
        # Cycle through the k values of each customer
        for i in np.random.permutation(len(j)):
            jj, kk0 = j[i], k_next[i]
            
            # Get vector of customer f_k values (dependent on model specification)
            old_mixes = self.fk_cust_(i, x, k_next, Kmax, *self.hypers_) 
            new_mixes = self.fk_cust_(i, x, k_next, Kmax, *self.hypers_, new=True) 
            
            cust_offset = np.zeros(Kmax)
            cust_offset[kk0] = 1
            old_k = (self.q_[jj, :] - cust_offset +
                     self.a0_ + self.beta_[:-1]) * old_mixes      
            new_k = self.a0_ * self.beta_[-1] * new_mixes
            k_used = np.sum(self.m_, axis=0) > 0
            k_dist = old_k * k_used + new_k * np.logical_not(k_used)
            """TEMPORARY FIX (bug should be found later):
               Remove nans and add epsilon so that distribution is all positive"""
            k_dist[np.isnan(k_dist)] = 0
            k_dist += 1e-6
            

            kk1 = np.random.choice(Kmax, p=k_dist/np.sum(k_dist))
            k_next[i] = kk1
            self.tally_up(it, which='q')
            
            if verbose:
                print(f"~ customer (j,i) = {(jj,i)}" +
                      f"changes dish: {kk0} -> {kk1}")
                print(f"  k_dist: {k_dist.round(3)} (sum {np.sum(k_dist)})")
                
    
    def draw_m(self, it, x, j, Kmax, verbose):
        """
        Helper function which does the draws from the z_ij full conditional.
        Updates the counts and the samples matrices at iteration `it`.
        Called by gibbs_direct()
        """
        
        k_next = self.direct_samples[it,:,1]
        self.m_ *= 0                           # reset the m counts
        # Cycle through the k values of each restaurant
        j_idx, k_idx = np.where(self.q_ > 0)   # find the consumed dishes
        for i in np.random.permutation(len(j_idx)):
            jj, kk = j_idx[i], k_idx[i]
            max_m = self.q_[jj, kk]
            
            abk = self.a0_ * self.beta_[kk]
            m_range = np.arange(max_m) + 1
            log_s = np.array([self.stir_.stirlog(max_m, m) for m in m_range])
            m_dist = np.exp( logg(abk) - logg(abk + max_m) +
                             log_s + m_range * np.log(abk) )
            """TEMPORARY FIX (bug should be found later):
               Remove nans and add epsilon so that distribution is all positive"""
            m_dist[np.isnan(m_dist)] = 0
            m_dist += 1e-6
            
            mm1 = np.random.choice(m_range, p=m_dist/np.sum(m_dist))
            self.m_[jj, kk] = mm1

            if verbose:
                print(f"~~ restaraunt {jj}: {mm1} tables / {max_m} customers eating {kk}")
                print(f"m_dist: {m_dist.round(3)}")
                
    
    def gibbs_cfr(self, x, j, iters, Tmax=5, Kmax=10, verbose=False):
        """
        Runs the Gibbs sampler to generate posterior estimates of t and k.
        x: data matrix, stored row-wise if multidimensional
        j: vector of group labels; must have same #rows as x
        iters: number of iterations to run
        Tmax: maximum number of clusters for each group
        Kmax: maximum number of atoms to draw from base measure H
        
        returns: this HDP object with cfr_samples attribute
        """
            
        group_counts = pd.Series(j).value_counts()
        J, N = np.max(j) + 1, len(j)
        self.n_ = np.zeros((J, Tmax), dtype='int')
        self.m_ = np.zeros((J, Kmax), dtype='int')
        self.cfr_samples = np.zeros((iters+1, N, 3), dtype='int')
        self.cfr_samples[:,:,0] = j
        
        # Set random initial values for t and k assignments
        t0, k0 = self.cfr_samples[0,:,1], self.cfr_samples[0,:,2]
        t0[:] = np.random.randint(0, Tmax, size=N)
        self.tk_map_ = np.random.randint(0, Kmax//2, (J, Tmax))
        self.tally_up(it=0, which='n')
        for jj in range(J):
            for tt in np.where(self.n_[jj, :] > 0)[0]:
                #print(f"mapping: {(jj, tt)} -> {self.tk_map_[jj, tt]}")
                k0[np.logical_and(j == jj, t0 == tt)] = self.tk_map_[jj, tt]
        self.tally_up(it=0, which='m')
        
        for s in range(iters):
            if verbose: print(f"----------------\n ITERATION {s+1}\n----------------")
            t_prev, k_prev = self.cfr_samples[s,:,1], self.cfr_samples[s,:,2]
            t_next, k_next = self.cfr_samples[s+1,:,1], self.cfr_samples[s+1,:,2]
            # Copy over the previous iteration as a starting point
            t_next[:], k_next[:] = t_prev, k_prev
            
            self.draw_t(s+1, x, j, Tmax, Kmax, verbose)
            self.draw_k(s+1, x, j, Kmax, verbose)
        
        return self  
    
    
    def gibbs_direct(self, x, j, iters, Kmax=10, verbose=False):
        """
        Runs the Gibbs sampler to generate posterior estimates of k.
        x: data matrix, stored row-wise if multidimensional
        j: vector of group labels; must have same #rows as x
        iters: number of iterations to run
        Kmax: maximum number of atoms to draw from base measure H
        
        returns: this HDP object with direct_samples attribute
        """
        
        group_counts = pd.Series(j).value_counts()
        J, N = np.max(j) + 1, len(j)
        self.q_ = np.zeros((J, Kmax), dtype='int')   # performs the same function as n_
        self.m_ = np.zeros((J, Kmax), dtype='int')
        self.beta_ = np.ones(Kmax + 1) * (1 / Kmax + 1)
        self.direct_samples = np.zeros((iters+1, N, 2), dtype='int')
        self.direct_samples[:,:,0] = j
        
        self.stir_ = StirlingEngine(N)
        np.seterr('ignore')
        
        # Set random initial values for k assignments
        k0 = self.direct_samples[0,:,1]
        k0[:] = np.random.randint(0, Kmax, size=N)
        self.tally_up(it=0, which='q')
        # Implicitly set random t assignments by drawing possible m counts (m_jk <= q_jk)
        for jj in range(J):
            for kk in range(Kmax):
                max_m = self.q_[jj, kk]
                if max_m == 1:
                    self.m_[jj, kk] = 1
                elif max_m > 1:
                    self.m_[jj, kk] = np.random.randint(1, max_m)
        
        for s in range(iters):
            if verbose: print(f"----------------\n ITERATION {s+1}\n----------------")
            k_prev, k_next = self.direct_samples[s,:,1], self.direct_samples[s+1,:,1]
            # Copy over the previous iteration as a starting point
            k_next[:] = k_prev
            
            self.draw_z(s+1, x, j, Kmax, verbose)
            self.draw_m(s+1, x, j, Kmax, verbose)
            
            Mk = np.sum(self.m_, axis=0)
            # Dirichlet weights must be > 0, so in case some k is unused, add epsilon
            self.beta_ = np.random.dirichlet(np.append(Mk, self.g_) + 1e-6)
        
        return self
        

In [193]:
# Simulated data (Poisson example)
N = 200
np.random.seed(0)
j = np.random.randint(0, 9, N)
x = np.random.poisson(j, N)
data = np.c_[x, j]

%time c = HDP(f='poisson', hypers=(1,10)).gibbs_direct(x[:,None], j, iters=100)

Wall time: 29.6 s


In [192]:
# Simulated data (Multinomial example)
N, L = 200, 100
np.random.seed(1)
X = np.random.randint(0, 10, (N, L))
j = np.random.randint(0, 10, N)

%time c = HDP(f='multinomial', hypers=(L, np.full(L, 1/L))).gibbs_direct(X, j, iters=100)

Wall time: 31.1 s


## Latent Topic Modeling Application

Below is an application of the above sampler using a multinomial data model.  The data is `final_project_data.csv`, produced by the modified preprocessing code in this directory, which contains a `(J, L)` matrix in which entry `(j,l)` contains the count of word `l` in document `j`, with the corresponding words given in the column names.

For the Dirichlet prior here, we use the observed distribution of the corpus vocabulary over all documents.  Customers could be encoded in four different ways to compare performance:
+ As a single word (such that `f` is categorical)
+ As a set of all identical words within a given document (each row of the data matrix has one entry, but the value can vary)
+ As a set of all words in a single sentence
+ As the entire document (essentially making this a non-hierarchical DP)

Since this algorithm has not been optimized, only a subset of the full dataframe is used for now.

In [194]:
import pandas as pd

full_df = pd.read_csv('final_project_data.csv', index_col=0, dtype='int')
vocab = full_df.columns

In [248]:
def expand_doc(doc_in):
    """Expands a row (passed in a series) into a dataframe in which each
       row contains only the counts of one unique word."""
    
    doc_in = doc_in.drop(J_ID, axis=1)
    doc_in = doc_in.iloc[0, :]
    words_used = doc_in[doc_in > 0]
    doc_out = pd.DataFrame(np.zeros((len(words_used), len(vocab)), dtype='int'),
                           columns=vocab, dtype='int')
    for i, word in enumerate(words_used.index):
        doc_out.loc[i, word] = words_used[word]
    return doc_out

J_ID = 'document#'
if J_ID not in full_df.columns:
    full_df.insert(0, column=J_ID, value=full_df.index)

Jmax = 10
wordset_df = full_df.iloc[:Jmax,:].groupby(J_ID).apply(expand_doc)

In [252]:
ji_indices = wordset_df.index.to_frame()
j = np.array(ji_indices[J_ID])
X = np.array(wordset_df)
# Get the corresponding word each ji is associated with
ji_words = vocab[np.where(X > 0)[1]]

Tmax, Kmax = np.max(ji_indices[1]), 20
# Get a prior distribution over the vocabulary from selected documents
L, h_alpha = X.shape[1], np.sum(X, axis=0)
h_alpha = h_alpha / np.sum(h_alpha)
iters = 10

In [258]:
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    %time hdp = HDP(f='multinomial', hypers=(L, h_alpha)).gibbs_direct(X, j, 100)

Wall time: 7min 55s


In [259]:
# Given seating assignment in final iteration, cluster customers with same k value
k_final = hdp.direct_samples[-1,:,1]
clusters = pd.DataFrame({'doc': j, 'word': ji_words, 'cluster': k_final})
for k in set(k_final):
    clusters_k = clusters[clusters['cluster'] == k].sort_values('word')
    print(f"-----------\nk = {k} (size {clusters_k.shape[0]})")
    print(clusters_k)

-----------
k = 0 (size 40)
     doc          word  cluster
316    7      addition        0
251    5        animal        0
176    4       animals        0
254    5        appear        0
234    4     attempted        0
444    9    attractive        0
384    8       closely        0
455    9      colonies        0
355    7   consumption        0
171    4   development        0
417    9       elegans        0
166    4       elegans        0
66     1      envelope        0
84     2         first        0
6      0         found        0
49     1         gonad        0
105    2      granules        0
425    9        growth        0
194    4    individual        0
97     2     localized        0
135    3     mechanism        0
61     1  mitochondria        0
213    4       mitosis        0
4      0       mutants        0
307    7      nematode        0
423    9             p        0
125    2   perinuclear        0
114    2         place        0
193    4     precursor        0
454    9    

In [232]:
len(ji_words)

456