In [1]:
import lda.datasets as dataset_loader # for loading a test dataset
import lda as lda_lib

import numpy as np
from scipy.special import digamma

from matplotlib import pyplot as plt
from IPython.display import display, Markdown
from tqdm import tqdm

md = lambda *args: display(Markdown(*args))

def flatten(L): return [e for l in L for e in l]
def draw_cat_prop(p): return np.random.choice(p.shape[0], p = p/np.sum(p))

np.random.seed(12345)

In [2]:
d_to_text = dataset_loader.load_reuters_titles()
v_to_text = dataset_loader.load_reuters_vocab()
nW = dataset_loader.load_reuters()

print(nW.shape, np.shape(d_to_text), np.shape(v_to_text))


(395, 4258) (395,) (4258,)


In [3]:
D = nW.shape[0] # set D to a smaller constant (below) to test with less documents
n_topics = 10
V = nW.shape[1]

D,n_topics = 10, 4 # smaller setup

def _line_to_list_of_words(l):
    return flatten(([v]*l[v] for v in range(len(l))))
w = [_line_to_list_of_words(nW[d,:]) for d in range(D)]


# Simple Gibbs sampler (very basic)

In [4]:
def lda_gibbs(K=42, LOOPS=112, α=5., β=.1):
    if type(α) == type(1.):
        α = np.ones(K)*α
    if type(β) == type(1.):
        β = np.ones(V)*β
        
    α0 = np.sum(α)
    β0 = np.sum(β)
    
    md('... initializing')
    # initialize sampled variables
    z = [np.random.choice(K, len(w[d])) for d in range(D)]
    φ = np.zeros((K, V)) # we'll draw right away from the counts
    θ = np.zeros((D, K)) # we'll draw right away from the counts
    # initialize count tables
    c_dk = np.zeros((D, K))
    for d in range(D):
        c_dk[d,:] = np.bincount(z[d], minlength=K)
    c_kv = np.zeros((K, V))
    for d in range(D):
        for i in range(len(w[d])):
            c_kv[z[d][i], w[d][i]] += 1
    #print(np.sum(c_dk), np.sum(c_kv))
    
    md('... running gibbs sampler')
    for loop in tqdm(range(LOOPS)):
        for d in range(D):
            θ[d,:] = np.random.dirichlet(α[:] + c_dk[d,:])
        for k in range(K):
            φ[k,:] = np.random.dirichlet(β[:] + c_kv[k,:])
        for d in range(D):
            for i in range(len(w[d])):
                _wdi = w[d][i]
                _zdi = z[d][i]
                c_dk[d, _zdi] -= 1
                c_kv[_zdi, _wdi] -= 1
                z[d][i] = draw_cat_prop(φ[:,_wdi]*θ[d,:])
                _zdi = z[d][i]
                c_dk[d, _zdi] += 1
                c_kv[_zdi, _wdi] += 1
    
    md('We obtained the following topics:')
    #plt.imshow(φ)
    #plt.show()
    for k in range(K):
        txt = 'Topic '+str(k)+': '
        words = np.argsort(φ[k,:])[::-1]
        txt += ' / '.join([v_to_text[v] for v in words[:20]])
        md(txt)
    
    
lda_gibbs(n_topics)

... initializing

... running gibbs sampler

100%|██████████| 112/112 [00:07<00:00, 14.56it/s]


We obtained the following topics:

Topic 0: buckingham / largest / prince / reforms / committee / photo / made / american / baroque / confirmed / open / spared / family / comedy / hit / spokeswoman / reign / cost / radical / visiting

Topic 1: 1945 / tomorrow / white / religions / became / pair / raised / 2000 / wait / deaths / considered / several / taken / fitted / project / 1977 / gay / historic / settlement / feature

Topic 2: charles / royal / diana / divorce / take / queen / church / camilla / monday / british / intention / public / becomes / year / britain / 25 / interview / palace / dresden / denied

Topic 3: mother / teresa / sen / sister / heart / told / calcutta / home / doctors / saint / living / charity / missionaries / fever / respirator / dr / condition / 100 / earlier / birthday

### NB: Simple Gibbs sampler [timing: 04:46.00]

> Topic 0: political / minister / last / government / leader / president / party / former / church / million / michael / first / visit / years / romania / nation / support / country / christian / return

> Topic 1: charles / prince / family / church / public / royal / parker / british / diana / king / bowles / queen / marriage / years / britain / camilla / newspaper / throne / princess / couple

> Topic 2: pope / vatican / church / surgery / john / roman / paul / mass / trip / pontiff / world / hospital / rome / sunday / during / operation / poland / year / since / told

> Topic 3: u.s / harriman / france / clinton / paris / churchill / american / british / ambassador / died / first / late / film / president / winston / became / close / pamela / war / french

> Topic 4: people / years / world / church / war / ceremony / long / place / country / great / took / union / state / first / very / buried / south / several / communist / town

> Topic 5: yeltsin / president / russian / police / miami / kremlin / operation / power / say / russia / cunanan / heart / take / last / versace / away / several / tuesday / spokesman / chief

