In [2]:
import numpy as np
import pandas as pd
from scipy.stats import poisson, gamma, dirichlet, multinomial
from tqdm import tqdm


In [3]:
!python preprocess.py

In [4]:
x_data = pd.read_csv('data.csv')


In [5]:
x_data = x_data.iloc[:,1:-1]

In [6]:
def initialize_params(x_data, K,alpha):
    N, V = x_data.shape
    topic_assignments = np.random.randint(0, K, size=N)
    topic_word_counts = np.zeros((K, V))
    topic_counts = np.zeros(K)
    lambdas = np.random.gamma(1.0, 1.0, size=(K, V))
    pis = dirichlet.rvs(alpha * np.ones(K), size=N)
    
    for i in range(N):
        topic = topic_assignments[i]
        topic_counts[topic] += 1
        topic_word_counts[topic] += x_data[i]

    return topic_assignments, topic_word_counts, topic_counts, lambdas, pis

In [39]:
gamma.rvs(a=[1,100], scale=1.0)

array([6.76362861e-02, 8.91860559e+01])

In [7]:
def gibbs_sampling(x_data, K, alpha, beta, num_iters=1000):
    N, V = x_data.shape
    topic_assignments, topic_word_counts, topic_counts, lambdas, pis = initialize_params(x_data, K, alpha)

    for iteration in tqdm(range(num_iters)):
        for i in range(N):
            current_topic = topic_assignments[i]
            topic_counts[current_topic] -= 1
            topic_word_counts[current_topic] -= x_data[i]

            log_topic_probs = np.zeros(K)
            for k in range(K):
                log_topic_word_prob = np.sum(poisson.logpmf(x_data[i], mu=lambdas[k]))
                log_topic_prior_prob = np.log(pis[i, k])
                log_topic_probs[k] = log_topic_word_prob + log_topic_prior_prob

            # Subtract the max log probability to avoid numerical instability
            log_topic_probs = log_topic_probs - np.max(log_topic_probs)
            topic_probs = np.exp(log_topic_probs)
            topic_probs /= topic_probs.sum()

            # Sample new_topic from a multinomial distribution
            new_topic = np.random.choice(K, p=topic_probs)

            topic_assignments[i] = new_topic
            topic_counts[new_topic] += 1
            topic_word_counts[new_topic] += x_data[i]

            # Update lambdas for the new topic
            lambdas[new_topic] = gamma.rvs(a=topic_word_counts[new_topic] + beta, scale=1.0)

            # Update pi for the current document
            pis[i] = dirichlet.rvs(alpha + topic_counts)

    return topic_assignments, topic_word_counts, topic_counts, lambdas, pis

In [33]:
## Very Crude Pruning TODO Make this more reliable!

vocabs = []
with open('./vocab.txt', 'r') as f:
    for line in f:
        vocabs.append(line.strip())
        
len(vocabs)
del vocabs[8000:-1]
del vocabs[0:10]
len(vocabs)
pruned_x = x_data[vocabs]

In [35]:

# Set parameters
K = 10  # Number of topics
alpha = 1.0  # Dirichlet prior parameter
beta = 1.0  # Gamma prior parameter
# num_iters = 100  # Number of iterations 
num_iters = 20 # Reduced Iterations 

### Take Pruned Value
# x_data_np = x_data.values
x_data_np = pruned_x.values


# Run the Gibbs sampler
print(x_data_np.shape)
topic_assignments, topic_word_counts, topic_counts, lambdas, pis = gibbs_sampling(x_data_np, K, alpha, beta, num_iters)


(2246, 7991)


100%|██████████| 20/20 [02:12<00:00,  6.63s/it]


In [37]:
lambdas[1]

array([183.01910867, 157.76521785, 152.50389722, ...,   0.78867308,
         4.8711386 ,   0.4820547 ])

In [18]:
topic_word_counts

array([[  0.,   0.,   0., ...,   0.,   0.,   0.],
       [229., 252., 278., ...,   2.,   1.,   0.],
       [358., 268., 235., ...,   2.,   0.,   1.],
       ...,
       [295., 288., 202., ...,   0.,   0.,   1.],
       [114., 145., 148., ...,   1.,   2.,   0.],
       [250., 232., 299., ...,   0.,   0.,   1.]])