# L3: Topic Models
### 732A92/TDDE16 Text Mining
Måns Magnusson

The purpose of this lab is to implement the standard Gibbs sampling algorithm for Latent Dirichlet Allocation in Python. You will be supplied starter code, a smaller corpus with State of the Union addresses for the period 1975 to 2000 by paragraph and a list with English stop words. The code is implemented as a class, `LDAGibbs`, where you are expected to replace central parts of the code with your own implementations.

### 1. Sampling

Implement the basic collapsed Gibbs sampling algorithm for Latent Dirichlet Allocation. Use the starter code and add the components that is missing (the sampler part). We use the fact that 

$$p(z_{i}=k)\propto\left(\alpha+n_{d,k}^{(d)}\right)\frac{\left(\beta+n_{k,w_{i}}^{(w)}\right)}{\sum^{V}\left(\beta+n_{k,w_{i}}^{(w)}\right)}=\left(\alpha+n_{d,k}^{(d)}\right)\frac{\left(\beta+n_{k,w_{i}}^{(w)}\right)}{V\beta+n_{k}}$$

to simplify computations, where $K$ is the number of topics, $V$ is the vocabulary size and $D$ is the number of documents. $\mathbf{n}^{(d)}$ is a count matrix of size $D\times K$ with the number of topic indicators by document, $d$, and topic $k$, $\mathbf{n}^{(w)}$ is a count matrix of size $K\times V$ with the number of topic indicators by topic, $k$, and word type, $w$. $\mathbf{n}$ is a topic indicator count vector of length $K$ that contain the number of topic indicators in each topic. The detailed algorithm can be found below:

__Data:__ tokenized corpus $\mathbf{w}$, priors $\alpha, \beta$ <br>
__Result:__ topic indicators $\mathbf{z}$

Init topic indicators $\mathbf{z}$ randomly per token<br>
Init topic probability vector $\mathbf{p}$<br>
Init $\mathbf{n}^{w}$, the topic type count matrix of size ($K \times V$) with respect to $\mathbf{z}$<br>
Init $\mathbf{n}^{d}$, the document topic count matrix of size ($D \times K$) with respect to $\mathbf{z}$<br>
Init $\mathbf{n}$, the topic count vector of length ($K$) with respect to $\mathbf{z}$<br>

for $g \leftarrow 1$ __to__ _num_\__iterations_ __do__<br>
&emsp;&emsp;// Iterate over all tokens<br>
&emsp;&emsp;for $i \leftarrow 1$ __to__ $N$ __do__<br>
&emsp;&emsp;&emsp;&emsp;// Remove current topic indicator $z_i$ from $\mathbf{n}^{w}$, $\mathbf{n}^{d}$ and $\mathbf{n}$<br>
&emsp;&emsp;&emsp;&emsp;$n^{(w)}_{z_i,w_i}$ -= 1, $n^{(d)}_{d_i,z_i}$ -= 1, $n_{z_i}$ -= 1<br>
&emsp;&emsp;&emsp;&emsp;for $k \leftarrow 1$ __to__ $K$ __do__<br>
&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;// Compute the unnormalized probability of each topic indicator<br>
&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;$\mathbf{p}_k \leftarrow \left(\alpha+n_{d,k}^{(d)}\right)\frac{\left(\beta+n_{k,w_{i}}^{(w)}\right)}{\left(V\beta+n_{k}\right)}$<br>
&emsp;&emsp;&emsp;&emsp;__end__<br>
&emsp;&emsp;&emsp;&emsp;// Sample the topic indicator<br>
&emsp;&emsp;&emsp;&emsp;$z_i \leftarrow $ Categorical($\mathbf{p}$)<br>
&emsp;&emsp;&emsp;&emsp;// Add the new topic indicator $z_i$ to $\mathbf{n}^{w}$, $\mathbf{n}^{d}$ and $\mathbf{n}$<br>
&emsp;&emsp;&emsp;&emsp;$n^{(w)}_{z_i,w_i}$ += 1, $n^{(d)}_{d_i,z_i}$ += 1, $n_{z_i}$ += 1<br>
&emsp;&emsp;__end__<br>
__end__

For a complete derivation of the collapsed Gibbs sampler for LDA, see https://lingpipe.files.wordpress.com/2010/07/lda3.pdf.

In [41]:
import numpy, random, scipy.special
from tm3 import LDAGibbs
import matplotlib.pyplot as plt

