In [2]:
from matplotlib import pyplot as plt
from scipy import stats
import numpy as np
import pandas as pd

In [3]:
class CAVI():
    def __init__(self, data):
        self.data = data

        # For PI (PI ~ Dirichlet)
        self.pi_a = 0.5

        # For Lambda (Lambda ~ Gamma)
        self.gamma_a = 1
        self.gamma_b = 1

        self.elbo = []

    def init_params(self):
        self.pi = np.random.dirichlet([self.pi_a] * self.data.shape[1])
        self.gamma = np.random.gamma(self.gamma_a, 1 / self.gamma_b, size=self.data.shape[1])
        self.z = np.random.multinomial(1, self.pi, size=self.data.shape[0])

    def update(self):
        self.update_pi()
        self.update_gamma()
        self.update_z()
    
    def update_pi(self):
        self.pi = (self.z.sum(axis=0) + self.pi_a - 1) / (self.data.shape[0] + self.pi_a + self.pi.sum() - 1)
    
    def update_gamma(self):
        self.gamma = self.gamma_a + self.z.sum(axis=0)

    def update_z(self):
        for i in range(self.data.shape[0]):
            self.z[i] = np.random.multinomial(1, self.pi * stats.poisson.pmf(self.data[i], self.gamma))

    def calculate_elbo(self):
        logpx = np.sum(self.z * np.log(stats.poisson.pmf(self.data, self.gamma).T))
        logpz = np.sum(self.z * np.log(self.pi))
        logppi = stats.dirichlet(self.pi_a * np.ones(self.data.shape[1])).logpdf(self.pi)
        logpgamma = np.sum(stats.gamma(self.gamma_a, scale=1 / self.gamma_b).logpdf(self.gamma))
        logqz = np.sum(self.z * np.log(self.z))
        logqpi = stats.dirichlet(self.pi_a * np.ones(self.data.shape[1])).logpdf(self.pi)
        logqgamma = np.sum(stats.gamma(self.gamma_a, scale=1 / self.gamma_b).logpdf(self.gamma))
        return logpx + logpz + logppi + logpgamma - logqz - logqpi - logqgamma


    def fit(self, max_iter=100):
        self.init_params()
        for i in range(max_iter):
            self.update()
            self.elbo.append(self.calculate_elbo())
            if self.elbo[-1] - self.elbo[-2] < 1e-5:
                break


In [4]:
data = pd.DataFrame()
vocabs = []
with open('./vocab.txt', 'r') as f:
    for line in f:
        vocabs.append(line.split()[0])

rows = []
count = []
with open('./ap.dat', 'r') as f:
    for line in f:
        tmp = line.split(' ')
        count.append(tmp[0])
        tmp = tmp[1:]
        row = [0] * len(vocabs)
        for elem in tmp:
            index, value = elem.split(':')
            row[int(index)] = int(value)
        rows.append(row)
            
x_data = pd.DataFrame(rows, columns = vocabs)

In [6]:
cavi = CAVI(x_data.values)
cavi.fit(max_iter=100)

ValueError: pvals < 0, pvals > 1 or pvals contains NaNs