<a href="https://colab.research.google.com/github/tanthongtan/ptm/blob/master/sam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hyperparameters

In [None]:
num_topic = 100
dataset = 'wiki'
method = 'sam'

#model hyperparameters
alpha_scalar = 50./num_topic
c0 = 1000.0
kappa1 = 10000.0
prior_mu = 'neg' #possible options: neg, pos, mean

#GMC hyperparameters
num_samples = 1
num_burn = 5000
S = 25000
L = 20
eta_theta = 1e-1
rho_theta = 1e-1
eta_mu = 2.5e-4
rho_mu = 1e-1

# Run GMC Inference

In [None]:
#only for google colab
import sys
import os
if 'google.colab' in sys.modules:
    #lets see what gpu we were given
    !nvidia-smi
    #get repository
    !git clone https://github.com/tanthongtan/ptm.git
    %cd '/content/ptm'
    #get ref corp if doesn't exist
    if not os.path.isdir('wiki_final'):
        !unzip -q "/content/drive/My Drive/wiki_final.zip"

import torch
import torch.nn.functional as F
from geodesic import GeodesicMonteCarlo
from dataset import load_data, csr_to_torchsparse
import geodesic as g
import distributions as D
from tqdm.notebook import tqdm
import torch.distributions as dist
import numpy as np
from utils import print_topics, get_topics, vmf_perplexity, clustering_metrics_20news, print_summary

#make all tensors cuda if available and double
if torch.cuda.is_available():
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    gpu = True
else:
    torch.set_default_tensor_type(torch.FloatTensor)
    gpu = False

#Load Data
data_tr, data_te, vocab, vocab_size, num_tr = load_data(use_tfidf = True, sublinear = False, normalize = True, dataset = dataset)    
tensor_te = csr_to_torchsparse(data_te, gpu)
tensor_tr = csr_to_torchsparse(data_tr, gpu)

#declare tensor hyperparameters
alpha = torch.full((1,num_topic), alpha_scalar)
if prior_mu == 'neg':
    mu0 = F.normalize(torch.full((vocab_size,),-1.0),dim=-1)
if prior_mu == 'pos':
    mu0 = F.normalize(torch.full((vocab_size,),1.0),dim=-1)
if prior_mu == 'mean':
    mu0 = F.normalize(torch.sparse.sum(tensor_tr,dim=0).to_dense(),dim=-1)

#randomly initialize model parameters
theta = torch.randn(num_tr,num_topic-1)
mu = F.normalize(torch.randn(num_topic, vocab_size) / (vocab_size ** 0.5) + mu0, p=2, dim=-1)

#declare GMC transition kernels
kernel = GeodesicMonteCarlo(L)
params = {'theta': theta, 'mu':mu}
init_etas = {'theta': eta_theta, 'mu':eta_mu}
geodesics = {'theta': g.RnGeodesic(eta = eta_theta, rho = rho_theta), 'mu': g.SphericalGeodesic(eta = eta_mu, rho = rho_mu)}
vs = {name: geodesics[name].projection(params[name],dist.MultivariateNormal(torch.zeros(params[name].shape[-1]), torch.eye(params[name].shape[-1])).sample([params[name].shape[0]])) for name in params}

#start sampling loop
t = tqdm(range(num_samples+num_burn))
theta_samples = 0
mu_samples = 0
for i in t:
    idx = torch.randperm(num_tr)[:S]
    x_batch = csr_to_torchsparse(data_tr[idx.cpu()], gpu)
    theta = params['theta']
    params['theta'] = theta[idx]
    v_theta = vs['theta']
    vs['theta'] = v_theta[idx]

    for name in geodesics:
        geodesics[name].eta = init_etas[name] * ((i+1) ** (-1./5.))
    params, vs = kernel.stochastic_transition(params, vs, geodesics, D.SamJointDistributionWithStickDir(x_batch, alpha, c0, mu0, kappa1))
    
    theta[idx] = params['theta']
    v_theta[idx] = vs['theta']
    params['theta'] = theta
    vs['theta'] = v_theta
    
    theta = params['theta']
    mu = params['mu']

    if torch.any(mu != mu):        
        break    
   
    if i >= num_burn:
        theta_samples += theta
        mu_samples += mu

    if i % 100 == 0:
        print("\ncurrent iteration:", i)         
        print("mu norms", mu.norm(dim=-1).sum(), num_topic)        
        print("sparsity",(torch.abs(mu)**2.).norm(dim=-1))
        print("sparsitymean",(torch.abs(mu)**2.).norm(dim=-1).mean())
        pi = dist.StickBreakingTransform()(theta)
        print("pi sums", pi.sum(dim=-1).sum(), num_tr)         
        
        sum_ll = 0.0
        sum_cs = 0.0
        for j in range(int(np.ceil(num_tr/S))):
            curr_pi = pi[j*S:j*S+S]
            curr_tensor_tr = csr_to_torchsparse(data_tr[j*S:j*S+S], gpu)
            curr_avg = F.normalize(torch.matmul(curr_pi,mu), p=2, dim=-1)
            sum_ll += D.log_prob_von_mises_fisher(kappa1 * curr_avg, curr_tensor_tr).sum()
            sum_cs += D.sparse_dense_dot(curr_tensor_tr, curr_avg).sum()

        print("log likelihood", sum_ll / num_tr)        
        print("cosine similarity", sum_cs / num_tr)
        print("perplexity", vmf_perplexity(tensor_te, mu, kappa1, alpha, N=1000),"\n")

        sum_cs_spread = 0
        count_cs = 0
        for j in range(num_topic-1):
            for k in range(j+1,num_topic):
                sum_cs_spread += (mu[j] * mu[k]).sum(dim=-1)
                count_cs += 1
        print("mean cs spread", sum_cs_spread / count_cs,"\n")
                
    if i % 1000 == 0:
        emb = mu.cpu().numpy()
        print_topics(get_topics(emb,vocab))
        print("")

