In [17]:
%load_ext autoreload
%autoreload 2
%matplotlib qt

In [18]:
from sklearn.datasets import fetch_20newsgroups
from time import time
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer

import numpy as np
from scipy.sparse import find

import pickle

In [3]:
n_words = 2000

In [4]:
print("Loading dataset...")
t0 = time()
dataset = fetch_20newsgroups(shuffle=False, random_state=1, subset='train',
                             remove=('headers', 'footers', 'quotes'))

print(len(dataset.data))
data_samples = dataset.data
n_docs = len(data_samples)

dataset_test = fetch_20newsgroups(shuffle=False, random_state=1, subset='test',
                             remove=('headers', 'footers', 'quotes'))

print(len(dataset_test.data))
data_test = dataset_test.data
n_docs_test = len(data_test)

print("done in %0.3fs." % (time() - t0))

Loading dataset...
11314
7532
done in 1.976s.


Vectorize documents and get Count Matrix
----------------------------------------------

In [5]:
# Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
                                max_features=n_words,
                                stop_words='english')
t0 = time()
tf = tf_vectorizer.fit_transform(data_samples)
print("done in %0.3fs." % (time() - t0))

Extracting tf features for LDA...
done in 1.616s.


In [6]:
tf_test = tf_vectorizer.transform(data_test)

In [7]:
(I, J, K) = find(tf)

docs_idxs = [[] for _ in range(n_docs)]
docs_cnts = [[] for _ in range(n_docs)]

for r, c, n in zip(I, J, K) :
    docs_idxs[r].append(c)
    docs_cnts[r].append(n)

In [8]:
(I, J, K) = find(tf_test)

docs_idxs_test = [[] for _ in range(n_docs_test)]
docs_cnts_test = [[] for _ in range(n_docs_test)]

for r, c, n in zip(I, J, K) :
    docs_idxs_test[r].append(c)
    docs_cnts_test[r].append(n)

Initialize Parameters
------------------------

In [9]:
alpha = 0.1
beta = 0.01
offset = 2
kappa = 0.75

In [10]:
from tqdm import tqdm_notebook
from scipy.special import psi

In [11]:
def get_top_words(word_list, phi, nums=False, tnum=10) :
    max_args = np.argsort(phi, axis=1)[:, -tnum:] 
    words = []
    for t in range(n_topics) :
        words.append([i if nums else word_list[i] for i in max_args[t]])
    return words

word_list = tf_vectorizer.get_feature_names()

In [12]:
from scipy.special import gammaln, logsumexp

def perp(lambda_zt, didxs, dcnts) :
    score = 0
    index = [i for i in range(len(didxs)) if len(didxs[i]) > 0]
    didxs = [didxs[i] for i in index]
    dcnts = [dcnts[i] for i in index]
    
    n_docs = len(didxs)
    lambda_hat, gamma = update_doc(range(n_docs), lambda_zt, didxs, dcnts)
    
    lambda_psi = psi(lambda_zt) - psi(lambda_zt.sum(1))[:, None]
    gamma_psi = psi(gamma) - psi(gamma.sum(1))[:, None]

    Nd = 0
    for j, d in enumerate(range(n_docs)) :
        words = didxs[d]
        cnts = dcnts[d]
        
        Nd += sum(cnts)
        
        e_theta = gamma_psi[j]
        e_beta =  lambda_psi[:, words]

        topic_assign = e_theta[:, None] + e_beta 
        topic_assign = logsumexp(topic_assign, axis=0)
        score += np.dot(topic_assign , cnts)
       
    score += np.sum((alpha - gamma) * gamma_psi)
    score += np.sum(gammaln(gamma) - gammaln(alpha))
    score += np.sum(gammaln(alpha * n_topics) - gammaln(gamma.sum(1))[:, np.newaxis])
    
    score += np.sum((beta - lambda_zt) * lambda_psi)
    score += np.sum(gammaln(lambda_zt) - gammaln(beta))
    score += np.sum(gammaln(beta * n_topics) - gammaln(lambda_zt.sum(1))[:, np.newaxis])
    
    score /= Nd
    return np.exp(-score)

In [13]:
def update_doc(docs, lambda_zt, docs_idxs, docs_cnts) :
    lambda_psi = np.exp(psi(lambda_zt) - psi(lambda_zt.sum(1))[:, None])
    
    lambda_hat = np.zeros((n_topics, n_words))
    gamma_ret = np.zeros((len(docs), n_topics))
    for j, d in enumerate(docs) :
        words = docs_idxs[d]
        cnts = docs_cnts[d]
        
        if len(words) > 0 :
            topic_assign = np.zeros((n_topics, len(words)))
            gamma = np.random.gamma(100., 1/100., (n_topics))
            e_beta =  lambda_psi[:, words] #(K, N_d)

            for i in range(100) :  
                prev_gamma = gamma
                e_theta = np.exp(psi(gamma) - psi(gamma.sum()))

                topic_assign = e_theta[:, None] * e_beta 
                topic_assign = topic_assign / topic_assign.sum(0)

                gamma = alpha + np.dot(topic_assign , cnts)
                if np.mean(np.abs(prev_gamma - gamma)) < 0.001 :
                    break

            gamma_ret[j, :] = gamma

            lambda_hat[:, words] += (topic_assign * cnts)

    return lambda_hat, gamma_ret

