In [78]:
def compute_dirichlet_expectation(dirichlet_parameter):
    """Calculate the expectation of dirichlet parameter. When input is null, output should be a warning that the input is null."""
    if not np.array(dirichlet_parameter).size:
        return ("The dirichlet_parameter is null.")
    if (len(dirichlet_parameter.shape) == 1):
        return (sp.special.psi(dirichlet_parameter)-sp.special.psi(np.sum(dirichlet_parameter)))
    return (sp.special.psi(dirichlet_parameter) - sp.special.psi(np.sum(dirichlet_parameter, 1))[:, np.newaxis])


def e_step(corpus=None,local_parameter_iteration=50,
local_parameter_converge_threshold=1e-6):
    """E step. When input is None, output should be document_log_likelihood, phi_sufficient_statistics, and gamma. Otherwise, it should be words_log_likelihood, gamma_values."""
    if corpus == None:
        word_ids = parsed_corpus[0]
        word_cts = parsed_corpus[1]
    else:
        word_ids = corpus[0]
        word_cts = corpus[1]
    # Initialization 
    number_of_documents = len(word_ids)
    document_log_likelihood = 0
    words_log_likelihood = 0
    phi_sufficient_statistics = np.zeros((number_of_topics, number_of_types))
    gamma_values = np.zeros((number_of_documents, number_of_topics)) + alpha_alpha[np.newaxis, :] + 1.0 * number_of_types / number_of_topics
    E_log_eta = compute_dirichlet_expectation(eta)
    if parsed_corpus != None:
        E_log_prob_eta = E_log_eta - sp.misc.logsumexp(E_log_eta, axis=1)[:, np.newaxis]

    # iterate over all documents
    for doc_id in np.random.permutation(number_of_documents):
        # compute the total number of words
        total_word_count = np.sum(word_cts[doc_id])
        # initialize gamma for this document
        gamma_values[doc_id, :] = alpha_alpha + 1.0 * total_word_count / number_of_topics

        term_ids = word_ids[doc_id]
        term_counts = word_cts[doc_id]

        # update phi and gamma until gamma converges
        for gamma_iteration in range(local_parameter_iteration):
            log_phi = E_log_eta[:, term_ids].T + np.tile(sp.special.psi(gamma_values[doc_id, :]), (word_ids[doc_id].shape[0], 1))
            log_phi -= sp.misc.logsumexp(log_phi, axis=1)[:, np.newaxis]
            gamma_update = alpha_alpha + np.array(np.sum(np.exp(log_phi + np.log(np.repeat(term_counts, number_of_topics, axis=0).T)), axis=0))
            mean_change = np.mean(abs(gamma_update - gamma_values[doc_id, :]))
            gamma_values[doc_id, :] = gamma_update
            if mean_change <= local_parameter_converge_threshold:
                break

        # compute the alpha, gamma, and phi terms
        document_log_likelihood += sp.special.gammaln(np.sum(alpha_alpha)) - np.sum(sp.special.gammaln(alpha_alpha))
        document_log_likelihood += np.sum(sp.special.gammaln(gamma_values[doc_id, :])) - sp.special.gammaln(np.sum(gamma_values[doc_id, :]))
        document_log_likelihood -= np.sum(np.dot(term_counts, np.exp(log_phi) * log_phi))

# compute the p(w_{dn} | z_{dn}, \eta) terms, which will be cancelled during M-step
        words_log_likelihood += np.sum(np.exp(log_phi.T + np.log(term_counts)) * E_log_prob_eta[:, term_ids])      
        phi_sufficient_statistics[:, term_ids] += np.exp(log_phi + np.log(term_counts.transpose())).T
        
    if corpus == None:
        gamma = gamma_values
        return (document_log_likelihood, phi_sufficient_statistics, gamma)
    else:
        return (words_log_likelihood, gamma_values)


In [79]:
%%time
document_log_likelihood, phi_sufficient_statistics, gamma = e_step()

CPU times: user 52.4 s, sys: 583 ms, total: 52.9 s
Wall time: 53.8 s


