In [6]:
from variational_inference_utils import *
from scipy.special import polygamma
import math

In [8]:
import multiprocessing

In [9]:
multiprocessing.cpu_count()

4

In [7]:
class VI_sLDA_E_Step:
    
    def __init__(self, K, bow, y, alpha, eta, delta, Lambda, epsilon=1e-5):
        self.K = K # number of topics
        self.bow = bow # dictionaries of arrays, with length D: each array represents the bag of words in the d^th document, with the length of array being N_d
        self.doc_len = {d:len(v) for d,v in bow.items()} # number of words within each document
        self.D = len(self.bow) # batch_size: number of documents in the minibatch
        self.y = y # D-dimensional vector
        self.alpha = alpha # K-dimensional vector
        self.eta = eta # K-dimensional vector
        self.delta = delta # scalar
        self.Lambda = Lambda # size: K x V
        self.Lambda_rowsum = np.sum(Lambda, axis=1)
        self.gamma = np.ones(shape=(self.D, K)) # initialize local variational parameter gamma (size: D x K)
        self.phi = {d:np.empty(shape=(self.doc_len[d], K)) for d in range(self.D)} # initialize local variational parameter phi (for each document, size is N_d x K)
        
    def update_gamma(self):
        # update rule for local variational parameter gamma
        sum_phi = np.vstack([np.sum(v, axis=0) for k,v in self.phi.items()]) # size: D x K
        self.gamma = self.alpha + sum_phi # use broadcasting

    def update_phi_unsupervised(self):
        # update rule for local variational parameter phi in case y is not observed (prediction mode): same as naive LDA
        for d in range(self.D): # can use vectorized operations to update each phi_d
            log_phi = polygamma(0, self.Lambda[:, self.bow[d]]).T + polygamma(0, self.gamma[d,:]) - polygamma(0, self.Lambda_rowsum) # the first term has size N_d x K, the 2nd & 3rd terms are K-dimensional vectors, so broadcasting is applicable
            self.phi[d] = normalize_by_row(np.exp(log_phi))
        
    def update_phi_supervised(self):
        # update rule for local variational parameter phi when y is observed (training mode): Eq (31) of SVI paper
        for d in range(self.D):
            N_d = self.doc_len[d]
            temp_var_1 = (self.y[d]/N_d/self.delta) * self.eta
            temp_var_2 = 1/(2*N_d**2*self.delta)
            temp_var_3 = self.eta**2
            for j,v in enumerate(self.bow[d]):
                log_phi_j = polygamma(0, self.Lambda[:, v]) + polygamma(0, self.gamma[d,:]) - polygamma(0, self.Lambda_rowsum) # first 2 terms same as the unsupervised case
                phi_minus_j = self.phi[d].sum(axis=0) - self.phi[d][j,:]
                log_phi_j += temp_var_1 - temp_var_2 * (2*np.dot(self.eta, phi_minus_j)*self.eta + temp_var_3) # Eq (33) of sLDA paper
                self.phi[d][j,:] = normalize_vector(np.exp(log_phi_j)) # dimension of topic space is relatively small, so no need of log-sum-exp normalization

    def coordinate_ascent_training(self, prediction=False):
        change_in_gamma = math.inf
        while change_in_gamma >= self.epsilon: # stopping criteria
            if prediction == True:
                self.update_phi_unsupervised()
            else:
                self.update_phi_supervised()
            previous_gamma = self.gamma
            self.update_gamma()
            change_in_gamma = np.mean(np.abs(self.gamma - previous_gamma))
        return self.gamma, self.phi