> Topic 6: n't / church / elvis / people / says / music / television / film / life / wright / catholic / while / bishop / little / women / told / tour / fans / father / every

> Topic 7: church / service / death / life / against / found / former / bernardin / home / funeral / us / thursday / died / cardinal / court / bishops / west / later / night / made

> Topic 8: mother / teresa / order / heart / work / missionaries / nuns / told / sister / charity / hospital / official / poor / successor / calcutta / home / india / last / nun / first

> Topic 9: city / germany / people / east / peace / against / prize / art / timor / simpson / award / museum / year / letter / years / told / capital / exhibition / group / culture


# Collapsed Gibbs sampler (faster "convergence")

In [5]:
def lda_collapsed_gibbs(K=42, LOOPS=112, α=5., β=.1):
    if type(α) == type(1.):
        α = np.ones(K)*α
    if type(β) == type(1.):
        β = np.ones(V)*β
        
    α0 = np.sum(α)
    β0 = np.sum(β)
    
    md('... initializing')
    # initialize sampled variables
    z = [np.random.choice(K, len(w[d])) for d in range(D)]
    # initialize count tables
    c_dk = np.zeros((D, K))
    for d in range(D):
        c_dk[d,:] = np.bincount(z[d], minlength=K)
    c_kv = np.zeros((K, V))
    for d in range(D):
        for i in range(len(w[d])):
            c_kv[z[d][i], w[d][i]] += 1
    c_k = np.sum(c_kv, axis=1)
    #print(np.sum(c_dk), np.sum(c_kv), np.sum(c_k))
    
    md('... running collapsed gibbs sampler')
    for loop in tqdm(range(LOOPS)):
        for d in range(D):
            for i in range(len(w[d])):
                _wdi = w[d][i]
                _zdi = z[d][i]
                # remove the point from the counts
                c_dk[d, _zdi] -= 1
                c_kv[_zdi, _wdi] -= 1
                c_k[_zdi] -= 1
                # sample the new value
                _zdi = draw_cat_prop((α[:] + c_dk[d,:]) * (β[_wdi] + c_kv[:,_wdi]) / (β0 + c_k[:]))
                z[d][i] = _zdi
                # update the counts
                c_dk[d, _zdi] += 1
                c_kv[_zdi, _wdi] += 1
                c_k[_zdi] += 1
                
    md('We obtained the following topics:')
    φ = c_kv + β[None,:]
    #plt.imshow(φ)
    #plt.show()
    for k in range(K):
        txt = 'Topic '+str(k)+': '
        words = np.argsort(φ[k,:])[::-1]
        txt += ' / '.join([v_to_text[v] for v in words[:20]])
        md(txt)

lda_collapsed_gibbs(n_topics)

... initializing

... running collapsed gibbs sampler

100%|██████████| 112/112 [00:09<00:00, 11.77it/s]


We obtained the following topics:

Topic 0: 1993 / found / part / ashes / christians / earned / mark / bishop / regional / sunday / got / call / takes / conversation / neither / photo / clear / william / raising / paul

Topic 1: mother / teresa / heart / calcutta / sen / sunday / order / sister / condition / home / respirator / hospital / nun / doctors / charity / prayers / nuns / tuesday / told / missionaries

Topic 2: dresden / million / church / first / state / once / princess / moment / project / interview / crypt / service / become / bishop / baroque / centre / wednesday / 1992 / 1945 / raid

Topic 3: charles / church / diana / royal / prince / queen / camilla / palace / family / monday / divorce / parker / bowles / daily / british / marriage / england / years / take / britain

### NB: Collapsed Gibbs sampler, [timing: 05:25.00]

> Topic 0: city / art / great / century / museum / exhibition / tour / first / music / history / cultural / left / time / capital / set / show / while / want / since / culture

> Topic 1: political / yeltsin / president / russian / russia / leader / minister / country / party / kremlin / tuesday / moscow / operation / communist / union / power / percent / soviet / heart / say

> Topic 2: pope / vatican / paul / world / mass / john / surgery / church / rome / pontiff / trip / sunday / year / since / roman / during / visit / poland / hospital / month

> Topic 3: against / film / germany / group / people / east / rights / last / prize / peace / award / letter / christian / human / international / magazine / french / timor / spokesman / country

> Topic 4: u.s / harriman / clinton / elvis / churchill / paris / president / france / ambassador / died / late / american / first / husband / age / winston / war / state / born / death

> Topic 5: police / life / service / family / national / simpson / miami / versace / first / funeral / cunanan / people / men / church / home / night / held / wednesday / say / star

> Topic 6: charles / prince / family / royal / king / diana / public / queen / church / bowles / parker / camilla / marriage / newspaper / princess / married / throne / british / britain / years

> Topic 7: years / million / government / world / people / former / british / three / war / west / year / town / law / four / south / went / john / sale / women / letters

> Topic 8: church / catholic / n't / years / told / bishop / cardinal / bernardin / last / died / father / life / son / michael / wright / romania / former / during / know / death

> Topic 9: mother / teresa / order / heart / work / hospital / charity / nuns / sister / home / calcutta / missionaries / roman / world / poor / successor / last / doctors / nun / peace


