<a href="https://colab.research.google.com/github/tanthongtan/ptm/blob/master/sam_geodesic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hyperparameters

In [0]:
#model hyperparameters
num_topic = 20
alpha = 0.5
c0 = 1000.0
kappa0 = 1000.0
kappa1 = 3000.0

#GMC hyperparameters
num_samples = 50000
num_burn = 10000
L = 5
eps = 2e-5

# Run GMC Inference

In [0]:
import torch
import torch.nn.functional as F
from geodesic import SphericalGeodesicMonteCarlo
import dataset
import geodesic
import distributions as D
import time
import sys

#only for google colab
if 'google.colab' in sys.modules:
    !git clone https://github.com/tanthongtan/ptm.git
    %cd '/content/ptm'

#make all tensors cuda if available and double
if torch.cuda.is_available():
    torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
    torch.set_default_tensor_type(torch.DoubleTensor)

#Load Data
data_tr, tensor_tr, vocab, vocab_size, num_tr = dataset.load_20news_diff()    

#declare tensor hyperparameters
alpha = torch.full((1,num_topic), alpha)
mu0 = F.normalize(tensor_tr.sum(dim=0),dim=-1)

#randomly initialize model parameters
theta = F.normalize(torch.randn(num_tr, num_topic), p=2, dim=-1)
mu = F.normalize(torch.randn(num_topic, vocab_size), p=2, dim=-1)

#declare GMC transition kernels
kernel = SphericalGeodesicMonteCarlo(L, eps)

#start sampling loop
real_start = time.time()
for i in range(num_samples+num_burn):
    start_time = time.time()
    theta=kernel.transition(theta, D.SamFullConditionalThetaDistribution(tensor_tr, mu, alpha, kappa1))
    mu = kernel.transition(mu, D.SamFullConditionalMuDistribution(tensor_tr, theta, c0, mu0, kappa0, kappa1))
    if i == num_burn:
        theta_samples = theta
        mu_samples = mu
    if i > num_burn:
        theta_samples += theta
        mu_samples += mu
    if i % 50 == 0:
        print('accept prob:', geodesic.accept_prob,' current iter:', i, ' time per iter:', time.time() - start_time, ' time taken:', time.time()-real_start)

accept prob: tensor(0.9961)  current iter: 0  time per iter: 0.39033079147338867  time taken: 0.3905174732208252
accept prob: tensor(1.0118)  current iter: 50  time per iter: 0.09789228439331055  time taken: 5.327042818069458
accept prob: tensor(1.0118)  current iter: 100  time per iter: 0.09692788124084473  time taken: 10.248358011245728
accept prob: tensor(1.0157)  current iter: 150  time per iter: 0.0975198745727539  time taken: 15.183617353439331
accept prob: tensor(1.0157)  current iter: 200  time per iter: 0.09804177284240723  time taken: 20.117740154266357
accept prob: tensor(1.)  current iter: 250  time per iter: 0.09637618064880371  time taken: 25.046432971954346
accept prob: tensor(1.0157)  current iter: 300  time per iter: 0.09716367721557617  time taken: 30.00232243537903
accept prob: tensor(1.0645)  current iter: 350  time per iter: 0.10164475440979004  time taken: 34.93155813217163
accept prob: tensor(1.)  current iter: 400  time per iter: 0.0973057746887207  time taken: 

# Get Topic Coherence

In [0]:
associations = {
    'jesus': ['prophet', 'jesus', 'matthew', 'christ', 'worship', 'church'],
    'comp ': ['floppy', 'windows', 'microsoft', 'monitor', 'workstation', 'macintosh', 
              'printer', 'programmer', 'colormap', 'scsi', 'jpeg', 'compression'],
    'car  ': ['wheel', 'tire'],
    'polit': ['amendment', 'libert', 'regulation', 'president'],
    'crime': ['violent', 'homicide', 'rape'],
    'midea': ['lebanese', 'israel', 'lebanon', 'palest'],
    'sport': ['coach', 'hitter', 'pitch'],
    'gears': ['helmet', 'bike'],
    'nasa ': ['orbit', 'spacecraft'],
}
def identify_topic_in_line(line):
    topics = []
    for topic, keywords in associations.items():
        for word in keywords:
            if word in line:
                topics.append(topic)
                break
    return topics

def print_top_words(beta, feature_names, n_top_words=10):
    print('---------------Printing the Topics------------------')
    for i in range(len(beta)):
        line = " ".join([feature_names[j] 
                            for j in beta[i].argsort()[:-n_top_words - 1:-1]])
        topics = identify_topic_in_line(line)
        print(('|'.join(topics)))
        print(('     {}'.format(line)))
    print('---------------End of Topics------------------')

emb = (mu_samples/num_samples).cpu().numpy()
print_top_words(emb, list(zip(*sorted(list(vocab.items()), key=lambda x:x[1])))[0])