In [None]:
import os
import time
import pickle
import shelve
import os.path
import chainer

import numpy as np
import chainer.optimizers as O

from chainer import cuda
from lda2vec import utils
from lda2vec import LDA2Vec
from chainer import serializers
from lda2vec import prepare_topics, print_top_words_per_topic, topic_coherence

In [None]:
base_path = 'drive/My Drive/data-topic-modeling/'
# Load Files
fn_vocab = base_path + 'vocab.pkl'
fn_corpus = base_path + 'corpus.pkl'
fn_flatnd = base_path + 'flattened.npy'
fn_docids = base_path + 'doc_ids.npy'
fn_vectors = base_path + 'vectors.npy'
vocab = pickle.load(open(fn_vocab, 'rb'))
corpus = pickle.load(open(fn_corpus, 'rb'))
flattened = np.load(fn_flatnd)
doc_ids = np.load(fn_docids)
vectors = np.load(fn_vectors)

In [None]:
gpu_id = int(os.getenv('CUDA_GPU', 0))
cuda.get_device(gpu_id).use()
print("Using GPU:" + str(gpu_id))

Using GPU:0


In [None]:
# Model Parameters
# Number of documents
n_docs = doc_ids.max() + 1
# Number of unique words in the vocabulary
n_vocab = flattened.max() + 1
# 'Strength' of the dircihlet prior; 200.0 seems to work well
clambda = 200.0
# Number of topics to fit
n_topics = int(os.getenv('n_topics', 20))
batchsize = 4096
# Power for neg sampling
power = float(os.getenv('power', 0.75))
# Intialize with pretrained word vectors
pretrained = bool(int(os.getenv('pretrained', True)))
# Sampling temperature
temperature = float(os.getenv('temperature', 1.0))
# Number of dimensions in a single word vector
n_units = int(os.getenv('n_units', 300))
# Get the string representation for every compact key
words = corpus.word_list(vocab)[:n_vocab]
# How many tokens are in each document
doc_idx, lengths = np.unique(doc_ids, return_counts=True)
doc_lengths = np.zeros(doc_ids.max() + 1, dtype='int32')
doc_lengths[doc_idx] = lengths
# Count all token frequencies
tok_idx, freq = np.unique(flattened, return_counts=True)
term_frequency = np.zeros(n_vocab, dtype='int32')
term_frequency[tok_idx] = freq

In [None]:
for key in sorted(locals().keys()):
    val = locals()[key]
    if len(str(val)) < 100 and '<' not in str(val):
        print(key, val)

Out {}
_ 
__ 
___ 
__doc__ Automatically created module for IPython interactive environment
__loader__ None
__name__ __main__
__package__ None
__spec__ None
_dh ['/content']
_exit_code 0
_i1 !pip install pylda2vec
_i2 !python -m spacy download en_core_web_md
_i3 !pip install jellyfish
_oh {}
base_path drive/My Drive/data-topic-modeling/
batchsize 4096
clambda 200.0
doc_ids [    0     0     0 ... 11008 11008 11008]
doc_idx [    0     1     2 ... 11006 11007 11008]
doc_lengths [100  92 333 ... 115  63  50]
flattened [  10   38 1311 ...   50   50   50]
fn_corpus drive/My Drive/data-topic-modeling/corpus.pkl
fn_docids drive/My Drive/data-topic-modeling/doc_ids.npy
fn_flatnd drive/My Drive/data-topic-modeling/flattened.npy
fn_vectors drive/My Drive/data-topic-modeling/vectors.npy
fn_vocab drive/My Drive/data-topic-modeling/vocab.pkl
freq [105415 103788 100993 ...     30     30     29]
gpu_id 0
lengths [100  92 333 ... 115  63  50]
n_docs 11009
n_topics 20
n_units 300
n_vocab 5845
power 0.75

In [None]:
model = LDA2Vec(n_documents = n_docs, n_document_topics = n_topics,
                n_units = n_units, n_vocab = n_vocab, counts = term_frequency,
                n_samples = 15, power = power, temperature = temperature)

In [None]:
if os.path.exists('lda2vec.hdf5'):
    print("Reloading from saved")
    serializers.load_hdf5("lda2vec.hdf5", model)
    
if pretrained:
    model.sampler.W.data[:, :] = vectors[:n_vocab, :]

In [None]:
model.to_gpu()
optimizer = O.Adam()
optimizer.setup(model)
clip = chainer.optimizer.GradientClipping(5.0)
optimizer.add_hook(clip)

In [None]:
j = 0
epoch = 0
fraction = batchsize * 1.0 / flattened.shape[0]
progress = shelve.open('progress.shelve')

In [None]:
for epoch in range(100):
    data = prepare_topics(cuda.to_cpu(model.mixture.weights.W.data).copy(),
                          cuda.to_cpu(model.mixture.factors.W.data).copy(),
                          cuda.to_cpu(model.sampler.W.data).copy(),
                          words)
    top_words = print_top_words_per_topic(data)
    if j % 100 == 0 and j > 100:
        coherence = topic_coherence(top_words)
        for j in range(n_topics):
            print(j, coherence[(j, 'cv')])
        kw = dict(top_words=top_words, coherence=coherence, epoch=epoch)
        progress[str(epoch)] = pickle.dumps(kw)
    data['doc_lengths'] = doc_lengths
    data['term_frequency'] = term_frequency
    np.savez('topics.pyldavis', **data)
    print(epoch)
    for d, f in utils.chunks(batchsize, doc_ids, flattened):
        t0 = time.time()
        model.cleargrads()
        #optimizer.use_cleargrads(use=False)
        l = model.fit_partial(d.copy(), f.copy())
        if(j%500==0):
          print("after partial fitting:", l)
        prior = model.prior()
        loss = prior * fraction
        loss.backward()
        optimizer.update()
        msg = ("J:{j:05d} E:{epoch:05d} L:{loss:1.3e} "
               "P:{prior:1.3e} R:{rate:1.3e}")
        prior.to_cpu()
        loss.to_cpu()
        t1 = time.time()
        dt = t1 - t0
        rate = batchsize / dt
        logs = dict(loss=float(l), epoch=epoch, j=j,
                    prior=float(prior.data), rate=rate)
        if(j%500==0):
          print(msg.format(**logs))
        j += 1
    serializers.save_hdf5("lda2vec.hdf5", model)

Top words in topic 0 <SKIP> ide scsi megs windows meg card pc ram /
Top words in topic 1 sharks <SKIP> njd nhl ahl season games mvp league standings
Top words in topic 2 ripem encryption pgp <SKIP> ciphertext pem patent cipher cryptanalysis classified
Top words in topic 3 soldiers armenians refugees sharks armenian <SKIP> helicopter inhabitants agdam azeri
Top words in topic 4 sharks <SKIP> nhl innings goalie team scored players games leafs
Top words in topic 5 <SKIP> scsi ide megs meg card turbo drives ram sony
Top words in topic 6 <SKIP> bike bikes honda ride ide helmet riding car duo
Top words in topic 7 <SKIP> $ 10.00 shipping comics 1 games / 25.00 4
Top words in topic 8 <SKIP> islam islamic scholars christians god jesus resurrection jews bible
Top words in topic 9 <SKIP> christians god jesus resurrection soldiers davidians that islam islamic
Top words in topic 10 <SKIP> that msg christians jail you god i convince we
Top words in topic 11 <SKIP> files motif directory interface win