----

# Using the lda library
This library implements a Collapsed Gibbs sampler but with cython instead of python loops (and with a slightly better sparse representation).
It is much faster (due to cython) compared to our above implementation

In [6]:
L = lda_lib.LDA(n_topics, 112, 5., .1, refresh=100)
L.fit(nW)

φ = L.components_
for k in range(n_topics):
    txt = 'Topic '+str(k)+': '
    words = np.argsort(φ[k,:])[::-1]
    txt += ' / '.join([v_to_text[v] for v in words[:20]])
    md(txt)


INFO:lda:n_documents: 395
INFO:lda:vocab_size: 4258
INFO:lda:n_words: 84010
INFO:lda:n_topics: 4
INFO:lda:n_iter: 112
INFO:lda:<0> log likelihood: -802269
INFO:lda:<100> log likelihood: -683441
INFO:lda:<111> log likelihood: -683110


Topic 0: charles / king / british / harriman / u.s / prince / church / first / clinton / died / family / elvis / royal / diana / churchill / years / public / son / queen / marriage

Topic 1: city / people / film / police / against / life / germany / years / church / n't / made / show / french / home / director / international / music / national / simpson / year

Topic 2: pope / mother / teresa / catholic / church / vatican / order / world / hospital / roman / john / told / doctors / sunday / heart / surgery / last / peace / paul / mass

Topic 3: church / president / government / political / yeltsin / country / last / leader / russian / minister / former / war / under / russia / million / three / people / union / party / says

### NB: `lda` library (Collapsed Gibbs sampler) [timing: 00:01.66]

> Topic 0: yeltsin / president / russian / russia / bernardin / last / union / kremlin / moscow / orthodox / operation / communist / soviet / take / say / power / country / political / church / under

> Topic 1: mother / teresa / order / heart / work / hospital / told / charity / nuns / official / calcutta / missionaries / sister / home / last / election / world / poor / senior / successor

> Topic 2: harriman / u.s / clinton / war / churchill / paris / died / president / france / british / ambassador / american / first / became / winston / campaign / minister / state / party / husband

> Topic 3: city / million / century / years / music / used / since / first / art / made / exhibition / year / museum / off / culture / history / churches / cultural / capital / including

> Topic 4: pope / vatican / church / paul / john / mass / sunday / rome / world / during / surgery / roman / pontiff / trip / day / since / hospital / visit / poland / health

> Topic 5: church / service / police / south / funeral / family / simpson / miami / versace / found / cunanan / national / held / home / death / wednesday / friday / night / law / thursday

> Topic 6: elvis / film / germany / against / french / german / group / fans / west / letter / people / called / says / every / king / magazine / festival / france / concert / made

> Topic 7: charles / prince / king / public / royal / diana / family / queen / church / british / bowles / parker / camilla / britain / marriage / princess / throne / years / married / love

> Topic 8: n't / life / years / told / people / own / very / time / church / television / never / first / world / say / women / catholic / show / bishop / year / later

> Topic 9: government / political / former / church / leader / minister / east / years / last / country / peace / prize / people / catholic / rights / party / ceremony / award / president / united


# Now we show how to make ours faster
- with numba (we have to jump through a few hoops, but very good gain) (3.8s)
- then with an improved sparse representation (not much gain as we still loop in pure python) (5min 20s)
- then with both (even better gain) (2.55s)
- reminder: cython impl from the library (1.5s)

In [7]:
import numba

# Collapsed Gibbs sampler, with numba

In [8]:
def lda_collapsed_gibbs_numba(K=42, LOOPS=112, α=5., β=.1, ___globalW=w):
    if type(α) == type(1.):
        α = np.ones(K)*α
    if type(β) == type(1.):
        β = np.ones(V)*β
        
    α0 = np.sum(α)
    β0 = np.sum(β)
    
    md('... initializing')
    # Hoop 1: use special lists
    w = numba.typed.List()
    for _wd in ___globalW:
        w.append(numba.typed.List(_wd))
    
    # initialize sampled variables
    z = [np.random.choice(K, len(w[d])) for d in range(D)]
    z = numba.typed.List(z)
        
    # initialize count tables
    c_dk = np.zeros((D, K))
    for d in range(D):
        c_dk[d,:] = np.bincount(z[d], minlength=K)
    c_kv = np.zeros((K, V))
    for d in range(D):
        for i in range(len(w[d])):
            c_kv[z[d][i], w[d][i]] += 1
    c_k = np.sum(c_kv, axis=1)
    #print(np.sum(c_dk), np.sum(c_kv), np.sum(c_k))
    
    md('... running collapsed gibbs sampler')
    
    # based on https://github.com/numba/numba/issues/2539#issuecomment-507306369
    @numba.jit(nopython=True)
    def draw_cat_prop(p):
        p /= np.sum(p)
        # hoop 2: work around the missing p parameter to random.choice
        return np.searchsorted(np.cumsum(p), np.random.random(), side="right")

    @numba.jit(nopython=True)
    def loopit(c_dk, c_kv, c_k, w, z):
        for d in range(D):
            _wd = w[d]
            for i in range(len(_wd)):
                _wdi = _wd[i]
                _zdi = z[d][i]
                # remove the point from the counts
                c_dk[d, _zdi] -= 1
                c_kv[_zdi, _wdi] -= 1
                c_k[_zdi] -= 1
                # sample the new value
                _zdi = draw_cat_prop((α[:] + c_dk[d,:]) * (β[_wdi] + c_kv[:,_wdi]) / (β0 + c_k[:]))
                z[d][i] = _zdi
                # update the counts
                c_dk[d, _zdi] += 1
                c_kv[_zdi, _wdi] += 1
                c_k[_zdi] += 1

    
    for loop in tqdm(range(LOOPS)):
        loopit(c_dk, c_kv, c_k, w, z)
                
    md('We obtained the following topics:')
    φ = c_kv + β[None,:]
    #plt.imshow(φ)
    #plt.show()
    for k in range(K):
        txt = 'Topic '+str(k)+': '
        words = np.argsort(φ[k,:])[::-1]
        txt += ' / '.join([v_to_text[v] for v in words[:20]])
        md(txt)
    
