In [82]:
from matplotlib import pyplot as plt
from scipy.stats import multinomial,gamma,dirichlet,poisson
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.special import psi  
from scipy.special import polygamma

In [136]:
class CAVI():
    def __init__(self, data,pi_a,gamma_a,gamma_b,K):

        # Number of Topics
        self.K = K

        # For Data & Size
        self.data = data
        self.N, self.V = data.shape

        # For PI (PI ~ Dirichlet)
        self.pi_a = pi_a

        # For Lambda (Lambda ~ Gamma)
        self.gamma_a = gamma_a
        self.gamma_b = gamma_b

        self.elbo = []

    def init_params(self):
        self.pis = dirichlet.rvs([self.pi_a] * self.K).flatten()
        self.lambdas = gamma.rvs(self.gamma_a, 1 / self.gamma_b, size=(self.K,self.V))
        self.Z = np.zeros((self.N, self.K))
        for i in range(self.N):
            self.Z[i] = multinomial.rvs(n=1, p=self.pis)

        print('----initialize random variables finished----')
        print(f'PIS: {self.pis.shape}, Lambdas: {self.lambdas.shape}, Z:{self.Z.shape}')

    def update(self):
        self.update_z()
        self.expected_counts = np.dot(self.Z.T, self.data)
        self.update_pi()
        
        self.update_lambda()
        print('----update parameters finished----')
    
    def update_pi(self):
        self.pis = np.sum(self.Z, axis=0) + self.pi_a
    
    def update_lambda(self):
        for k in range(self.K):
            self.lambdas[k] = self.expected_counts[k] + self.gamma_a

    def update_z(self):
        for i in range(self.N):
            for k in range(self.K):
                self.Z[i, k] = np.exp(psi(self.pis[k]) + np.sum(self.data[i] * (psi(self.lambdas[k]) - np.log(self.lambdas[k]))))
            self.Z[i] /= np.sum(self.Z[i])

    def calculate_elbo(self):
        expected_log_likelihood = self.compute_log_likelihood()
        expected_entropy = self.compute_entropy()
        return expected_log_likelihood + expected_entropy
    
    def compute_log_likelihood(self):
        exp_poisson = np.sum(poisson.logpmf(self.expected_counts,self.lambdas))
        exp_gamma = np.sum(gamma.logpdf(self.lambdas, self.gamma_a, scale= 1/self.gamma_b))
        exp_multinomial = np.sum(multinomial.logpmf(self.Z, n=1, p=self.pis))
        print(f'{exp_poisson},{exp_gamma},{exp_multinomial}')
        exp_dirichlet = dirichlet.logpdf(self.pis, [self.pi_a]*len(self.pis))
        
        return exp_poisson + exp_multinomial + exp_dirichlet + exp_gamma
    
    def compute_entropy(self):
        exp_log_q_Z = np.sum(multinomial.logpmf(self.Z, n=1, p=self.pis))
        exp_log_q_pi = dirichlet.logpdf(self.pis, [self.pi_a]*len(self.pis))
        exp_log_q_lambda = np.sum(gamma.logpdf(self.lambdas, self.gamma_a, scale=1/self.gamma_b))

        return exp_log_q_Z + exp_log_q_pi + exp_log_q_lambda

    def fit(self, max_iter=100):
        self.init_params()
        print('----start iteration----')
        for i in tqdm(range(max_iter),total=max_iter,desc='VI',ncols=100,ascii=' =',leave=True):
            self.update()
            self.elbo.append(self.calculate_elbo())
            if self.elbo[-1] - self.elbo[-2] < 1e-5:
                break
        


In [13]:
data = pd.DataFrame()
vocabs = []
with open('./vocab.txt', 'r') as f:
    for line in f:
        vocabs.append(line.split()[0])

rows = []
count = []
with open('./ap.dat', 'r') as f:
    for line in f:
        tmp = line.split(' ')
        count.append(tmp[0])
        tmp = tmp[1:]
        row = [0] * len(vocabs)
        for elem in tmp:
            index, value = elem.split(':')
            row[int(index)] = int(value)
        rows.append(row)
            
x_data = pd.DataFrame(rows, columns = vocabs)

In [109]:
gamma_a = 1
gamma_b = 1
pi_a = 0.5
K = 8

In [137]:
cavi = CAVI(data=x_data.values,pi_a=pi_a,gamma_a=gamma_a,gamma_b=gamma_b,K=K)

In [138]:
cavi.fit(max_iter=100)

----initialize random variables finished----
PIS: (8,), Lambdas: (8, 10473), Z:(2246, 8)
----start iteration----


VI:   0%|                                                                   | 0/100 [00:07<?, ?it/s]

----update parameters finished----
-inf,-519621.99999999994,nan





ValueError: Each entry in 'x' must be smaller or equal one.