In [82]:
import numpy as np
import pandas as pd
from scipy.stats import poisson, gamma, dirichlet, multinomial
from tqdm import tqdm


In [83]:
!python preprocess.py

In [84]:
x_data = pd.read_csv('data.csv')


In [85]:
x_data = x_data.iloc[:,1:-1]

In [86]:
def initialize_params(x_data, K,alpha):
    N, V = x_data.shape
    topic_assignments = np.random.randint(0, K, size=N)
    topic_word_counts = np.zeros((K, V))
    topic_counts = np.zeros(K)
    lambdas = np.random.gamma(1.0, 1.0, size=(K, V))
    pis = dirichlet.rvs(alpha * np.ones(K), size=N)
    
    for i in range(N):
        topic = topic_assignments[i]
        topic_counts[topic] += 1
        topic_word_counts[topic] += x_data[i]

    return topic_assignments, topic_word_counts, topic_counts, lambdas, pis

In [87]:
gamma.rvs(a=[1,100], scale=1.0)

array([  0.90753515, 112.43588239])

In [88]:
def gibbs_sampling(x_data, K, alpha, beta, num_iters=1000):
    N, V = x_data.shape
    topic_assignments, topic_word_counts, topic_counts, lambdas, pis = initialize_params(x_data, K, alpha)

    for iteration in tqdm(range(num_iters)):
        for i in range(N):
            current_topic = topic_assignments[i]
            topic_counts[current_topic] -= 1
            topic_word_counts[current_topic] -= x_data[i]

            log_topic_probs = np.zeros(K)
            for k in range(K):
                log_topic_word_prob = np.sum(poisson.logpmf(x_data[i], mu=lambdas[k]))
                log_topic_prior_prob = np.log(pis[i, k])
                log_topic_probs[k] = log_topic_word_prob + log_topic_prior_prob

            # Subtract the max log probability to avoid numerical instability
            log_topic_probs = log_topic_probs - np.max(log_topic_probs)
            topic_probs = np.exp(log_topic_probs)
            topic_probs /= topic_probs.sum()

            # Sample new_topic from a multinomial distribution
            new_topic = np.random.choice(K, p=topic_probs)

            topic_assignments[i] = new_topic
            topic_counts[new_topic] += 1
            topic_word_counts[new_topic] += x_data[i]

            # Update lambdas for the new topic
            lambdas[new_topic] = gamma.rvs(a=topic_word_counts[new_topic] + beta, scale=1.0)

            # Update pi for the current document
            pis[i] = dirichlet.rvs(alpha + topic_counts)

    return topic_assignments, topic_word_counts, topic_counts, lambdas, pis

In [96]:
## Very Crude Pruning TODO Make this more reliable!

vocabs = []
with open('./vocab.txt', 'r') as f:
    for line in f:
        vocabs.append(line.strip())
        
len(vocabs)
del vocabs[8000:-1]
del vocabs[0:2000]
len(vocabs)
pruned_x = x_data[vocabs]

In [97]:

# Set parameters
K = 10  # Number of topics
alpha = 1.0  # Dirichlet prior parameter
beta = 1.0  # Gamma prior parameter
# num_iters = 100  # Number of iterations 
num_iters = 40 # Reduced Iterations 

### Take Pruned Value
# x_data_np = x_data.values
x_data_np = pruned_x.values


# Run the Gibbs sampler
print(x_data_np.shape)
topic_assignments, topic_word_counts, topic_counts, lambdas, pis = gibbs_sampling(x_data_np, K, alpha, beta, num_iters)


(2246, 6001)


100%|██████████| 40/40 [03:31<00:00,  5.28s/it]


In [98]:
top_word_index = {}
for idx, ld in enumerate(lambdas):
    top_word_index[idx] =  sorted(range(len(ld)), key=lambda x: ld[x])[-30:]
    top_word_index[idx].reverse()

In [99]:
toplist = pd.DataFrame()

In [100]:
for topic in top_word_index:
    # print(f"Topic K={topic}:   ", )
    tmp =[]
    for idx in top_word_index[topic]:
        tmp.append(vocabs[idx])
    toplist[topic] = tmp

In [101]:
toplist

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,provision,gunmen,sandinista,venus,ames,nelson,cable,venus,pacs,drinking
1,parole,reed,bcspehealth,carpenter,wedtech,macmillan,robertson,inheritance,yeutter,militants
2,fujimori,bushels,orion,maxwell,edgemont,peres,creek,creek,maxwell,stolen
3,castro,hezbollah,fatal,peres,manhattan,spacecraft,ceausescu,particular,cdy,hijackers
4,ceausescu,cloudy,sand,radar,actress,corroon,quake,miners,subsidies,andreas
5,walsh,andreas,symphony,infected,whites,exercise,verdict,fda,infected,terrorists
6,klerk,disney,device,abuse,taylor,maxwell,solid,robertson,macmillan,birth
7,policemen,inmates,activist,bird,mexican,venus,vargas,aristide,retailers,ballot
8,warmus,hijackers,employers,macmillan,grigoryants,electoral,rocket,veto,ruby,auction
9,convoy,drinking,species,welfare,polish,pennsylvania,apply,arrangement,trash,gunmen


In [95]:
vocabs

['years',
 'first',
 'police',
 'state',
 'states',
 'officials',
 'soviet',
 'united',
 'bush',
 'time',
 'three',
 'billion',
 'today',
 'national',
 'told',
 'american',
 'thursday',
 'federal',
 'house',
 'week',
 'court',
 'day',
 'tuesday',
 'made',
 'news',
 'wednesday',
 'monday',
 'friday',
 'say',
 'company',
 'city',
 'party',
 'just',
 'group',
 'york',
 'market',
 'report',
 'department',
 'military',
 'south',
 'union',
 'members',
 'home',
 'west',
 'political',
 'reported',
 'make',
 'going',
 'office',
 'get',
 'spokesman',
 'dont',
 'world',
 'like',
 'four',
 'think',
 'committee',
 'back',
 'work',
 'defense',
 'says',
 'country',
 'war',
 'congress',
 'nations',
 'foreign',
 'official',
 'public',
 'trade',
 'take',
 'prices',
 'month',
 'general',
 'economic',
 'five',
 'air',
 'money',
 'stock',
 'called',
 'found',
 'dukakis',
 'days',
 'campaign',
 'law',
 'months',
 'program',
 'case',
 'asked',
 'workers',
 'administration',
 'late',
 'business',
 'east',
 'm