In [14]:
def update_lambda_zt(docs, t, lambda_zt) :
    lr_t = (t + offset) ** (-kappa)
    lambda_hat, gamma_ret = update_doc(docs, lambda_zt, docs_idxs, docs_cnts) 
    lambda_hat = lambda_hat * n_docs/len(docs) + beta
    lambda_zt = (1 - lr_t) * (lambda_zt) + lr_t * (lambda_hat)
    
    return lambda_zt

def run_SVI(n_iters, bsize, nt=10) :    
    global n_topics
    
    t = 0
    n_topics = nt
    lambda_zt = np.random.gamma(100., 1/100., (n_topics, n_words))
    
    perps = np.zeros(n_iters)
    perps_test = np.zeros(n_iters)
    top_words = []
    
    for i in tqdm_notebook(range(n_iters)) :
        for d in range(0, n_docs, bsize) :
            batch = range(d, min(d+bsize, n_docs))
            lambda_zt = update_lambda_zt(batch, t, lambda_zt)
            t += 1
        top_words.append(get_top_words(word_list, lambda_zt, tnum=50))
        perps[i] = perp(lambda_zt, docs_idxs, docs_cnts)
        perps_test[i] = perp(lambda_zt, docs_idxs_test, docs_cnts_test)
        
    return lambda_zt, perps, perps_test, top_words

In [None]:
n_topics = 10
lambda_zt, perps, perps_test, top_words = run_SVI(300, 32, nt=10)
#pickle.dump([lambda_zt, perps, perps_test, top_words], open("svi-10-300.p", "wb"))

In [6]:
n_topics = 10
lambda_zt, perps, perps_test, top_words = pickle.load(open("svi-10-300.p", "rb"))

