# LDA with Variational Bayes



## train & test データロード



In [1]:
def load_train_test():
    """
    @return train list 学習用の文書集合
    @return test list テスト用の文書集合
    """
    read_dir = './data/ldcourpas/'
    train_doc_name = 'train_doclist.list'
    test_doc_name = 'test_doclist.list'
    
    with open(read_dir + train_doc_name, mode='rb') as f:
        train = pickle.load(f)
    with open(read_dir + test_doc_name, mode='rb') as f:
        test = pickle.load(f)
    
    return train, test

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
import pickle

train, test = load_train_test()
analyzer = lambda words: words
vect = CountVectorizer(max_features=10000, min_df=.01, max_df=.40, analyzer=analyzer)
X = vect.fit_transform(train)
id2word = {v: k for k, v in vect.vocabulary_.items()}

## LDA


In [2]:
import scipy.special as special


def _dirichlet_expectation_1d(arr):
    """
    calcureta E[log(theta)]. theta ~ Dir(theta|arr)
    """
    sum_arr = arr.sum()
    return special.psi(arr) - special.psi(sum_arr)

def _dirichlet_expectation_2d(arr):
    """
    calcurate E[log(theta)]. theta ~ Dir(theta|arr)
    """
    sum_arr_ax1 = arr.sum(axis=1).reshape(-1, 1)
    return special.psi(arr) - special.psi(sum_arr_ax1)
    
def mean_change(arr1, arr2):
    size = arr1.shape[0]
    return np.abs(arr1 - arr2).sum()/size


In [3]:
def _get_n_jobs(n_jobs):
    if n_jobs < 0:
        return max(cpu_count() - 1, 1)
    elif n_jobs == 0:
        ValueError('n_jobs == 0 doesn\'t meaning')
    else:
        return n_jobs
    