Wed Dec  9 12:24:27 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    25W / 300W |      0MiB / 16130MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

HBox(children=(FloatProgress(value=0.0, max=5001.0), HTML(value='')))


current iteration: 0
mu norms tensor(100.) 100
sparsity tensor([0.0231, 0.0678, 0.0278, 0.0196, 0.0267, 0.0213, 0.0241, 0.0316, 0.0232,
        0.0227, 0.0381, 0.0288, 0.0339, 0.0354, 0.0404, 0.0467, 0.0460, 0.0427,
        0.0420, 0.0417, 0.0511, 0.0476, 0.0455, 0.0489, 0.0490, 0.0465, 0.0495,
        0.0486, 0.0508, 0.0486, 0.0478, 0.0505, 0.0482, 0.0467, 0.0461, 0.0449,
        0.0463, 0.0468, 0.0447, 0.0409, 0.0434, 0.0443, 0.0456, 0.0402, 0.0428,
        0.0402, 0.0405, 0.0411, 0.0405, 0.0378, 0.0362, 0.0375, 0.0351, 0.0334,
        0.0329, 0.0343, 0.0336, 0.0327, 0.0308, 0.0298, 0.0298, 0.0288, 0.0302,
        0.0268, 0.0286, 0.0257, 0.0268, 0.0234, 0.0236, 0.0231, 0.0246, 0.0236,
        0.0223, 0.0227, 0.0207, 0.0213, 0.0196, 0.0193, 0.0182, 0.0178, 0.0175,
        0.0182, 0.0171, 0.0160, 0.0158, 0.0152, 0.0151, 0.0145, 0.0147, 0.0144,
        0.0142, 0.0140, 0.0140, 0.0136, 0.0136, 0.0136, 0.0134, 0.0133, 0.0132,
        0.0131])
sparsitymean tensor(0.0316)
pi sums tensor(170

# Get Topic Coherence

In [None]:
mu_final = mu_samples / num_samples
theta_final = theta_samples / num_samples
print("final perplexity", vmf_perplexity(tensor_te, mu_final, kappa1, alpha, N=1000))
emb = mu_final.cpu().numpy()
topics = get_topics(emb, vocab)
print('prior_mu:', prior_mu)
print_summary(topics,method,dataset)

if dataset == '20news':
    pi = dist.StickBreakingTransform()(theta_final)
    pi = pi.cpu().numpy()
    clustering_metrics_20news(pi)

final perplexity tensor(-50083.5391)
prior_mu: neg

Method  = sam
Number of topics = 100
Dataset = wiki 

 NPMI       TU         Topic
 0.06257    0.61167    town castle century village population district centre local river railway
 0.14715    0.75000    film films directed movie director festival released starring plot tamil
 0.20300    0.75000    album released songs track tracks albums studio release listing recorded
 0.19833    0.85000    church catholic pope st christian jesus churches saint bishop christ
 0.18090    0.51667    club team scored cup goal football goals career match fc
 0.12780    0.82500    bank tax market financial rate income credit price value investment
 0.20165    0.73333    university college campus students research science professor education institute faculty
 0.22137    0.88333    football nfl yards bowl yard touchdown touchdowns quarterback coach tackles
 0.25180    0.90000    aircraft air flight force wing fighter pilot flying squadron aviation
 0.1628