# Example for LDA (Latent Dirichlet Allocation)

In [1]:
using LinearAlgebra
using PyPlot, PyCall
using ArgParse
using Distributions

push!(LOAD_PATH, ".")
import LDA

# Example for 20 news groups

In [14]:
function newsgroup_example()
    datasets = pyimport("sklearn.datasets")
    categories = ["alt.atheism", "soc.religion.christian", "comp.graphics", "sci.med"]
    twenty_train = datasets.fetch_20newsgroups(subset="train", categories=categories,
                                               shuffle=true, random_state=0)
    
    np = pyimport("numpy")
    fe_text = pyimport("sklearn.feature_extraction.text")
    cnt_vect = fe_text.CountVectorizer(dtype=np.int32)   
    X_train_counts = cnt_vect.fit_transform(twenty_train["data"])
    
    D = X_train_counts.shape[1]  # Number of documents
    K = 4  # Number of topics
    V = X_train_counts.shape[2]  # Number of vocabulary
    alpha = ones(D, K)
    beta = ones(K, V)

    prior = LDA.LDAModel(D, K, V, alpha, beta)
    max_iter = 100
    posterior, S_est = LDA.VI(X_train_counts.toarray(), prior, max_iter)
    
    return twenty_train, posterior, S_est
end

newsgroup_example (generic function with 1 method)

In [None]:
twenty_train, posterior, S_est = newsgroup_example()