lda_collapsed_gibbs_numba(n_topics)

... initializing

... running collapsed gibbs sampler

100%|██████████| 112/112 [00:00<00:00, 133.62it/s]


We obtained the following topics:

Topic 0: prince / palace / daily / royal / changes / reforms / committee / heir-to-the-throne / head / princess / ending / main / moves / monarchy / monarch / throne / changing / 1992 / marriages / supreme

Topic 1: mother / teresa / heart / calcutta / sen / order / sunday / sister / condition / home / respirator / told / hospital / doctors / prayers / tuesday / nun / charity / nuns / missionaries

Topic 2: charles / church / diana / queen / camilla / british / divorce / against / bowles / parker / britain / years / monday / million / take / buckingham / marriage / since / family / newspaper

Topic 3: dresden / next / service / centre / project / famous / spent / officials / crypt / moment / ashes / regional / baroque / news / again / divorced / inauguration / far / 1945 / ceremony

### NB: Numba for Collapsed Gibbs sampler [timing: 00:03.48]

> Topic 0: city / germany / german / christian / million / art / century / since / history / set / letter / exhibition / museum / year / nazi / capital / against / jews / international / government

> Topic 1: charles / prince / king / public / diana / royal / family / queen / parker / bowles / britain / british / camilla / marriage / church / years / princess / throne / married / newspaper

> Topic 2: mother / teresa / order / work / heart / charity / hospital / nuns / sister / told / calcutta / home / missionaries / world / successor / senior / roman / poor / last / official

> Topic 3: church / n't / say / very / television / while / bishop / years / world / women / father / time / long / go / never / three / held / come / wright / catholic

> Topic 4: yeltsin / political / president / russian / government / russia / minister / michael / country / kremlin / moscow / power / communist / soviet / leader / romania / party / take / operation / orthodox

> Topic 5: church / years / during / last / bernardin / life / against / catholic / john / people / death / national / french / cardinal / former / died / court / france / time / south

> Topic 6: elvis / film / life / people / music / simpson / first / fans / won / king / years / west / every / festival / tour / best / concert / stars / used / world

> Topic 7: u.s / harriman / clinton / churchill / war / paris / british / first / president / ambassador / france / american / late / party / monday / winston / former / minister / husband / state

> Topic 8: police / people / service / funeral / told / home / family / east / ceremony / miami / thursday / versace / cunanan / peace / church / spokesman / night / prize / city / friday

> Topic 9: pope / vatican / paul / during / church / surgery / john / mass / since / roman / trip / rome / pontiff / hospital / world / visit / doctors / sunday / health / poland



# Better representation for Collapsed Gibbs sampler (same speed, as it is still pure python)

In [9]:
def lda_collapsed_gibbs_goodrepr(K=42, LOOPS=112, α=5., β=.1):
    if type(α) == type(1.):
        α = np.ones(K)*α
    if type(β) == type(1.):
        β = np.ones(V)*β
        
    α0 = np.sum(α)
    β0 = np.sum(β)
    
    md('... initializing')
    allw = np.array([_wdi for _wd in w for _wdi in _wd])
    alld = np.array([d for d in range(D) for i in range(len(w[d]))])
    S = alld.shape[0] # total size across all the (considered) documents
    # initialize sampled variables
    z = np.random.choice(K, S)
    # initialize count tables
    c_dk = np.zeros((D, K))
    for d in range(D):
        c_dk[d,:] = np.bincount(z[alld==d], minlength=K)
    c_kv = np.zeros((K, V))
    for di in range(S):
        c_kv[z[di], allw[di]] += 1
    c_k = np.sum(c_kv, axis=1)
    #print(np.sum(c_dk), np.sum(c_kv), np.sum(c_k))
    
    md('... running collapsed gibbs sampler')
    for loop in tqdm(range(LOOPS)):
        for di in range(S):
            d = alld[di]
            _wdi = allw[di]
            _zdi = z[di]
            # remove the point from the counts
            c_dk[d, _zdi] -= 1
            c_kv[_zdi, _wdi] -= 1
            c_k[_zdi] -= 1
            # sample the new value
            _zdi = draw_cat_prop((α[:] + c_dk[d,:]) * (β[_wdi] + c_kv[:,_wdi]) / (β0 + c_k[:]))
            z[di] = _zdi
            # update the counts
            c_dk[d, _zdi] += 1
            c_kv[_zdi, _wdi] += 1
            c_k[_zdi] += 1
                
    md('We obtained the following topics:')
    φ = c_kv + β[None,:]
    #plt.imshow(φ)
    #plt.show()
    for k in range(K):
        txt = 'Topic '+str(k)+': '
        words = np.argsort(φ[k,:])[::-1]
        txt += ' / '.join([v_to_text[v] for v in words[:20]])
        md(txt)
    
    