class MyGibbs(LDAGibbs):
        
    def __init__(self, num_topics, docs_file_name, stop_list_file_name = None):
        self.num_topics = num_topics
        self.num_docs = 0
        self.docs = []
        ## Prepare set of stop words
        self.stop_words = set()
        if stop_list_file_name != None:
            with open(stop_list_file_name) as f:
                for line in f:
                    word = line.rstrip()
                    self.stop_words.add(word)
        self.read_documents(docs_file_name)
        self.initialize_matrices()
        self.total_tokens = sum(self.doc_totals)
            
    def read_documents(self, filename):
        """Reads documents from a file, filters stop words and initializes
        the vocabulary. Also converts tokens to integer term IDs."""
        self.vocab = []
        self.vocab_ids = {}
        with open(filename) as f:
            for line in f:
                line = line.replace(".", " ").replace(",", " ").lower()
                self.num_docs += 1
                tokens = []
                for w in line.split():
                    if not w in self.stop_words:
                        if w in self.vocab_ids:
                            tokens.append(self.vocab_ids[w])
                        else:
                            term_id = len(self.vocab)
                            self.vocab.append(w)
                            self.vocab_ids[w] = term_id
                            tokens.append(term_id)
                self.docs.append({ 'tokens': tokens })
        self.num_terms = len(self.vocab)
        print("Read {} documents with a total of {} terms".format(self.num_docs, self.num_terms))
        
    def initialize_matrices(self):
        """Initializes numpy arrays for the matrix computations performed
        by the sampler during the MCMC process."""
        ## Set up numpy matrices
        self.term_topics = numpy.zeros((self.num_terms, self.num_topics)) # n^w
        self.doc_topics = numpy.zeros((self.num_docs, self.num_topics)) # n^d
        self.topic_totals = numpy.zeros(self.num_topics) # n
        self.doc_totals = numpy.zeros(self.num_docs)
        ## Initialize topics randomly
        for doc_id in range(self.num_docs):
            doc = self.docs[doc_id]
            ## Create an array of random topic assignments
            doc['topics'] = [random.randrange(self.num_topics) for token in doc['tokens']]
            ## Construct the initial summary statistics
            doc_length = len(doc['tokens'])
            for token, topic in zip(doc['tokens'], doc['topics']):
                self.term_topics[token][topic] += 1 # n_wk
                self.doc_topics[doc_id][topic] += 1 # n_dk
                self.topic_totals[topic] += 1       # n_k
                self.doc_totals[doc_id] += 1
        ## Printout to check that everything is coherent
        #print(sum(sum(self.doc_topics)))
        #print(sum(sum(self.term_topics)))
        #print(sum(self.topic_totals))
        #print(sum(self.doc_totals))
        
    def run(self, num_iterations = 50, alpha = 0.1, beta = 0.01):
        self.logprobs = []
        for iteration in range(num_iterations): #iteration = 0
            self.make_draw(alpha, beta)
            logprob = self.compute_logprob(alpha, beta)
            self.logprobs.append(logprob)
            print("iteration {}, {}".format(iteration, logprob))
            
    def make_draw(self, alpha, beta):
        ## TODO: implement this function for exercise 1
    
        for doc_id in range(self.num_docs):
            doc = self.docs[doc_id]
            if not doc or not doc['tokens'] or not doc['topics']:
                continue
                #print("doc is empty")      
            prob_k = numpy.zeros(self.num_topics)
            V = len(self.vocab)
            for token, topic in zip(doc['tokens'], doc['topics']):
                self.term_topics[token][topic] -= 1 # n_wk
                self.doc_topics[doc_id][topic] -= 1 # n_dk
                self.topic_totals[topic]-= 1       # n_k
                ndk = max(0,self.doc_topics[doc_id][topic])
                nkw = max(0,self.term_topics[token][topic])
                nk = max(0,self.topic_totals[topic])
                #for k in range(self.num_topics):
                #print("token", token)
                #print("topic", topic)
                #print("denominator:", V*beta + nk)
                #print("numerator:", (alpha + ndk)*(beta + nkw))
                prob_k[topic] = (alpha + ndk)*(beta + nkw)/(V*beta + nk)
            #print("num topics: ",self.num_topics)
            #print("probs: ", prob_k)            
            num = int(self.num_topics) 
            num = range(num)
            prob_k = prob_k / sum(prob_k)
            print("prob_k ", prob_k)
            #print(len(num))
            z = int(numpy.random.choice(a=num,size=1,p=prob_k))
            self.term_topics[token][z] += 1 # n_wk
            self.doc_topics[doc_id][z] += 1 # n_dk
            self.topic_totals[z] += 1  
            #print("doc.topics", doc['topics'])
            #print("token", token)
            #print("doctopics: ", doc['topics'])
            #print("doctokens: ", doc['tokens'].index(token))
            #print("z:", z)
            doc['topics'][doc['tokens'].index(token)] = z
            self.docs[doc_id] = doc
        return super().make_draw(alpha, beta)
               
            
            
            
