In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from sklearn.datasets import fetch_20newsgroups
from time import time

In [12]:
n_words = 2000
n_topics = 20

In [13]:
print("Loading dataset...")
t0 = time()
dataset = fetch_20newsgroups(shuffle=True, random_state=1,
                             remove=('headers', 'footers', 'quotes'))

data_samples = dataset.data #[:n_docs]
n_docs = len(data_samples)

print("done in %0.3fs." % (time() - t0))

Loading dataset...
done in 1.189s.


Vectorize documents and get Count Matrix
----------------------------------------------

In [14]:
from sklearn.feature_extraction.text import CountVectorizer

In [15]:
# Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
                                max_features=n_words,
                                stop_words='english')
t0 = time()
tf = tf_vectorizer.fit_transform(data_samples)
print("done in %0.3fs." % (time() - t0))

Extracting tf features for LDA...
done in 1.530s.


In [16]:
import numpy as np
from scipy.sparse import find

In [17]:
(I, J, K) = find(tf)

docs_idxs = [[] for _ in range(n_docs)]
docs_cnts = [[] for _ in range(n_docs)]

for r, c, n in zip(I, J, K) :
    docs_idxs[r].append(c)
    docs_cnts[r].append(n)

Initialize Parameters
------------------------

In [18]:
alpha = 0.1
beta = 0.01

In [19]:
from tqdm import tqdm_notebook
from scipy.special import psi

In [20]:
def update_doc(docs, lambda_zt) :
    lambda_psi = np.exp(psi(lambda_zt) - psi(lambda_zt.sum(1))[:, None])
    
    lambda_hat = np.zeros((n_topics, n_words))
    for d in docs :
        words = docs_idxs[d]
        cnts = docs_cnts[d]
        
        if len(words) > 0:
            topic_assign = np.zeros((n_topics, len(words)))
            gamma = np.random.gamma(100., 1/100., (n_topics))
            e_beta =  lambda_psi[:, words] #(K, N_d)
            
            for i in range(100) :  
                prev_gamma = gamma
                e_theta = np.exp(psi(gamma) - psi(gamma.sum()))

                topic_assign = e_theta[:, None] * e_beta 
                topic_assign = topic_assign / topic_assign.sum(0)

                gamma = alpha + np.dot(topic_assign , cnts)
                if np.mean(np.abs(prev_gamma - gamma)) < 0.001 :
                    break

            lambda_hat[:, words] += (topic_assign * cnts)

    return lambda_hat

In [21]:
offset = 2
kappa = 0.75

def update_lambda_zt(docs, t, lambda_zt) :
    lr_t = (t + offset) ** (-kappa)
    lambda_hat = update_doc(docs, lambda_zt) 
    lambda_hat = lambda_hat * n_docs/len(docs) + beta
    lambda_zt = (1 - lr_t) * (lambda_zt) + lr_t * (lambda_hat)
    
    return lambda_zt

def run_SVI(n_iters, bsize) :    
    t = 0
    lambda_zt = np.random.gamma(100., 1/100., (n_topics, n_words))
    
    for i in tqdm_notebook(range(n_iters)) :
        for d in range(0, n_docs, bsize) :
            batch = range(d, min(d+bsize, n_docs))
            lambda_zt = update_lambda_zt(batch, t, lambda_zt)
            t += 1
            
    return lambda_zt

In [22]:
def get_top_words(word_list, phi) :
    max_args = np.argsort(phi, axis=1)[:, -10:] 
    for t in range(n_topics) :
        print([word_list[i] for i in max_args[t]])

In [23]:
lambda_zt = run_SVI(100, 32)




In [24]:
word_list = tf_vectorizer.get_feature_names()
get_top_words(word_list, lambda_zt)

['true', 'home', 'life', 'say', 'does', 'people', 'believe', 'pl', 'jesus', 'god']
['bus', 'drives', 'hard', 'faith', 'disk', 'db', 'scsi', 'card', 'drive', 'max']
['policy', 'number', 'research', 'stop', 'hockey', 'points', 'center', 'sort', 'bad', 'israel']
['anti', 'israeli', 'killed', 'children', 'turkish', 'armenians', 'jews', 'armenian', 'said', 'people']
['effect', 'data', 'mode', 'use', 'clipper', 'keys', 'encryption', 'bit', 'chip', 'key']
['la', 'st', 'cx', 'players', 'vs', 'period', '28', '26', 'season', 'game']
['road', 'water', 'bike', 'april', 'health', 'university', 'value', 'american', 'states', 'new']
['right', 'didn', 've', 'going', 'like', 'just', 'know', 'people', 'think', 'don']
['data', 'use', 'program', 'version', 'graphics', 'software', 'available', 'code', 'window', 'image']
['heard', 'won', 'old', 'little', 'new', 'actually', 'time', 'probably', 'good', 'year']
['files', 'list', 'information', 'send', 'email', 'ftp', 'mail', 'com', 'file', 'edu']
['try', 'let'