lda_collapsed_gibbs_goodrepr(n_topics)

... initializing

... running collapsed gibbs sampler

100%|██████████| 112/112 [00:09<00:00, 11.53it/s]


We obtained the following topics:

Topic 0: mother / teresa / heart / calcutta / sunday / sen / order / condition / sister / home / respirator / hospital / doctors / nun / nuns / tuesday / charity / told / prayers / fever

Topic 1: charles / diana / royal / queen / prince / palace / camilla / take / church / family / divorce / monday / parker / british / daily / britain / bowles / spokeswoman / buckingham / marriage

Topic 2: later / strategic / telegraph / difficult / taken / latest / member / john / inspired / 20 / working / late / married / huge / woman / bar / hill / end / responded / wife

Topic 3: church / dresden / members / bishop / million / once / service / state / years / found / raid / days / 1945 / ashes / tabloid / ceremony / centre / britons / marks / unable

### NB: Better representation for Collapsed Gibbs sampler [timing: 05:33.00]

> Topic 0: years / city / million / world / set / simpson / time / century / art / museum / used / during / since / own / part / once / churches / first / music / great

> Topic 1: political / yeltsin / president / russian / russia / leader / country / last / union / kremlin / state / moscow / minister / percent / under / soviet / government / party / communist / power

> Topic 2: mother / teresa / order / heart / work / charity / official / hospital / told / nuns / world / home / calcutta / missionaries / sister / peace / last / poor / election / head

> Topic 3: church / died / service / people / south / funeral / former / local / government / ceremony / during / last / michael / first / visit / friday / romania / country / wednesday / white

> Topic 4: u.s / harriman / clinton / president / war / churchill / american / paris / british / ambassador / france / became / minister / home / party / winston / very / husband / state / former

> Topic 5: against / film / life / french / bernardin / year / years / rights / france / court / death / cardinal / west / festival / chicago / cancer / found / magazine / national / conservative

> Topic 6: charles / prince / diana / royal / family / king / public / queen / years / british / parker / bowles / camilla / marriage / britain / princess / throne / married / church / england

> Topic 7: church / n't / told / people / show / very / catholic / say / television / bishop / know / life / news / think / time / come / wright / want / came / love

> Topic 8: pope / vatican / john / paul / church / mass / roman / surgery / world / sunday / catholic / rome / trip / pontiff / hospital / three / left / day / since / poland

> Topic 9: elvis / police / germany / city / german / called / international / fans / miami / versace / cunanan / thursday / death / off / jews / nazi / king / letter / concert / states



# Numba+Representation for Collapsed Gibbs sampler

In [10]:
def lda_collapsed_gibbs_goodrepr_numba(K=42, LOOPS=112, α=5., β=.1):
    if type(α) == type(1.):
        α = np.ones(K)*α
    if type(β) == type(1.):
        β = np.ones(V)*β
        
    α0 = np.sum(α)
    β0 = np.sum(β)
    
    md('... initializing')
    allw = np.array([_wdi for _wd in w for _wdi in _wd])
    alld = np.array([d for d in range(D) for i in range(len(w[d]))])
    S = alld.shape[0] # total size across all the (considered) documents
    # initialize sampled variables
    z = np.random.choice(K, S)
    # initialize count tables
    c_dk = np.zeros((D, K))
    for d in range(D):
        c_dk[d,:] = np.bincount(z[alld==d], minlength=K)
    c_kv = np.zeros((K, V))
    for di in range(S):
        c_kv[z[di], allw[di]] += 1
    c_k = np.sum(c_kv, axis=1)
    #print(np.sum(c_dk), np.sum(c_kv), np.sum(c_k))
    
    md('... running collapsed gibbs sampler')
    # based on https://github.com/numba/numba/issues/2539#issuecomment-507306369
    @numba.jit(nopython=True)
    def draw_cat_prop(p):
        p /= np.sum(p)
        # hoop 2: work around the missing p parameter to random.choice
        return np.searchsorted(np.cumsum(p), np.random.random(), side="right")

    @numba.jit(nopython=True)
    def loopit(c_dk, c_kv, c_k, alld, allw, z):
        for di in range(S):
            d = alld[di]
            _wdi = allw[di]
            _zdi = z[di]
            # remove the point from the counts
            c_dk[d, _zdi] -= 1
            c_kv[_zdi, _wdi] -= 1
            c_k[_zdi] -= 1
            # sample the new value
            _zdi = draw_cat_prop((α[:] + c_dk[d,:]) * (β[_wdi] + c_kv[:,_wdi]) / (β0 + c_k[:]))
            z[di] = _zdi
            # update the counts
            c_dk[d, _zdi] += 1
            c_kv[_zdi, _wdi] += 1
            c_k[_zdi] += 1

    #loopit(c_dk, c_kv, c_k, alld, allw, z, LOOPS)
    for loop in tqdm(range(LOOPS)):
        loopit(c_dk, c_kv, c_k, alld, allw, z)
                
    md('We obtained the following topics:')
    φ = c_kv + β[None,:]
    #plt.imshow(φ)
    #plt.show()
    for k in range(K):
        txt = 'Topic '+str(k)+': '
        words = np.argsort(φ[k,:])[::-1]
        txt += ' / '.join([v_to_text[v] for v in words[:20]])
        md(txt)
    
    