#         prob_k = numpy.zeros(self.num_topics)
#         V = len(self.vocab)
#         for token,doc_id in zip(range(V),range(self.num_docs)):
#             print(token,doc_id)
#             for topic in range(self.num_topics):
#             # Remove z_i
#                 self.term_topics[token][topic] -= 1 # n_wk
#                 self.doc_topics[doc_id][topic] -= 1 # n_dk
#                 self.topic_totals[topic] -= 1       # n_k
#                 ndk = max(0,self.doc_topics[doc_id][topic])
#                 nkw = max(0,self.term_topics[token][topic])
#                 nk = max(0,self.topic_totals[topic])
#                 prob_k[topic] = (alpha + ndk)*(beta + nkw)/(V*beta + nk)
#             #print("num topics: ",self.num_topics)
#             #print("probs: ", prob_k)            
#             num = int(self.num_topics) 
#             num = range(num)
#             prob_k = prob_k / sum(prob_k)
#             print("prob_k ", sum(prob_k))
#             #print(len(num))
#             z = int(numpy.random.choice(a=num,size=1,p=prob_k))
#             print("Z:", z)
#             self.term_topics[token][z] += 1 # n_wk
#             self.doc_topics[doc_id][z] += 1 # n_dk
#             self.topic_totals[z] += 1       # n_k
#             doc = self.docs[doc_id]
#             print("doc.topics", doc['topics'])
#             doc['topics'][token] = z
#             self.docs[doc_id] = doc
#             print("After ", doc['topics'])
#             #self.docs['topics'][token]=z
                
        
        
            
    def print_topics(self, j):
        ## TODO: implement this function for exercise 2
        super().print_topics(j)
    
    def plot(self):
        ## TODO: implement this function for exercise 3
        super().plot()            
    
    def compute_logprob(self, alpha, beta):
        ## TODO: implement this function for the bonus exercise
        return super().compute_logprob(alpha, beta)

Implement the `make_draw` function above. You should get behavior very similar to the results from calling the parent class.

In [42]:
num_topics = 10
num_iterations = 10

model = MyGibbs(num_topics, 'sotu_1975_2000.txt', 'stoplist_en.txt')
model.run(num_iterations)

Read 2898 documents with a total of 8695 terms
prob_k  [3.92949747e-02 0.00000000e+00 1.27594087e-04 0.00000000e+00
 2.53097327e-02 6.01138820e-01 2.43692618e-01 0.00000000e+00
 6.43571003e-02 2.60791596e-02]
prob_k  [0.10196543 0.00099738 0.00098684 0.         0.         0.00098874
 0.89306473 0.         0.00099366 0.00100322]
prob_k  [5.06380866e-02 1.24728898e-01 1.97309078e-01 1.01137833e-01
 0.00000000e+00 2.46734866e-04 0.00000000e+00 4.99577641e-02
 4.98404105e-02 4.26141195e-01]
prob_k  [0.16788871 0.06762879 0.00613917 0.08096986 0.07201267 0.2923453
 0.17091277 0.02460066 0.00618076 0.1113213 ]
prob_k  [1.60198128e-01 1.97610775e-04 3.93052199e-02 0.00000000e+00
 3.47411731e-01 3.33360220e-01 0.00000000e+00 0.00000000e+00
 0.00000000e+00 1.19527090e-01]
prob_k  [0.         0.         0.04406338 0.         0.         0.49896876
 0.         0.         0.45696786 0.        ]
