In [177]:
import torch
import torch.nn as nn
from torch.autograd import grad

from sklearn.datasets import fetch_20newsgroups

import gensim

from collections import Counter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load and preprocess data

In [314]:
# Load documents
documents = fetch_20newsgroups(subset='train', remove=('headers', 'footers',  'quotes'), shuffle=True)

# Preprocess with gensim and remove stopwords
documents = [[word for word in gensim.utils.simple_preprocess(document) if not word in gensim.parsing.preprocessing.STOPWORDS]
             for document in documents['data'][:100]]

# Remove rare words by applying a frequency threshold
freq_threshold = 5
counts = Counter([word for document in documents for word in document])

documents = [list(filter(lambda word : counts[word] > freq_threshold, document)) for document in documents]
documents = list(filter(lambda document : len(document) > 0, documents))

# Map words to integer indices
counts = Counter([word for document in documents for word in document])

word_idx_dict = {word : i for i, word in enumerate(counts.keys())}
idx_word_dict = {i : word for word, i in word_idx_dict.items()}

documents = [torch.tensor(list(map(lambda word : word_idx_dict[word], document))) for document in documents]

In [335]:
class VariationalLDA(nn.Module):
    
    def __init__(self, documents, vocab_size, num_topics, log_eta, log_alpha):
        
        super(VariationalLDA, self).__init__()
        
        self.D = len(documents)
        self.K = num_topics
        self.V = vocab_size
        
        self.words = documents
        
        k_ones = torch.ones(size=(self.K,)).to(device)
        
        self.log_eta = nn.Parameter(torch.tensor(log_eta).to(device))
        self.log_alpha = nn.Parameter((log_alpha * k_ones + torch.rand(self.K).to(device)).to(device))
        
        self.gamma = [torch.exp(log_alpha * k_ones) + k_ones * len(Counter(document)) / self.K
                      for document in documents]
        
        self.phi = [torch.ones(size=(document.shape[0], self.K)).to(device) / self.K
                    for document in documents]
        
        self.lamda = torch.ones(size=(self.K, self.V)).to(device)
    
    
    def elbo(self):
        
        self.eta = torch.exp(self.log_eta)
        self.alpha = torch.exp(self.log_alpha)
        
        elbo = 0
        
        eta_vector = self.eta * torch.ones(self.V).to(device)
        
        lamda_digamma_diff = self.digamma_difference(self.lamda)
        gamma_digamma_diff = [self.digamma_difference(gamma_d) for gamma_d in self.gamma]
        
        eta_log_gamma_diff = self.log_gamma_difference(eta_vector)
        alpha_log_gamma_diff = self.log_gamma_difference(self.alpha)
        gamma_log_gamma_diff = [self.log_gamma_difference(gamma_d) for gamma_d in self.gamma]
        lamda_log_gamma_diff = self.log_gamma_difference(self.lamda)
        
        elbo = elbo + sum([torch.einsum('nk, kn ->', phi_d, lamda_digamma_diff[:, words_d])
                           for phi_d, words_d in zip(self.phi, self.words)])
        
        elbo = elbo + sum([torch.einsum('nk, k ->', phi_d, gamma_digamma_diff_d)
                           for phi_d, gamma_digamma_diff_d in zip(self.phi, gamma_digamma_diff)])
        
        elbo = elbo + self.D * alpha_log_gamma_diff
        
        elbo = elbo + sum([torch.einsum('k, k ->', self.alpha - 1, gamma_digamma_diff_d)
                           for gamma_digamma_diff_d in gamma_digamma_diff])
        
        elbo = elbo + self.K * eta_log_gamma_diff
        
        elbo = elbo + torch.einsum('v, kv ->', eta_vector - 1, lamda_digamma_diff)
        
        elbo = elbo - sum([torch.sum(phi_d * torch.log(phi_d)) for phi_d in self.phi])
        
        elbo = elbo - sum(gamma_log_gamma_diff)
        
        elbo = elbo - sum([torch.einsum('k, k ->', gamma_d - 1, gamma_digamma_diff_d)
                           for gamma_d, gamma_digamma_diff_d in zip(self.gamma, gamma_digamma_diff)])
        
        elbo = elbo - torch.sum(lamda_log_gamma_diff)
        
        elbo = elbo - torch.einsum('kv, kv ->', self.lamda - 1, lamda_digamma_diff)
        
        return elbo
    
    
    def forward(self):
        return - self.elbo()
    
    
    def variational_parameter_step(self):
        pass
    
    
    def model_parameter_step(self):
        pass
    
    
    def perplexity(self):
        pass
    
    
    def digamma_difference(self, tensor):
        return torch.digamma(tensor) - torch.digamma(torch.sum(tensor, dim=-1)[..., None])
    
    
    def log_gamma_difference(self, tensor):
        return torch.lgamma(torch.sum(tensor, dim=-1)) - torch.sum(torch.lgamma(tensor), dim=-1)

In [336]:
lda = VariationalLDA(documents, len(counts), num_topics=5, log_eta=0., log_alpha=0.)

In [None]:
optimizer = torch.optim.SGD(lda.parameters(), lr=1e-5)

for i in range(1000):
    
    if i % 100 == 0:
        print(torch.exp(lda.log_eta).cpu().detach().numpy())
        
        print(torch.exp(lda.log_alpha).cpu().detach().numpy())

    optimizer.zero_grad()

    neg_elbo = lda.forward()

    neg_elbo.backward()

    optimizer.step()

1.0
[1.0951843 2.1340022 2.1017673 1.1532717 1.3771474]
1.0
[1.1744578 2.0392492 2.0164607 1.230205  1.4367872]
1.0
[1.2514659 1.9743441 1.9577085 1.3039268 1.4912012]
1.0
[1.3257719 1.9308741 1.9184316 1.3742497 1.5414073]
1.0
[1.3970369 1.9033227 1.8938441 1.4410876 1.5882629]
1.0
[1.4650337 1.8879035 1.8805801 1.504446  1.6324649]
1.0
[1.5296443 1.8819125 1.8761939 1.564409  1.6745648]