lda_collapsed_gibbs_goodrepr_numba(n_topics)

... initializing

... running collapsed gibbs sampler

100%|██████████| 112/112 [00:00<00:00, 253.27it/s]


We obtained the following topics:

Topic 0: mother / teresa / heart / calcutta / sunday / sen / order / sister / condition / home / hospital / respirator / nun / doctors / told / tuesday / charity / prayers / nuns / fever

Topic 1: church / dresden / million / love / part / service / main / high / temperature / ceremony / baroque / far / inauguration / centre / project / john / ashes / allowed / 1945 / crypt

Topic 2: charles / diana / church / queen / royal / prince / camilla / palace / bowles / parker / take / british / family / divorce / britain / england / buckingham / daily / spokeswoman / marriage

Topic 3: monday / committee / member / members / marks / cannot / agreed / very / might / showed / pay / strategic / minister / 1949 / marry / governor / head / century / moment / raise

### NB: Numba+representation for Collapsed Gibbs sampler [timing: 00:02.40]

> Topic 0: political / city / russian / country / leader / state / minister / union / president / russia / since / moscow / part / capital / year / christian / time / power / soviet / percent

> Topic 1: church / catholic / told / bishop / president / bernardin / people / peace / ceremony / cardinal / east / life / prize / former / death / during / wright / last / years / roman

> Topic 2: pope / vatican / yeltsin / surgery / operation / paul / mass / church / health / doctors / rome / hospital / during / pontiff / trip / world / since / last / tuesday / sunday

> Topic 3: charles / harriman / prince / u.s / clinton / british / royal / diana / churchill / queen / marriage / public / parker / bowles / camilla / family / ambassador / princess / paris / england

> Topic 4: war / first / last / million / government / century / people / years / michael / church / visit / exhibition / museum / romania / place / put / small / art / five / three

> Topic 5: against / film / germany / french / german / france / group / west / letter / country / festival / book / people / nazi / american / director / says / magazine / paris / rights

> Topic 6: n't / world / television / year / church / people / until / while / time / never / day / several / few / own / south / think / among / off / end / days

> Topic 7: mother / teresa / order / heart / charity / hospital / work / election / nuns / calcutta / sister / missionaries / told / official / religious / home / poor / head / senior / nun

> Topic 8: years / elvis / first / life / king / death / simpson / died / fans / father / left / every / world / women / year / black / three / music / children / popular

> Topic 9: family / police / service / church / funeral / former / home / thursday / miami / versace / very / cunanan / york / died / night / city / us / told / kennedy / men


# Now we try with some Variational Inference

# Collapsed Variational Inference, from intuitive (wrong?) derivations

In [11]:
def lda_collapsed_vi(K=42, LOOPS=112, α=5., β=.1):
    if type(α) == type(1.):
        α = np.ones(K)*α
    if type(β) == type(1.):
        β = np.ones(V)*β
        
    α0 = np.sum(α)
    β0 = np.sum(β)
    
    md('... initializing')
    # initialize variational parameters
    ν = [np.random.dirichlet(np.ones(K), len(w[d])) for d in range(D)]

    
    def expdigamma_big(x):
        x = x - 0.5
        return x + 1/(24*x) - 37/(5760*x**3)

    def expdigamma(x):
        return expdigamma_big(x+1) / np.exp(1/x)
        #return expdigamma_big(x+2) / np.exp(1/x + 1/(x+1))
    
    md('... running collapsed variational inference')
    for loop in tqdm(range(LOOPS)):
        νnew = []
        # precompute what can be
        sumν_kv = np.zeros((K, V))
        for d in range(D):
            _wd = w[d]
            Nd = len(_wd)
            νd = ν[d] # (Nd, K)
            for i in range(Nd):
                _wdi = _wd[i]
                sumν_kv[:,_wdi] += νd[i, :]
        sumν_k = np.sum(sumν_kv, axis=1)
        # compute the actual new ν
        for d in range(D):
            _wd = w[d]
            Nd = len(_wd)
            νd = ν[d]                  # (Nd, K)
            νnewd = np.zeros((Nd, K))
            νnew.append(νnewd)
            ψ0 = digamma(α0 + Nd - 1)
            αsumνd_k = α + np.sum(νd, axis=0)
            for i in range(Nd):
                _wdi = _wd[i]
                p = ( expdigamma(β[_wdi] + sumν_kv[:,_wdi] - νd[i,:])
                    / expdigamma(β0 + sumν_k - νd[i,:])
                    * expdigamma(αsumνd_k - νd[i,:])
                    #/ expdigamma(α0 + Nd - 1)
                )
                #p = digamma(β[_wdi] + sumν_kv[:,_wdi] - νd[i,:]) - digamma(β0 + sumν_k - νd[i,:]) + digamma(αsumνd_k - νd[i,:]) - ψ0
                #p = np.exp(p)
                p = p / np.sum(p)
                νnewd[i,:] = p
        ν = νnew
                
    md('We obtained the following topics:')
    φ = sumν_kv + β[None,:]
    print(φ.shape)
    #plt.imshow(φ)
    #plt.show()
    for k in range(K):
        txt = 'Topic '+str(k)+': '
        words = np.argsort(φ[k,:])[::-1]
        txt += ' / '.join([v_to_text[v] for v in words[:20]])
        md(txt)
    