prob_k  [0.         0.         0.         0.19230559 0.61605009 0.
 0.         0.         0.00062811 0.19

prob_k  [8.85756932e-02 8.74916712e-02 2.26165140e-01 1.06382516e-01
 0.00000000e+00 1.73762947e-04 8.67497774e-02 3.86650063e-01
 0.00000000e+00 1.78113767e-02]
prob_k  [3.78642628e-02 1.25478027e-02 1.23669964e-04 4.91313138e-01
 1.21985660e-04 2.34993049e-01 1.35563255e-01 1.24973308e-04
 7.46766574e-02 1.26712063e-02]
prob_k  [9.11310815e-02 1.00904121e-02 1.00473982e-02 6.08878695e-02
 9.91046003e-03 9.93623435e-05 4.95311179e-01 1.10681734e-01
 2.00809192e-02 1.91759584e-01]
prob_k  [1.35356862e-04 0.00000000e+00 5.99111047e-01 1.35519089e-04
 1.31214056e-04 5.33285072e-02 1.06196161e-01 1.35775860e-02
 2.27249636e-01 1.34973126e-04]
prob_k  [0.00000000e+00 3.42097464e-04 0.00000000e+00 3.82237283e-01
 3.36092228e-04 2.72851617e-01 3.39590135e-04 3.47881502e-02
 2.74182020e-01 3.49231500e-02]
prob_k  [1.81091557e-04 1.78916355e-01 1.79887279e-02 1.81417362e-04
 1.75599117e-04 4.98581969e-01 2.13183861e-01 1.81787040e-02
 1.78843286e-04 7.24334317e-02]
prob_k  [4.97630277e-02 1.22

prob_k  [1.27864036e-04 1.01277652e-01 2.65141809e-01 1.28651153e-04
 4.95040397e-02 3.76801990e-02 2.51761885e-01 2.57172338e-02
 2.54344895e-02 2.43226177e-01]
prob_k  [0.00000000e+00 2.11779303e-02 0.00000000e+00 2.13283664e-04
 4.11370733e-02 8.09589105e-01 0.00000000e+00 2.12115166e-04
 2.09816162e-04 1.27460676e-01]
prob_k  [1.12503431e-01 1.33475653e-01 1.99657123e-01 0.00000000e+00
 2.16804681e-04 0.00000000e+00 3.09470778e-01 0.00000000e+00
 2.44676210e-01 0.00000000e+00]
prob_k  [3.06774067e-04 1.82256040e-01 3.02579667e-04 3.08414584e-04
 8.91627955e-02 9.03154689e-02 4.52732537e-01 0.00000000e+00
 3.03593261e-04 1.84311797e-01]
prob_k  [1.07629096e-04 0.00000000e+00 1.16860336e-01 1.09286892e-02
 2.08923403e-02 1.05286362e-04 5.60954344e-01 2.15261695e-01
 6.40243975e-02 1.08652830e-02]
prob_k  [0.0597373  0.35121787 0.00058318 0.00059453 0.11480923 0.00057867
 0.         0.00059108 0.23475156 0.23713658]
prob_k  [0.         0.06639926 0.00065595 0.46877156 0.         0.000

 0.42881302 0.02187962 0.         0.17424245]
prob_k  [0.1559861  0.1313138  0.03638571 0.25285494 0.02134457 0.08714985
 0.13144857 0.02218523 0.10202711 0.05930413]
prob_k  [0.         0.         0.         0.66649151 0.         0.
 0.         0.         0.         0.33350849]
prob_k  [1.31798829e-01 8.64619543e-02 4.32832528e-02 4.38709726e-04
 4.18496219e-04 8.60500170e-02 1.29655190e-01 4.35127173e-01
 8.63291164e-02 4.37261072e-04]
prob_k  [2.68776452e-02 1.05149305e-01 3.93649183e-02 1.33906258e-04
 2.56786990e-02 4.83441871e-01 1.32789664e-02 6.64494307e-02
 5.25773427e-02 1.87047915e-01]
prob_k  [0.00000000e+00 2.47651426e-04 3.95003420e-01 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 6.04748928e-01]
prob_k  [2.02386138e-01 3.37792439e-01 1.97738420e-04 2.04526595e-02
 0.00000000e+00 1.99409906e-02 2.58624212e-01 1.60605822e-01
 0.00000000e+00 0.00000000e+00]
prob_k  [7.82011900e-02 3.57318428e-01 7.65066607e-02 2.60122175e-04
 7.

prob_k  [0.07596327 0.10432143 0.14789965 0.10612788 0.05817981 0.01496693
 0.08917565 0.01521358 0.25248546 0.13566636]
prob_k  [0.         0.         0.33262894 0.         0.         0.
 0.         0.         0.66737106 0.        ]
prob_k  [6.41513877e-02 1.10006656e-01 2.64924103e-01 4.80718877e-02
 2.45072062e-01 1.57825487e-02 1.58146155e-02 6.37184970e-02
 1.72299492e-01 1.58750212e-04]
prob_k  [1.16132634e-01 9.77026208e-02 9.69125169e-02 4.80418817e-01
 1.42743354e-01 1.63556014e-02 1.63707092e-02 1.64517698e-04
 1.62146140e-04 3.30370828e-02]
prob_k  [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
prob_k  [2.52193329e-01 1.45457642e-04 1.73344630e-01 1.48093864e-04
 1.41700303e-04 1.44920199e-04 1.45027370e-04 2.95881185e-02
 2.91720745e-02 5.14976649e-01]
prob_k  [2.92079819e-02 1.43933541e-02 1.42951865e-02 1.30754103e-01
 0.00000000e+00 1.41928934e-04 3.41339223e-01 1.44330010e-04
 4.40773421e-01 2.89504714e-02]
prob_k  [0.79989403 0.         0.         0.         0.         0.
 0.        

 0.00061026 0.12586393 0.06218436 0.        ]
prob_k  [0.20593177 0.00066947 0.06727609 0.00067952 0.51992702 0.00066704
 0.         0.13666737 0.06750733 0.00067439]
prob_k  [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
prob_k  [4.77283485e-02 2.32208666e-04 6.95564678e-02 2.35980134e-01
 2.02967205e-01 2.31364603e-04 1.15054005e-01 2.35938580e-04
 2.31926628e-04 3.27782401e-01]
prob_k  [0.         0.         0.         0.         0.77408507 0.
 0.22591493 0.         0.         0.        ]
prob_k  [2.44348889e-02 2.36631811e-04 2.35628517e-04 2.42888722e-02
 0.00000000e+00 3.30315624e-01 2.34068986e-04 2.16631340e-01
 2.60215058e-01 1.43407889e-01]
prob_k  [1.79206844e-01 5.45018445e-01 1.93676199e-04 1.99685753e-02
 1.88647578e-04 1.16469956e-01 1.92432680e-04 1.38565210e-01
 0.00000000e+00 1.96212084e-04]
prob_k  [0.         0.02812692 0.0279962  0.3697505  0.02726888 0.
 0.         0.09965209 0.44720541 0.        ]
prob_k  [2.21117437e-02 3.21603744e-01 2.13479261e-04 2.17663445e-04
 2.28835828e

prob_k  [5.92169122e-02 1.73098301e-01 0.00000000e+00 1.18522964e-01
 5.33645405e-01 8.62252800e-02 2.90044808e-02 0.00000000e+00
 2.86656108e-04 0.00000000e+00]
prob_k  [3.91958867e-01 1.61114683e-02 1.60609114e-04 1.65416611e-02
 1.55510356e-04 1.60281009e-02 4.45604833e-01 6.54473715e-02
 4.78314798e-02 1.60098776e-04]
prob_k  [0.11337733 0.05569464 0.3336755  0.         0.00053792 0.27483733
 0.         0.         0.         0.22187727]
prob_k  [0.11106572 0.00107428 0.00108187 0.66258433 0.00104842 0.00106944
 0.00107137 0.11098859 0.00107016 0.10894583]
prob_k  [1.01591411e-01 3.33016043e-02 3.32049316e-04 0.00000000e+00
 6.46631155e-02 0.00000000e+00 3.28974056e-04 4.04977688e-01
 3.94474089e-01 3.31068027e-04]
prob_k  [1.44669343e-04 1.42708719e-02 2.86207598e-02 1.46421510e-02
 6.90996041e-02 7.45800688e-01 1.27134781e-01 1.44568808e-04
 0.00000000e+00 1.41906187e-04]
prob_k  [3.55066857e-02 1.37683980e-01 0.00000000e+00 3.52286474e-04
 3.35228859e-04 4.78866745e-01 6.89200233

prob_k  [0.03493611 0.         0.03454262 0.         0.79576197 0.
 0.         0.         0.1347593  0.        ]
prob_k  [0.05486691 0.         0.91741266 0.02772043 0.         0.
 0.         0.         0.         0.        ]
prob_k  [0.33711304 0.         0.11184154 0.0011261  0.00107193 0.10863186
 0.00109093 0.         0.32707238 0.11205223]
prob_k  [0.         0.         0.14037974 0.07157645 0.06813165 0.37446516
 0.27639925 0.         0.06904775 0.        ]
prob_k  [0.01015574 0.         0.19914849 0.02033802 0.04825288 0.0097581
 0.02952812 0.02028257 0.2539467  0.40858938]
prob_k  [1.40842248e-04 6.95971651e-02 3.90571079e-01 9.93243724e-02
 2.71152890e-02 1.76103454e-01 1.37369686e-04 8.49463590e-02
 8.21670747e-02 6.98969953e-02]
prob_k  [0.00032782 0.09732331 0.00032455 0.1323904  0.1259087  0.18959479
 0.28830793 0.16517931 0.0003183  0.0003249 ]
prob_k  [1.35416200e-01 1.68366506e-02 0.00000000e+00 1.70170562e-04
 0.00000000e+00 0.00000000e+00 1.48561502e-01 2.55146397e-01

prob_k  [2.14662169e-04 2.11687586e-04 2.13732570e-04 2.17721539e-04
 1.65201371e-01 2.08252159e-04 2.12107470e-04 4.37917494e-02
 7.89516398e-01 2.12318036e-04]
prob_k  [1.12955683e-04 1.11388901e-04 2.58870671e-01 2.52248026e-01
 1.08595282e-01 2.20183128e-02 2.68330930e-01 1.14606100e-04
 1.09330790e-04 8.94884967e-02]
prob_k  [3.84663580e-01 1.05378573e-01 4.27580666e-02 2.16415665e-04
 4.12152249e-02 2.08986598e-02 1.05692932e-01 1.95056838e-01
 8.28123517e-02 2.13073576e-02]
prob_k  [3.72591202e-01 0.00000000e+00 8.39729045e-02 1.23130212e-02
 4.62710550e-02 1.39889789e-01 3.57340438e-02 1.21952625e-04
 1.16250851e-04 3.08989781e-01]
prob_k  [0.         0.99513591 0.         0.         0.         0.
 0.         0.         0.00486409 0.        ]
prob_k  [0.000897   0.00088336 0.00089489 0.00091045 0.00086171 0.17483821
 0.00088629 0.72926648 0.00086815 0.08969347]
prob_k  [2.12599226e-01 4.83607172e-02 4.08818213e-02 8.29920147e-05
 5.50600348e-02 3.97228848e-02 2.34209525e-01 1.6

prob_k  [0.2441358  0.         0.13651822 0.07084297 0.09931325 0.13672298
 0.20615076 0.10563783 0.00033309 0.0003451 ]
prob_k  [0.00091632 0.09060297 0.00089538 0.00092734 0.00086774 0.18010605
 0.45231303 0.09319167 0.08850979 0.09166971]
prob_k  [1.90078334e-04 3.73879575e-02 1.67343409e-01 5.78798986e-02
 0.00000000e+00 1.11791853e-01 1.87346445e-04 9.58545555e-02
 4.91522198e-01 3.78427037e-02]
prob_k  [5.16569271e-02 6.73158507e-02 1.17544066e-01 1.73285217e-04
 1.63812634e-02 1.67617511e-04 1.70705563e-02 1.20768869e-01
 5.74729119e-01 3.41924470e-02]
prob_k  [0.37792122 0.         0.00061367 0.00063395 0.06003775 0.06200439
 0.31001756 0.00063126 0.00060101 0.18753918]
prob_k  [4.91622152e-04 4.79400806e-04 4.84377257e-02 5.00404065e-02
 9.35067597e-02 4.79943125e-04 1.45450117e-01 4.98463724e-02
 6.11267652e-01 0.00000000e+00]
prob_k  [1.61397638e-02 4.69709118e-02 1.55754768e-04 5.47724589e-01
 1.51130000e-04 1.57312316e-02 9.42854033e-02 4.83244613e-02
 1.82915228e-01 4.760

prob_k  [5.58152878e-04 5.35836185e-04 5.42585511e-04 2.27076845e-01
 5.27632040e-04 5.43563605e-04 5.56523386e-02 5.53080282e-01
 1.60929222e-01 5.53542055e-04]
prob_k  [0.00336539 0.00323497 0.00326992 0.00341604 0.00317973 0.0032773
 0.00331924 0.00333295 0.97026692 0.00333755]
prob_k  [1.72004066e-01 1.66738429e-02 1.66873934e-04 1.39581700e-01
 4.71143481e-01 0.00000000e+00 0.00000000e+00 8.52175549e-02
 1.15212481e-01 0.00000000e+00]
prob_k  [1.19694348e-01 2.72781701e-01 1.30747172e-01 2.12114233e-01
 4.24711235e-02 0.00000000e+00 1.48441031e-02 1.49397415e-02
 1.42919550e-04 1.92264658e-01]
prob_k  [0.00167123 0.         0.00162216 0.00169409 0.00157728 0.00162143
 0.         0.82880066 0.16136038 0.00165277]
prob_k  [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 2.28772245e-01 0.00000000e+00 2.97236713e-04 5.09093048e-01
 2.31623053e-01 3.02144174e-02]
prob_k  [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
prob_k  [2.80155839e-04 8.09718506e-02 8.19240513e-02 1.42480743e-01
 4

prob_k  [0.00084493 0.16407765 0.08456219 0.0008669  0.00080914 0.41349474
 0.0008463  0.         0.24855794 0.08594021]
prob_k  [2.84078741e-02 2.71299123e-04 2.81342757e-02 5.80396016e-02
 2.69473564e-04 2.77629054e-02 1.41066733e-01 5.64734611e-02
 6.59291275e-01 2.83101472e-04]
prob_k  [0.00052944 0.05160404 0.05295771 0.00054354 0.10216419 0.31178633
 0.42545777 0.00053001 0.00051686 0.05391011]
prob_k  [0.         0.         0.         0.         0.00315358 0.
 0.         0.         0.         0.99684642]
prob_k  [0.01366535 0.07847203 0.09393121 0.3203356  0.01312582 0.29226395
 0.04083588 0.         0.09264105 0.0547291 ]
prob_k  [0.18170526 0.25038963 0.02592753 0.05360947 0.07501776 0.00025466
 0.20830245 0.02619425 0.15216141 0.02643759]
prob_k  [0.00045698 0.00044074 0.00045208 0.00047076 0.08821929 0.35922766
 0.3212148  0.22907207 0.00044563 0.        ]
prob_k  [3.90863633e-02 3.77145169e-02 1.15233886e-01 3.98681752e-04
 0.00000000e+00 1.52371886e-01 3.87625025e-04 3.886

prob_k  [0.08930904 0.08762277 0.00088486 0.27582713 0.00085886 0.17652059
 0.         0.3672025  0.00087642 0.00089782]
prob_k  [1.18249929e-01 5.82563568e-02 5.89932197e-04 6.09709517e-04
 5.71779830e-04 5.84286720e-04 0.00000000e+00 6.09709517e-04
 7.00292296e-01 1.20236001e-01]
prob_k  [1.57653705e-01 3.86230407e-04 1.97649948e-01 4.07182982e-04
 3.82347263e-04 1.17528555e-01 0.00000000e+00 4.08464854e-01
 1.17127694e-01 3.99483875e-04]
prob_k  [0.05598225 0.05495815 0.00055774 0.0005745  0.00053867 0.11072248
 0.33473926 0.00057532 0.38430638 0.05704524]
prob_k  [0.19607414 0.00190838 0.00195347 0.00201509 0.00189168 0.00193332
 0.39210253 0.00201943 0.00192011 0.39818186]
prob_k  [0.32693887 0.00064059 0.13208182 0.00067598 0.06413323 0.00064898
 0.06613885 0.40743307 0.00064454 0.00066407]
prob_k  [0.         0.07029691 0.         0.00073501 0.00069084 0.35399862
 0.0719113  0.07439644 0.2107893  0.21718158]
prob_k  [0.00090365 0.00088824 0.00090869 0.00093748 0.00088103 0.09101

prob_k  [0.00124778 0.00122139 0.25003227 0.13115356 0.00122509 0.36433659
 0.24824987 0.00125944 0.         0.001274  ]
prob_k  [0.02893317 0.0845683  0.28586843 0.0598659  0.19675226 0.08372088
 0.02857923 0.00028915 0.23112987 0.00029281]
prob_k  [0.00163532 0.32170586 0.0016269  0.34142437 0.00159891 0.1603534
 0.00161196 0.16672095 0.00165243 0.0016699 ]
prob_k  [0.10084224 0.19641158 0.00099432 0.10508961 0.39184599 0.
 0.         0.10179164 0.10200505 0.00101957]
prob_k  [0.00164804 0.         0.00164462 0.00171582 0.00161119 0.32059087
 0.00162107 0.49969556 0.0016671  0.16980573]
prob_k  [0.06705244 0.06541859 0.00066251 0.27778086 0.39046228 0.1294012
 0.00065301 0.00066876 0.06790035 0.        ]
prob_k  [0.00106667 0.00104171 0.10762151 0.78024735 0.00104704 0.10468029
 0.00105026 0.00107563 0.00108017 0.00108937]
prob_k  [0.         0.         0.         0.41680396 0.         0.5811644
 0.         0.         0.         0.00203165]
prob_k  [0.00167363 0.00163773 0.         0

prob_k  [0.00162142 0.0015924  0.00161042 0.00183458 0.64431988 0.16915911
 0.00163634 0.00167879 0.17492193 0.00162513]
prob_k  [0.10824511 0.         0.10750362 0.12198785 0.10701491 0.11079235
 0.10950392 0.11158007 0.11512706 0.10824511]
prob_k  [0.00121392 0.00118371 0.00121392 0.00136965 0.12092229 0.86910803
 0.00123981 0.00125168 0.00128868 0.00120831]
prob_k  [0.24534858 0.00236818 0.24534858 0.00272969 0.00240107 0.00246969
 0.0024815  0.00252993 0.00259968 0.4917231 ]
prob_k  [0.323253   0.0031194  0.00317054 0.         0.00314111 0.00323108
 0.65413163 0.00332636 0.00341877 0.00320811]
prob_k  [0.00460038 0.00451478 0.46463873 0.00514957 0.00454651 0.00468929
 0.00471206 0.00484133 0.49766293 0.00464441]
prob_k  [0.00138015 0.13614594 0.0013736  0.15533927 0.00136067 0.56148335
 0.00140702 0.00144572 0.         0.14006428]
prob_k  [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
prob_k  [0.63829867 0.         0.         0.00354927 0.00313807 0.00322187
 0.00325346 0.00332685 0.34201324 0.00

ValueError: probabilities are not non-negative

In [40]:
sum([0.00000000e+00, 7.34808434e-04, 0.00000000e+00, 6.84701532e-04,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 1.50552899e-01, 8.48027591e-01])

0.999999999966

### 2. Top terms
Implement the `print_topics` function to extract the top `j` largest counts in $n(w)$ by row. This is the most probable word types in each topic.

In [None]:
model.print_topics(10)

### 3. Explore the data

Run your implemention on the State of the Union corpus until convergence with 10 topics, don't forget to remove stop words. Plot the log marginal posterior by the number of iterations. How many iterations do you need until convergence? How do you interpret the topics?

[Hint: You can use the plot-function to print the marginal probability for each iteration. To get it working in Jupyter you need to run the command `%matplotlib inline` before plotting the first time.]

In [None]:
%matplotlib inline
model.plot()

#### Answer here:

### 4. Simulate a new State of the Union speech

Write a function `new_speech` using the `MyGibbs` class to use the estimated values for $\Phi$ from your model (with stop words removed) to simulate a new State of the Union speech. Start out by simulating $\theta_d \sim Dir(\alpha = 0.5)$ and then simulate your document. Does it make sense? Why, why not?

In [None]:
def new_speech(model, alpha, num_words):
    speech = []        
    return " ".join(speech)

In [None]:
new_speech(model, 0.5, 100)

#### Answer here:

### Bonus assignment:

To get better understanding how to implement the underlying model or similar models, you might want to implement your own function to compute the log marginal posterior. If so, implement the `compute_logprob` function.

$$\begin{align}
\log p(\mathbf{z}|\mathbf{w}) =& \log\prod^{K}p(\mathbf{w}|\mathbf{z},\beta)\prod^{D}p(\mathbf{z}|\alpha) \\
=& \sum^{K}\log\left[\frac{\Gamma\left(\sum^{V}\beta\right)}{\prod^{V}\Gamma\left(\beta\right)}\frac{\prod^{V}\Gamma\left(n_{kv}^{(w)}+\beta\right)}{\Gamma(\sum^{V}n_{kv}^{(w)}+\beta)}\right]+\sum^{D}\log\left[\frac{\Gamma\left(\sum^{K}\alpha\right)}{\prod^{K}\Gamma\left(\alpha\right)}\frac{\prod^{K}\Gamma\left(n_{dk}^{(d)}+\alpha\right)}{\Gamma(\sum^{K}n_{dk}^{(d)}+\alpha)}\right] \\
=& K\log\Gamma\left(V\beta\right)-KV\log\Gamma\left(\beta\right)+\sum^{K}\sum^{V}\log\Gamma\left(n_{kv}^{(w)}+\beta\right)-\sum^{K}\log\Gamma(\sum^{V}n_{kv}^{(w)}+\beta)\\
&+ D\log\Gamma\left(K\alpha\right)-DK\log\Gamma\left(\alpha\right)+\sum^{D}\sum^{K}\log\Gamma\left(n_{dk}^{(d)}+\alpha\right)-\sum^{D}\log\Gamma(\sum^{K}n_{dk}^{(d)}+\alpha)
\end{align}$$

In Python, use `scipy.special.gammaln` for $\log\Gamma(x)$ (if you run into problems, you might try `math.lgamma` instead).