In [187]:
from matplotlib import pyplot as plt
from scipy.stats import multinomial,gamma,dirichlet,poisson
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.special import psi  
from scipy.special import digamma, gammaln

In [259]:
class CAVI():
    def __init__(self, data,pi_a,gamma_a,gamma_b,K):

        # Number of Topics
        self.K = K

        # For Data & Size
        self.data = data
        self.N, self.V = data.shape

        # For PI (PI ~ Dirichlet)
        self.pi_a = pi_a

        # For Lambda (Lambda ~ Gamma)
        self.gamma_a = gamma_a
        self.gamma_b = gamma_b

        self.elbo = []

    def init_params(self):
        self.pis = dirichlet.rvs([self.pi_a] * self.K).flatten()
        self.lambdas = gamma.rvs(self.gamma_a, 1 / self.gamma_b, size=(self.K,self.V))
        self.Z = np.zeros((self.N, self.K))
        for i in range(self.N):
            self.Z[i] = multinomial.rvs(n=1, p=self.pis)
        self.expected_counts = np.dot(self.Z.T, self.data)
        print('----initialize random variables finished----')
        #print(f'PIS: {self.pis}, Lambdas: {self.lambdas}, Z:{self.Z}')

    def update(self):
        self.update_z()
        self.expected_counts = np.dot(self.Z.T, self.data)
        self.update_pi()
        print(self.pis)
        self.update_lambda()
        #for k in range(self.K):
        #    print(self.lambdas[k])
        #print('----update parameters finished----')
        
    
    def update_pi(self):
        self.pis = np.sum(self.Z, axis=0) + self.pi_a
        self.pis = self.pis / np.sum(self.pis)
    
    def update_lambda(self):
        for k in range(self.K):
            self.lambdas[k] = (self.expected_counts[k] + self.gamma_a) / (self.gamma_b + np.sum(self.Z[:, k]))
            #print(self.expected_counts[k],np.sum(self.Z[:, k]))

    def update_z(self):

        for i in range(self.N):  # for each document
            E_Z_ln_pi = psi(self.Z[i]) + psi(self.pis) - psi(np.sum(self.pis))
            E_X_ln_lambda_minus_lambda_minus_ln_X_fact = np.dot(self.data[i], (psi(self.lambdas.T) - np.log(self.lambdas.T)))

            ln_q = E_Z_ln_pi + E_X_ln_lambda_minus_lambda_minus_ln_X_fact

            self.Z[i] = np.exp(ln_q - np.max(ln_q))
            self.Z[i] /= np.sum(self.Z[i])  # normalize
        self.Z = np.eye(self.Z.shape[1])[np.argmax(self.Z, axis=1)]
     

    def calculate_elbo(self):
        expected_log_likelihood = self.compute_log_likelihood()
        expected_entropy = self.compute_entropy()
        return expected_log_likelihood - expected_entropy
    
    
    
    def compute_log_likelihood(self):
        exp_poisson = np.sum(poisson.logpmf(self.data,np.dot(self.Z,self.lambdas)))
        exp_gamma = np.sum(gamma.logpdf(self.lambdas, self.gamma_a, scale= 1/self.gamma_b))
        exp_multinomial = np.sum(multinomial.logpmf(self.Z, n=1, p=self.pis))
        #print(f'{self.pis}')
        exp_dirichlet = dirichlet.logpdf(self.pis, np.ones(self.K)*self.pi_a)
        
        return exp_poisson + exp_multinomial + exp_dirichlet + exp_gamma
    
    def compute_entropy(self):
        exp_log_q_Z = np.sum(multinomial.logpmf(self.Z, n=1, p=self.pis))
        exp_log_q_pi = dirichlet.logpdf(self.pis, np.ones(self.K)*self.pi_a)
        exp_log_q_lambda = np.sum(gamma.logpdf(self.lambdas, self.gamma_a, scale=1/self.gamma_b))

        return exp_log_q_Z + exp_log_q_pi + exp_log_q_lambda

    def fit(self, max_iter=100):
        self.init_params()
        print('----start iteration----')
        for i in tqdm(range(max_iter),total=max_iter,desc='VI',ncols=100,ascii=' =',leave=True):
            self.update()
            self.elbo.append(self.calculate_elbo())
            if (len(self.elbo) >= 2) and (np.abs(self.elbo[-1] - self.elbo[-2]) < 1e-6):
                print(i)
                break
        print('----finish iteration----')


In [32]:
data = pd.DataFrame()
vocabs = []
with open('./vocab.txt', 'r') as f:
    for line in f:
        vocabs.append(line.split()[0])

rows = []
count = []
with open('./ap.dat', 'r') as f:
    for line in f:
        tmp = line.split(' ')
        count.append(tmp[0])
        tmp = tmp[1:]
        row = [0] * len(vocabs)
        for elem in tmp:
            index, value = elem.split(':')
            row[int(index)] = int(value)
        rows.append(row)
            
x_data = pd.DataFrame(rows, columns = vocabs)

In [159]:
vocabs = []
with open('./vocab.txt', 'r') as f:
    for line in f:
        vocabs.append(line.strip())
        
tmp = x_data.sum()
tmp = np.array(tmp)
to_be_pruned = []
for idx, elem in enumerate(tmp):
    if elem > 1000 or elem < 10:
        to_be_pruned.append(idx)

for i in sorted(to_be_pruned, reverse=True):
    del vocabs[i]
    
pruned_x = x_data[vocabs]

In [160]:
pruned_x.sum()

officials      1000
soviet          999
united          998
bush            949
time            948
               ... 
chicken          10
homosexuals      10
rocked           10
locate           10
frohnmayer       10
Length: 7262, dtype: int64

In [264]:
gamma_a = 1
gamma_b = 1
pi_a = 1
K = 8

In [265]:
# cavi = CAVI(data=x_data.values,pi_a=pi_a,gamma_a=gamma_a,gamma_b=gamma_b,K=K)
cavi = CAVI(data=pruned_x.values,pi_a=pi_a,gamma_a=gamma_a,gamma_b=gamma_b,K=K)

In [266]:
cavi.fit(max_iter=100)

----initialize random variables finished----
----start iteration----


VI:   1%|                                                           | 1/100 [00:10<17:21, 10.52s/it]

1
----finish iteration----





In [118]:
cavi.lambdas.shape

(8, 898)

In [267]:
top_word_index = {}
for idx, ld in enumerate(cavi.lambdas):
    top_word_index[idx] =  sorted(range(len(ld)), key=lambda x: ld[x])[-30:]
    top_word_index[idx].reverse()

toplist = pd.DataFrame()
for topic in top_word_index:
    # print(f"Topic K={topic}:   ", )
    tmp =[]
    for idx in top_word_index[topic]:
        tmp.append(vocabs[idx])
    toplist[topic] = tmp

In [268]:
toplist

Unnamed: 0,0,1,2,3,4,5,6,7
0,bush,united,dukakis,officials,bush,today,bush,soviet
1,soviet,south,today,national,soviet,house,house,united
2,officials,military,bush,time,market,court,children,officials
3,national,officials,told,three,united,south,available,billion
4,united,tuesday,national,billion,time,american,parents,time
5,time,house,federal,tuesday,american,three,national,three
6,told,say,three,united,union,officials,today,bush
7,wednesday,mrs,united,american,monday,time,time,told
8,billion,told,time,party,made,united,number,today
9,made,soviet,york,soviet,week,friday,billion,federal