def gen_slices(length, n):
    """
    divide idx[0:length] into n
    """
    idx = np.arange(length)
    nums = [(length + i) // n for i in range(n)]
    
    start = 0
    for num in nums:
        end = start + num
        yield idx[start: end]
        start = end


def _update_doc_topic_distrb(X, max_iter, nd,
                             exp_topic_word,
                             alpha, beta,
                             mean_change_tol):
    
    n_docs, n_features = X.shape
    n_topics = exp_topic_word.shape[0]
    
    doc_topic_distrb = np.zeros((n_docs, n_topics))\
                        + nd.reshape(-1, 1)\
                        / n_topics
    exp_doc_topic = np.exp(_dirichlet_expectation_2d(doc_topic_distrb))
    sstats = np.zeros(exp_topic_word.shape)
    q = np.zeros((n_docs, n_features, n_topics))
    
    indices = X.indices
    indptr = X.indptr
    data = X.data
    
    for d in range(n_docs):
        ids = indices[indptr[d]:indptr[d + 1]]
        nds = data[indptr[d]:indptr[d + 1]]
        
        doc_topic_d = doc_topic_distrb[d]
        for _ in range(max_iter):
            last_d = doc_topic_d
            for idx_i in ids:
                qd = exp_topic_word[:, idx_i] * exp_doc_topic[d]
                norm_qd = qd.sum()
                q[d, idx_i] = qd / norm_qd
            doc_topic_d = q[d].sum(axis=0) + alpha
            exp_doc_topic[d] = np.exp(_dirichlet_expectation_1d(doc_topic_d))
            if mean_change(last_d, doc_topic_d) < mean_change_tol:
                break
        doc_topic_distrb[d] = doc_topic_d
        sstats += q[d].T
    
    return doc_topic_distrb, sstats

In [6]:
import logging
import time
import scipy.special as special
import scipy.stats as stats
from joblib import Parallel, delayed, cpu_count


class LDA(object):
    def __init__(self, max_iter=10, max_update_iter=100,
                 n_topics=10, print_every=20, n_jobs=-1,
                 verbose=1, mean_change_tol=1e-4, logger=None):
        self.max_iter = max_iter
        self.max_update_iter = max_update_iter
        self.n_topics = n_topics
        self.print_every = print_every
        self.n_jobs = n_jobs
        self.verbose = verbose
        self.mean_change_tol = mean_change_tol
        self.logger = logger
        self.random_state = np.random.mtrand._rand
    
    def _initialize(self, X):
        n_docs, n_features = X.shape
        
        self.alpha = np.ones(self.n_topics) * 0.1
        self.beta = 100 / n_docs
        
        init_gamma = 100.
        init_var = 1. / init_gamma
        self.topic_word_ = self.random_state.gamma(init_gamma,
                                                        init_var,
                                                        (self.n_topics,
                                                         n_features))
        self.exp_topic_word = np.exp(
            _dirichlet_expectation_2d(self.topic_word_))
        
        self.nd = np.zeros(n_docs)
        indices = X.indices
        indptr = X.indptr
        data = X.data
        for d in range(n_docs):
            nds = data[indptr[d]:indptr[d + 1]]
            self.nd[d] = nds.sum()
    
    def _update_dirichlet_param(self, doc_topic):
        """
        update alpha

        @param ave_ndz ndarray ndzのサンプル平均
        """
        e_ndk = doc_topic - self.alpha
        n_docs = e_ndk.shape[0]
        sum_alpha = self.alpha.sum()
        
        numes = (special.psi(e_ndk + self.alpha).sum(axis=0)\
                    - n_docs*special.psi(self.alpha))*self.alpha
        
        denom = special.psi(self.nd + sum_alpha).sum()\
                        - n_docs*special.psi(sum_alpha)
        
        self.alpha = numes / denom
    
    def _e_step(self, X, parallel=None):
        
        n_jobs = _get_n_jobs(self.n_jobs)
        if parallel is None:
            parallel = Parallel(n_jobs, verbose=max(0, self.verbose - 1))
        
        results = parallel(
            delayed(_update_doc_topic_distrb)(X[idx_slice, :],
                                              self.max_update_iter,
                                              self.nd[idx_slice],
                                              self.exp_topic_word,
                                              self.alpha,
                                              self.beta,
                                              self.mean_change_tol)
            for idx_slice in gen_slices(X.shape[0], n_jobs))
                                             
        doc_topics, sstats_list = zip(*results)
        doc_topic = np.vstack(doc_topics)
        
        sstats = np.zeros(self.exp_topic_word.shape)
        for suff_stats in sstats_list:
            sstats += suff_stats
        
        return doc_topic, sstats
    
    def _em_step(self, X):
        
        doc_topic, sstats = self._e_step(X)
        self.topic_word_ = self.beta + sstats
        self.exp_topic_word = np.exp(
            _dirichlet_expectation_2d(self.topic_word_))
        self._update_dirichlet_param(doc_topic)
    
    def fit(self, X):
        n_docs, n_features = X.shape
        self._initialize(X)
        
        if self.logger:
            self.logger.info('train start!')
        for s in range(self.max_iter):
            start = time.time()
            self._em_step(X)
            elapsed_time = time.time() - start
            
            if self.logger:
                self.logger.info('elapsed: {:.3f} [sec]'.format(elapsed_time))
            if (s+1) % self.print_every == 0:
                self.logger.info('{} iterations finished.'.format(s+1))
    
    def _sampling_phi(self):
        self.phi = np.zeros(self.exp_topic_word.shape)
        for k in range(self.n_topics):
            self.phi[k, :] = stats.dirichlet.rvs(self.topic_word_[k], size=1)[0]

    def print_topn_words(self, n, id2word):
        index = np.arange(n) + 1
        df = pd.DataFrame(data=[], index=index)
        for k in range(self.n_topics):
            idx_descend = self.topic_word_[k].argsort()[::-1]
            top_n = [id2word[idx] for idx in idx_descend[:n]]
            df['topic{}'.format(k)] = top_n
        display(df)
    
    def print_topn_pertopic(self, n=5, vocab=None):
        index_to_words = {v: k for k, v in vocab.items()}
        for k in range(self.n_topics):
            print('-----topic {}-----'.format(k))
            index_phi_k = self.topic_word_[k].argsort()[::-1]
            for print_num, v in enumerate(index_phi_k):
                if print_num >= n:
                    break
                
                print('{}, pdf:{}'.format(index_to_words[v],
                                          self.topic_word_[k, v]))


def main():
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger('LDA')
    lda = LDA(max_iter=5, max_update_iter=100, logger=logger)
    lda.fit(X)
    lda.print_topn_pertopic(n=10, vocab=vect.vocabulary_)

main()

INFO:LDA:train start!
INFO:LDA:elapsed: 96.909 [sec]
INFO:LDA:elapsed: 63.840 [sec]
INFO:LDA:elapsed: 48.654 [sec]
INFO:LDA:elapsed: 41.167 [sec]
INFO:LDA:elapsed: 36.075 [sec]


-----topic 0-----
自分, pdf:485.2424224462455
それ, pdf:428.71597595165167
さん, pdf:424.46462906114886
いい, pdf:407.4955474022321
今, pdf:394.6907466214214
もの, pdf:391.51075399945404
時, pdf:373.542639552101
何, pdf:370.8836047398745
私, pdf:357.575289847065
年, pdf:336.76383644733943
-----topic 1-----
MAX, pdf:141.36132271548607
関連リンク, pdf:140.00544861634972
MAXsmaxjponTwitter, pdf:138.3446791453985
エスマックス, pdf:138.13716596652216
S, pdf:138.05553606637218
執筆, pdf:128.98313968333116
発表, pdf:126.40365710133787
発売, pdf:108.65174656511633
搭載, pdf:101.75930245073198
Android, pdf:101.48749777367156
-----topic 2-----
対応, pdf:210.5648022613041
機能, pdf:201.62908220198878
ため, pdf:196.8616532282202
発売, pdf:194.12609095777674
登場, pdf:191.9183135844302
関連, pdf:189.48883857137642
搭載, pdf:177.81930655055547
これ, pdf:175.3500503928074
iPhone, pdf:167.28298986415294
製品, pdf:156.17040239028807
-----topic 3-----
多い, pdf:194.68104343491618
女性, pdf:192.85535757936117
もの, pdf:185.13542503086106
ため, pdf:180.89633030734