lda_collapsed_vi(n_topics)

... initializing

... running collapsed variational inference

100%|██████████| 112/112 [00:12<00:00,  9.31it/s]


We obtained the following topics:

(4, 4258)


Topic 0: mother / teresa / heart / calcutta / sunday / sen / order / sister / condition / home / doctors / nun / hospital / told / charity / prayers / nuns / tuesday / missionaries / fever

Topic 1: charles / church / diana / queen / prince / royal / years / camilla / palace / monday / against / take / bowles / parker / divorce / family / daily / dresden / million / britain

Topic 2: respirator / people / saint / house / world / thousands / pray / admitted / roman / life / birthday / dr / long / days / irregular / several / blessed / hope / tomorrow / problems

Topic 3: british / year / praying / n't / age / nada / strength / buckingham / under / spokeswoman / reporters / marks / reforms / committee / changes / showing / public / might / newspaper / help

### NB: Hacky Variational Inference [timing: 00:07.26]

> Topic 0: city / germany / german / century / art / war / capital / west / union / exhibition / museum / soviet / letter / churches / nazi / says / international / great / jews / world

> Topic 1: people / political / east / peace / rights / group / went / prize / country / against / award / economic / last / saturday / president / told / campaign / letters / timor / human

> Topic 2: charles / king / prince / elvis / diana / royal / family / queen / public / parker / bowles / camilla / marriage / princess / married / throne / music / years / love / first

> Topic 3: church / n't / film / people / against / south / years / women / catholic / french / life / know / won / never / bishop / wright / father / very / made / few

> Topic 4: yeltsin / president / doctors / operation / russian / last / since / surgery / leader / hospital / tuesday / russia / kremlin / news / month / monday / three / heart / minister / statement

> Topic 5: mother / teresa / order / heart / work / charity / nuns / sister / home / calcutta / told / missionaries / head / hospital / official / election / poor / india / last / members

> Topic 6: former / million / party / year / years / leader / time / political / simpson / made / past / television / months / country / three / kennedy / parliament / sale / court / vote

> Topic 7: church / government / british / ceremony / last / minister / michael / former / law / during / public / years / officials / newspaper / romania / become / visit / while / opinion / friday

> Topic 8: harriman / u.s / clinton / churchill / paris / france / died / president / american / ambassador / british / police / war / late / service / home / funeral / death / became / family

> Topic 9: pope / church / vatican / john / world / during / catholic / paul / roman / mass / cardinal / left / sunday / bernardin / life / day / rome / years / trip / pontiff


# Collapsed Variational Inference, from intuitive (wrong?) derivations + repr+numba