In [80]:
@jit
def compute_dirichlet_expectation(dirichlet_parameter):
    """Calculate the expectation of dirichlet parameter. When input is null, output should be a warning that the input is null."""
    if not np.array(dirichlet_parameter).size:
        return ("The dirichlet_parameter is null.")
    if (len(dirichlet_parameter.shape) == 1):
        return (sp.special.psi(dirichlet_parameter)-sp.special.psi(np.sum(dirichlet_parameter)))
    return (sp.special.psi(dirichlet_parameter) - sp.special.psi(np.sum(dirichlet_parameter, 1))[:, np.newaxis])


def e_step(corpus=None,local_parameter_iteration=50,
local_parameter_converge_threshold=1e-6):
    """E step. When input is None, output should be document_log_likelihood, phi_sufficient_statistics, and gamma. Otherwise, it should be words_log_likelihood, gamma_values."""
    if corpus == None:
        word_ids = parsed_corpus[0]
        word_cts = parsed_corpus[1]
    else:
        word_ids = corpus[0]
        word_cts = corpus[1]
    # Initialization 
    number_of_documents = len(word_ids)
    document_log_likelihood = 0
    words_log_likelihood = 0
    phi_sufficient_statistics = np.zeros((number_of_topics, number_of_types))
    gamma_values = np.zeros((number_of_documents, number_of_topics)) + alpha_alpha[np.newaxis, :] + 1.0 * number_of_types / number_of_topics
    E_log_eta = compute_dirichlet_expectation(eta)
    if parsed_corpus != None:
        E_log_prob_eta = E_log_eta - sp.misc.logsumexp(E_log_eta, axis=1)[:, np.newaxis]

    # iterate over all documents
    for doc_id in np.random.permutation(number_of_documents):
        # compute the total number of words
        total_word_count = np.sum(word_cts[doc_id])
        # initialize gamma for this document
        gamma_values[doc_id, :] = alpha_alpha + 1.0 * total_word_count / number_of_topics

        term_ids = word_ids[doc_id]
        term_counts = word_cts[doc_id]

        # update phi and gamma until gamma converges
        for gamma_iteration in range(local_parameter_iteration):
            log_phi = E_log_eta[:, term_ids].T + np.tile(sp.special.psi(gamma_values[doc_id, :]), (word_ids[doc_id].shape[0], 1))
            log_phi -= sp.misc.logsumexp(log_phi, axis=1)[:, np.newaxis]
            gamma_update = alpha_alpha + np.array(np.sum(np.exp(log_phi + np.log(np.repeat(term_counts, number_of_topics, axis=0).T)), axis=0))
            mean_change = np.mean(abs(gamma_update - gamma_values[doc_id, :]))
            gamma_values[doc_id, :] = gamma_update
            if mean_change <= local_parameter_converge_threshold:
                break

        # compute the alpha, gamma, and phi terms
        document_log_likelihood += sp.special.gammaln(np.sum(alpha_alpha)) - np.sum(sp.special.gammaln(alpha_alpha))
        document_log_likelihood += np.sum(sp.special.gammaln(gamma_values[doc_id, :])) - sp.special.gammaln(np.sum(gamma_values[doc_id, :]))
        document_log_likelihood -= np.sum(np.dot(term_counts, np.exp(log_phi) * log_phi))

# compute the p(w_{dn} | z_{dn}, \eta) terms, which will be cancelled during M-step
        words_log_likelihood += np.sum(np.exp(log_phi.T + np.log(term_counts)) * E_log_prob_eta[:, term_ids])      
        phi_sufficient_statistics[:, term_ids] += np.exp(log_phi + np.log(term_counts.transpose())).T
        
    if corpus == None:
        gamma = gamma_values
        return (document_log_likelihood, phi_sufficient_statistics, gamma)
    else:
        return (words_log_likelihood, gamma_values)

In [81]:
%%time
document_log_likelihood, phi_sufficient_statistics, gamma = e_step()

CPU times: user 47.7 s, sys: 310 ms, total: 48 s
Wall time: 48.2 s
