In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sklearn.datasets import fetch_20newsgroups
from time import time

In [3]:
n_docs = 2000
n_words = 1000
n_topics = 20

In [4]:
print("Loading dataset...")
t0 = time()
dataset = fetch_20newsgroups(shuffle=True, random_state=1,
                             remove=('headers', 'footers', 'quotes'))
data_samples = dataset.data[:n_docs]
print("done in %0.3fs." % (time() - t0))

Loading dataset...
done in 1.149s.


Vectorize documents and get Count Matrix
----------------------------------------------

In [5]:
from sklearn.feature_extraction.text import CountVectorizer

In [6]:
# Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
                                max_features=n_words,
                                stop_words='english')
t0 = time()
tf = tf_vectorizer.fit_transform(data_samples)
print("done in %0.3fs." % (time() - t0))

Extracting tf features for LDA...
done in 0.330s.


In [7]:
import numpy as np
from scipy.sparse import find

In [8]:
(I, J, K) = find(tf)

docs_idxs = [[] for _ in range(n_docs)]
docs_cnts = [[] for _ in range(n_docs)]

for r, c, n in zip(I, J, K) :
    docs_idxs[r].append(c)
    docs_cnts[r].append(n)

Initialize Parameters
------------------------

In [9]:
alpha = 0.1
beta = 0.01

In [10]:
from tqdm import tqdm_notebook
from scipy.special import psi

In [11]:
def update_doc(docs, lambda_zt) :
    lambda_psi = np.exp(psi(lambda_zt) - psi(lambda_zt.sum(1))[:, None])
    
    lambda_hat = np.zeros((n_topics, n_words))
    for d in docs :
        words = docs_idxs[d]
        cnts = docs_cnts[d]
        
        if len(words) > 0:
            topic_assign = np.zeros((n_topics, len(words)))
            gamma = np.random.gamma(100., 1/100., (n_topics))
            e_beta =  lambda_psi[:, words] #(K, N_d)
            
            for i in range(100) :  
                prev_gamma = gamma
                e_theta = np.exp(psi(gamma) - psi(gamma.sum()))

                topic_assign = e_theta[:, None] * e_beta 
                topic_assign = topic_assign / topic_assign.sum(0)

                gamma = alpha + np.dot(topic_assign , cnts)
                if np.mean(np.abs(prev_gamma - gamma)) < 0.001 :
                    break

            lambda_hat[:, words] += (topic_assign * cnts)

    return lambda_hat

In [12]:
offset = 2
kappa = 0.75

def update_lambda_zt(docs, t, lambda_zt) :
    lr_t = (t + offset) ** (-kappa)
    lambda_hat = update_doc(docs, lambda_zt) 
    lambda_hat = lambda_hat * n_docs/len(docs) + beta
    lambda_zt = (1 - lr_t) * (lambda_zt) + lr_t * (lambda_hat)
    
    return lambda_zt

def run_SVI(n_iters, bsize) :    
    t = 0
    lambda_zt = np.random.gamma(100., 1/100., (n_topics, n_words))
    
    for i in tqdm_notebook(range(n_iters)) :
        for d in range(0, n_docs, bsize) :
            batch = range(d, min(d+bsize, n_docs))
            lambda_zt = update_lambda_zt(batch, t, lambda_zt)
            t += 1
            
    return lambda_zt

In [13]:
def get_top_words(word_list, phi) :
    max_args = np.argsort(phi, axis=1)[:, -10:] 
    for t in range(n_topics) :
        print([word_list[i] for i in max_args[t]])

In [14]:
lambda_zt = run_SVI(100, 32)




In [15]:
word_list = tf_vectorizer.get_feature_names()
get_top_words(word_list, lambda_zt)

['level', 'feature', 'read', 'guess', 'board', 'memory', 'try', 'post', 'speed', 'time']
['images', 'interface', 'data', 'bit', 'video', 'pc', 'color', 'image', 'graphics', 'software']
['win', 'scsi', 'came', 'drives', 'went', 'true', 'disk', 'hard', 'didn', 'drive']
['sale', 'questions', 'size', 'sort', 'jews', 'points', 'ask', 'israel', 'card', '55']
['built', 'term', 'cost', 'launch', 'orbit', 'soon', 'nasa', 'data', 'bios', 'space']
['exactly', 'division', 'course', 'military', 'hi', 'feel', 'matter', 'reason', 'world', 'left']
['insurance', 'guy', 'rest', 'need', 'family', 'bike', 'health', 'states', 'group', 'told']
['phone', 'cs', 'send', 'received', 'machines', 'type', 'contact', 'mail', 'com', 'edu']
['things', 've', 'going', 'want', 'way', 'know', 'think', 'just', 'don', 'like']
['program', 'hiv', 'number', 'section', 'general', 'new', 'research', 'information', 'government', 'law']
['00', '25', '14', '93', '20', '12', '15', '16', '11', '10']
['help', 'program', 'dos', 'serve