In [12]:
def lda_collapsed_vi_numba(K=42, LOOPS=112, α=5., β=.1):
    if type(α) == type(1.):
        α = np.ones(K)*α
    if type(β) == type(1.):
        β = np.ones(V)*β
        
    α0 = np.sum(α)
    β0 = np.sum(β)
    
    md('... initializing')
    allw = np.array([_wdi for _wd in w for _wdi in _wd])
    alld = np.array([d for d in range(D) for i in range(len(w[d]))])
    S = alld.shape[0] # total size across all the (considered) documents

    # initialize variational parameters
    ν = np.random.dirichlet(np.ones(K), S)

    @numba.jit(nopython=True)
    def expdigamma_big(x):
        x = x - 0.5
        return x + 1/(24*x) #- 37/(5760*x**3)

    @numba.jit(nopython=True)
    def expdigamma(x):
        return expdigamma_big(x+1) / np.exp(1/x)
        return expdigamma_big(x+2) / np.exp(1/x + 1/(x+1))
    
    @numba.jit(nopython=True)
    def loopit(ν, alld, allw):
        νnew = np.zeros(ν.shape) # (S, K)
        # precompute what can be
        sumν_kv = np.zeros((K, V))
        sumν_dk = np.zeros((D, K))
        N = np.zeros(D)
        for di in range(S):
            d = alld[di]
            N[d] += 1
            _wdi = allw[di]
            sumν_kv[:,_wdi] += ν[di, :]
            sumν_dk[d,:] += ν[di, :]

        sumν_k = np.sum(sumν_kv, axis=1)
        
        for di in range(S):
            d = alld[di]
            Nd = N[d]
            _wdi = allw[di]
            p = ( expdigamma(β[_wdi] + sumν_kv[:,_wdi] - ν[di,:])
                / expdigamma(β0 + sumν_k - ν[di,:])
                * expdigamma(α + sumν_dk[d, :] - ν[di,:])
                #/ expdigamma(α0 + Nd - 1)
                )
            #             _zdi = draw_cat_prop(   (α[:] + c_dk[d,:]) * (β[_wdi] + c_kv[:,_wdi]) / (β0 + c_k[:]))
            #p = np.exp(p)
            p = p / np.sum(p)
            νnew[di,:] = p
            
        return νnew, sumν_kv


    md('... running collapsed variational inference')
    for loop in tqdm(range(LOOPS)):
        ν, sumν_kv = loopit(ν, alld, allw)
                
    md('We obtained the following topics:')
    φ = sumν_kv + β[None,:]
    print(φ.shape)
    #plt.imshow(φ)
    #plt.show()
    for k in range(K):
        txt = 'Topic '+str(k)+': '
        words = np.argsort(φ[k,:])[::-1]
        txt += ' / '.join([v_to_text[v] for v in words[:20]])
        md(txt)
    
lda_collapsed_vi_numba(n_topics)

... initializing

... running collapsed variational inference

100%|██████████| 112/112 [00:02<00:00, 46.65it/s]


We obtained the following topics:

(4, 4258)


Topic 0: charles / church / diana / queen / prince / royal / years / camilla / palace / take / parker / bowles / divorce / british / daily / britain / dresden / against / marriage / england

Topic 1: mother / teresa / heart / calcutta / sunday / sen / home / respirator / nun / hospital / doctors / nuns / tuesday / fever / saint / people / catholic / nursing / house / world

Topic 2: family / special / million / age / princess / reporters / tomorrow / changes / reforms / wales / head / first / bishop / services / help / died / anything / albanian-born / evening / state

Topic 3: order / sister / condition / told / prayers / charity / missionaries / woodlands / poor / peace / day / remained / vomiting / senior / nobel / slightly / dr / known / god / including

### NB: Numba+representation for Hacky Variational Inference [timing: 00:18.30]

> Topic 0: church / police / service / family / south / told / father / died / funeral / friday / us / national / home / miami / thursday / versace / went / body / held / cunanan

> Topic 1: u.s / harriman / british / prince / clinton / royal / churchill / party / war / president / minister / britain / paris / american / ambassador / prime / son / france / died / became

> Topic 2: charles / public / diana / queen / church / bowles / parker / camilla / newspaper / love / media / simpson / family / couple / marriage / divorce / woman / together / years / princess

> Topic 3: pope / church / vatican / john / catholic / paul / mass / during / cardinal / east / roman / surgery / world / bernardin / rome / pontiff / peace / trip / years / left

> Topic 4: mother / teresa / order / hospital / heart / doctors / work / charity / last / election / told / nuns / sister / official / calcutta / tuesday / missionaries / sunday / successor / since

> Topic 5: church / leader / people / years / three / time / say / last / including / reports / week / exhibition / saying / end / himself / former / go / under / around / although

> Topic 6: elvis / first / people / life / saturday / n't / told / fans / king / death / condition / age / every / mark / music / live / politics / called / own / lives

> Topic 7: president / yeltsin / political / russian / russia / country / michael / operation / kremlin / government / moscow / take / orthodox / power / communist / union / return / economic / museum / europe

> Topic 8: against / film / group / germany / french / people / france / show / director / rights / award / west / festival / art / made / country / german / international / magazine / conservative

> Topic 9: city / years / million / year / people / king / great / while / government / world / until / capital / century / old / southern / tour / past / set / made / air


----

----

# Exponential of digamma/ψ/F approximation (checking its quality)

In [13]:
def expdigamma_big(x):
    x = x - 0.5
    return x + 1/(24*x) - 37/(5760*x**3)

def expdigamma(x):
    #return expdigamma_big(x+1) / np.exp(1/x)
    return expdigamma_big(x+2) / np.exp(1/x + 1/(x+1))
    sub = 1
    while x <= 2:
        sub *= np.exp(1/x)
        x = x+1
    return expdigamma_big(x) / sub

x = np.linspace(.1, 10, 300)
ytrue = np.exp(digamma(x))
yapprox = np.array([expdigamma(x) for x in x])

###plt.plot(x, ytrue, label="true")
###plt.plot(x, expdigamma(x), label="approx")

#plt.plot(x, (ytrue - yapprox) / (ytrue + yapprox) * 2, label="rel ")
#plt.legend()