In [85]:
plt.plot(perps, label="Training Set")
plt.xlabel("Epochs", fontsize=20)
plt.ylabel("Perplexity", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.legend(fontsize=18)
plt.tight_layout()
plt.savefig("perpsvi.pdf")
plt.show()

Plot Topic Distribution by True Topic
=====================================

In [16]:
didxs = docs_idxs
dcnts = docs_cnts
index = [i for i in range(len(didxs)) if len(didxs[i]) > 0]
didxs = [didxs[i] for i in index]
dcnts = [dcnts[i] for i in index]
n_docs = len(didxs)
lambda_hat, gamma = update_doc(range(n_docs), lambda_zt, didxs, dcnts)

In [25]:
gamma = gamma / gamma.sum(1)[:, np.newaxis]
gamma_max = gamma.argmax(1)
maptopic = [0, 1, 1, 1, 1, 1, 2, 3, 3, 4, 4, 5, 1, 6, 7, 0, 8, 8, 8, 0]

targets = dataset.target
targets_filt = [targets[i] for i in index]
targets_mapped = [maptopic[c] for c in targets_filt]

countmat = np.zeros((9, 10))
for a, b in zip(targets_mapped, gamma_max) :
    countmat[a, b] += 1
    
countmat = countmat / countmat.sum(1)[:, None]

sns.heatmap(countmat, annot=True, cmap="Greens")
plt.xlabel("Predicted Topics")
plt.ylabel("True Topics")
plt.show()

Plot Changes in Word Distribution
=================================

In [82]:
selected_topic = -5

topic_words_list = []
for t in range(10) :
    nt = []
    for i in range(len(top_words)) :
        nt.append(top_words[i][t])
    topic_words_list.append(nt)

dict_words = {}
for i in range(300) :
    for j, word in enumerate(topic_words_list[selected_topic][i]) :
        if word in dict_words :
            dict_words[word][i] = j
        else :
            dict_words[word] = [-5]*300
            dict_words[word][i] = j

In [83]:
fig = plt.figure(figsize=(10, 10))
for word in dict_words :
    plt.plot(dict_words[word], label=word)
    plt.text(i + 1, dict_words[word][-1], word)
plt.ylabel(r"Rank of the word (according to $\beta_{kv}$)", fontsize=20)
plt.xlabel("Epochs", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.tight_layout()
plt.savefig("svi-religion.pdf")
plt.show()

Print Top Words
===============

In [18]:
for i in range(10) :
    print(get_top_words(word_list, lambda_zt, tnum=10)[i])

['rights', 'armenians', 'war', 'israel', 'armenian', 'president', 'car', 'people', 'government', 'going']
['encryption', '13', 'software', 'disk', 'ftp', '145', 'files', '25', 'key', 'file']
['help', 'using', 've', 'thanks', 'windows', 'problem', 'need', 'like', 'does', 'use']
['really', 'said', 'year', 'time', 'good', 'think', 'like', 'just', 'don', 'know']
['16', 'db', '12', '11', '17', '20', '15', 'team', 'game', '10']
['right', 'think', 'say', 'believe', 'does', 'make', 'way', 'did', 'god', 'people']
['server', 'available', 'bible', 'university', 'list', 'send', 'information', 'mail', 'com', 'edu']
['package', 'output', 'pc', 'image', 'info', 'data', 'card', '00', 'program', 'space']
['bus', 'apple', 'mac', 'hard', 'driver', 'cx', 'power', 'screen', 'scsi', 'drive']
['0t', '1t', '34u', '1d9', 'pl', 'a86', 'b8f', 'g9v', 'max', 'ax']


Print Coherence
===============

In [None]:
def coherence(top_words) :
    tc = []
    tfidf = TfidfTransformer().fit_transform(tf).todense()
    for words in top_words :
        tk = np.zeros((len(words), len(words)))
        for i in range(len(words) - 1) :
            for j in range(i + 1, len(words)) :
                num = np.dot(tfidf[:, words[i]].T, tfidf[:, words[j]])[0, 0]
                denom = np.sum(tfidf[:, words[i]])
                tk[i, j] = np.log((num + 0.000001)/denom)
                tk[j, i] = tk[i, j]
        tc.append(tk)
    return tc

tp = get_top_words(word_list, lambda_zt, nums=True, tnum=1000)
c = coherence(tp)
pickle.dump(c, open("sv-10-300-cohenrence.p", "wb"))

In [21]:
c = pickle.load(open("sv-10-300-cohenrence.p", "rb"))

In [36]:
for i in range(10) :
    plt.imshow(c[-1][-100:, -100:])
plt.show()

In [50]:
from scipy.interpolate import interp1d

for i in range(10) :
    cscores = []
    for j in range(2, 100) :
        coh = c[i][-j:, -j:].sum()/(j * (j-1))
        cscores.append(coh)
    plt.plot(range(2, 100), cscores, label=i)
    plt.ylabel("Coherence Score", fontsize=20)
plt.xlabel("Epochs", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.legend(loc='upper right', ncol=4, fontsize=12)

plt.show()

In [31]:
c = pickle.load(open("sv-10-300-cohenrence.p", "rb"))

In [33]:
lsort = np.sort(lambda_zt, axis=1)
lsort = lsort / lsort.sum(1)[:, None]
for i in range(10) :
    coh = c[i][-20:, -20:].sum(0)/(20)
    plt.scatter(lsort[i, -20:], coh, label=i, s=10)
    
plt.xlim(0.0, 0.1)
plt.ylabel("Coherence Score", fontsize=20)
plt.xlabel(r"Probability of word in topic $\beta_{kv}$", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.legend(loc='lower right', ncol=3, fontsize=12, title="k")
plt.tight_layout()
plt.savefig("cohvsprobfull_svi.pdf")
plt.show()

In [35]:
lsort = np.sort(lambda_zt, axis=1)
lsort = lsort / lsort.sum(1)[:, None]
for i in range(10) :
    coh = c[i][-4:, -50:-4].sum(0)/4
    plt.scatter(lsort[i, -50:-4], coh, s=10, label=i)
    
plt.ylabel("Coherence Score", fontsize=20)
plt.xlabel(r"Probability of word in topic $\beta_{kv}$", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.legend(loc='lower right', ncol=2, fontsize=12, title="k")
plt.tight_layout()
plt.savefig("cohvsprob3_svi.pdf")
plt.show()

Plot Perplexity by Number of Topics
===================================

In [None]:
perp_by_ntopics = {}
perp_by_ntopics_test = {}
for nt in tqdm_notebook(range(2, 30, 3)) :
    lambda_zt, perps, perps_test, top_words = run_SVI(100, 32, nt=nt)
    perp_by_ntopics[nt] = perps
    perp_by_ntopics_test[nt] = perps_test

In [None]:
pickle.dump([perp_by_ntopics, perp_by_ntopics_test], open("svi-perps.p", "wb"))

In [65]:
perp_by_ntopics, perp_by_ntopics_test = pickle.load(open("svi-perps.p", "rb"))

In [68]:
plt.plot(perp_by_ntopics.keys(), [x[-1] for x in perp_by_ntopics.values()])
plt.xlabel("Number of Topics", fontsize=20)
plt.ylabel("Final Perplexity", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.tight_layout()
plt.savefig("perpbytopic_svi.pdf")
plt.show()

In [72]:
plt.plot(perp_by_ntopics_test.keys(), [x[-1] for x in perp_by_ntopics_test.values()])
#plt.plot(perp_by_ntopics.keys(), [x[-1] for x in perp_by_ntopics.values()])
plt.xlabel("Number of Topics", fontsize=20)
plt.ylabel("Final Perplexity", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.tight_layout()
plt.savefig("perpbytopic_test_svi.pdf")
plt.show()