In [44]:
import numpy as np
from scipy.special import polygamma, logsumexp

In [57]:
def normalize_vector(arr):
    arr_sum = np.sum(arr)
    return arr / arr_sum

def normalize_by_row(arr):
    return (arr.T / arr.sum(axis=1)).T

def exp_normalize(arr):
    # exponentiate, and then normalize to 1 (using the idea of log-sum-exp)
    M = np.max(arr)
    arr = arr - M
    exp_arr = np.exp(arr)
    return exp_arr / np.sum(exp_arr)

In [2]:
class minibatch_VI_sLDA:
    
    def __init__(self, K, bow, y, alpha, eta, delta, Lambda, epsilon):
        self.K = K # number of topics
        self.bow = bow # list of dictionaries, with length D
        self.y = y # D-dimensional vector
        self.word_indices = {d:np.hstack(np.hstack([[k]*v for k,v in bag.items()]) for d,bag in enumerate(bow)}
        self.alpha = alpha # K-dimensional vector
        # self.xi = xi # V-dimensional vector
        self.eta = eta # K-dimensional vector
        self.delta = delta # scalar
        self.Lambda = Lambda # size: K x V
        self.Lambda_rowsum = np.sum(Lambda, axis=1)
        self.D = len(self.bow) # batch_size: number of documents in the minibatch
        self.doc_len = [sum(list(v.values()) for v in bow] # number of words within each document
        self.gamma = np.ones(shape=(self.D, K)) # initialize local variational parameter gamma
        self.phi = {d:np.empty(shape=(self.doc_len[d], K)) for d in range(self.D)} # initialize local variational parameter phi
        
    def update_gamma(self):
        # update rule for gamma
        mean_phi = np.vstack([np.sum(v, axis=0) for k,v in self.phi.items()])
        self.gamma = self.alpha + mean_phi

    def update_phi_unsupervised(self):
        # update rule for phi in case y is not observed (prediction mode)
        for d in range(self.D):
            log_phi = polygamma(0, self.Lambda[:, word_indices[d]]).T + polygamma(0, self.gamma[d,:]) - polygamma(0, self.Lambda_rowsum)
            self.phi[d] = normalize_by_row(np.exp(log_phi))
        
    def update_phi_supervised(self):
        # update rule for phi when y is observed (training mode)
        for d in range(self.D):
            N_d = self.doc_len[d]
            for j,wi in enumerate(word_indices[d]):
                log_phi_j = polygamma(0, self.Lambda[:, wi]) + polygamma(0, self.gamma[d,:]) - polygamma(0, self.Lambda_rowsum)
                phi_minus_j = self.phi[d].sum(axis=0)-self.phi[d][j,:]
                log_phi_j += (y[d]/N_d/self.delta) * self.eta - 1/(2*N_d**2*self.delta) * (2*np.dot(self.eta, phi_minus_j)*self.eta + self.eta**2)
                self.phi[d][j,:] = normalize_vector(np.exp(log_phi_j))

    def coordinate_ascent(self):
        return self.gamma, self.phi