In [175]:
%load_ext autoreload
from lda_transformer import LDATransformer
from sklearn.datasets import fetch_20newsgroups
dataset = fetch_20newsgroups(shuffle=True, random_state=1,
                             remove=('headers', 'footers', 'quotes'))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [227]:
from os.path import exists, join
from os import makedirs
import urllib
import pickle
import numpy as np
import scipy 

ROOT_PATH = "https://github.com/akashgit/autoencoding_vi_for_topic_models/raw/9db556361409ecb3a732f99b4ef207aeb8516f83/data/20news_clean"
FILE_TEMPLATE = "{split}.txt.npy"
data_dir = 'data'

with open(download(data_dir, "vocab.pkl"), "rb") as f:
    words_to_idx = pickle.load(f)

def download(directory, filename):
    """Download a file."""
    filepath = join(directory, filename)
    if exists(filepath):
        return filepath
    if not exists(directory):
        makedirs(directory)
    url = join(ROOT_PATH, filename)
    print("Downloading %s to %s" % (url, filepath))
    urllib.request.urlretrieve(url, filepath)
    return filepath

def newsgroups_dataset(directory, split_name, num_words):
    """Return 20 newsgroups tf.data.Dataset."""
    data = np.load(download(directory, FILE_TEMPLATE.format(split=split_name)), encoding='bytes', allow_pickle=True)
    # The last row is empty in both train and test.
    data = data[:-1]

    # Each row is a list of word ids in the document. We first convert this to
    # sparse COO matrix (which automatically sums the repeating words). Then,
    # we convert this COO matrix to CSR format which allows for fast querying of
    # documents.
    num_documents = data.shape[0]
    print(f'num_documents: {num_documents}')
    print(f'num_words: {num_words}')
    indices = np.array([(row_idx, column_idx)
                      for row_idx, row in enumerate(data)
                      for column_idx in row])
    sparse_matrix = scipy.sparse.coo_matrix(
      (np.ones(indices.shape[0]), (indices[:, 0], indices[:, 1])),
      shape=(num_documents, num_words),
      dtype=np.float32)
    sparse_matrix = sparse_matrix.tocsr()

    return sparse_matrix
tr = newsgroups_dataset('data', 'train', len(words_to_idx))

num_documents: 11258
num_words: 1995


In [233]:
val = newsgroups_dataset('data', 'test', len(words_to_idx))

num_documents: 7487
num_words: 1995


In [242]:
val

<7487x1995 sparse matrix of type '<class 'numpy.float32'>'
	with 410358 stored elements in Compressed Sparse Row format>

In [234]:
# %autoreload
# t = LDATransformer()
# t.fit(dataset['data'])
# X = t.transform(dataset['data'])

In [235]:
# doc, row = zip(*X)

In [236]:
# max(row)

In [237]:
# len(set(doc))

In [238]:
# len(set(row))

In [239]:
# len(X)

In [None]:
%autoreload
from lda import LDA
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)

# g = tf.Graph()
# with g.as_default():
m = LDA(n_topics=50, learning_rate=3e-4)
m.fit(tr, X_val=val)

W0811 08:34:33.794562 4417594816 lda.py:414] Deleting old log directory at /tmp/lda/
I0811 08:34:33.803490 4417594816 estimator.py:209] Using config: {'_model_dir': '/tmp/lda/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 10000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x2ebae39d0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
I0811 08:34:33.850023 

I0811 08:35:11.178234 4417594816 basic_session_run_hooks.py:260] loss = 540.8634, step = 3401 (0.948 sec)
I0811 08:35:12.151253 4417594816 basic_session_run_hooks.py:692] global_step/sec: 102.592
I0811 08:35:12.152925 4417594816 basic_session_run_hooks.py:260] loss = 715.9923, step = 3501 (0.975 sec)
I0811 08:35:13.124464 4417594816 basic_session_run_hooks.py:692] global_step/sec: 102.754
I0811 08:35:13.126034 4417594816 basic_session_run_hooks.py:260] loss = 564.7883, step = 3601 (0.973 sec)
I0811 08:35:14.102807 4417594816 basic_session_run_hooks.py:692] global_step/sec: 102.213
I0811 08:35:14.104517 4417594816 basic_session_run_hooks.py:260] loss = 616.4719, step = 3701 (0.978 sec)
I0811 08:35:15.075431 4417594816 basic_session_run_hooks.py:692] global_step/sec: 102.814
I0811 08:35:15.077633 4417594816 basic_session_run_hooks.py:260] loss = 393.18628, step = 3801 (0.973 sec)
I0811 08:35:16.121706 4417594816 basic_session_run_hooks.py:692] global_step/sec: 95.5777
I0811 08:35:16.1232

I0811 08:35:54.878182 4417594816 basic_session_run_hooks.py:260] loss = 409.30875, step = 7601 (1.109 sec)
I0811 08:35:55.955450 4417594816 basic_session_run_hooks.py:692] global_step/sec: 92.6002
I0811 08:35:55.958425 4417594816 basic_session_run_hooks.py:260] loss = 420.68973, step = 7701 (1.080 sec)
I0811 08:35:56.977673 4417594816 basic_session_run_hooks.py:692] global_step/sec: 97.8259
I0811 08:35:56.979330 4417594816 basic_session_run_hooks.py:260] loss = 518.0114, step = 7801 (1.021 sec)
I0811 08:35:57.989398 4417594816 basic_session_run_hooks.py:692] global_step/sec: 98.8405
I0811 08:35:57.991311 4417594816 basic_session_run_hooks.py:260] loss = 648.2583, step = 7901 (1.012 sec)
I0811 08:35:59.036250 4417594816 basic_session_run_hooks.py:692] global_step/sec: 95.5247
I0811 08:35:59.037374 4417594816 basic_session_run_hooks.py:260] loss = 635.4867, step = 8001 (1.046 sec)
I0811 08:36:00.079560 4417594816 basic_session_run_hooks.py:692] global_step/sec: 95.849
I0811 08:36:00.0811

elbo
-596.42816

kl
0.00032807313

loss
596.4841

perplexity
1153.8445

reconstruction
-596.4278

global_step
10000




I0811 08:36:23.186531 4417594816 estimator.py:1147] Done calling model_fn.
I0811 08:36:23.188687 4417594816 basic_session_run_hooks.py:541] Create CheckpointSaverHook.
I0811 08:36:23.308595 4417594816 monitored_session.py:240] Graph was finalized.
I0811 08:36:23.311723 4417594816 saver.py:1280] Restoring parameters from /tmp/lda/model.ckpt-10000
I0811 08:36:23.398283 4417594816 session_manager.py:500] Running local_init_op.
I0811 08:36:23.419626 4417594816 session_manager.py:502] Done running local_init_op.
I0811 08:36:23.925636 4417594816 basic_session_run_hooks.py:606] Saving checkpoints for 10000 into /tmp/lda/model.ckpt.
I0811 08:36:24.365496 4417594816 basic_session_run_hooks.py:262] loss = 495.42853, step = 10001
I0811 08:36:25.407444 4417594816 basic_session_run_hooks.py:692] global_step/sec: 95.9377
I0811 08:36:25.409293 4417594816 basic_session_run_hooks.py:260] loss = 588.6762, step = 10101 (1.044 sec)
I0811 08:36:26.421411 4417594816 basic_session_run_hooks.py:692] global_st

I0811 08:37:03.980357 4417594816 basic_session_run_hooks.py:692] global_step/sec: 100.054
I0811 08:37:03.981637 4417594816 basic_session_run_hooks.py:260] loss = 507.4223, step = 13901 (0.998 sec)
I0811 08:37:04.974192 4417594816 basic_session_run_hooks.py:692] global_step/sec: 100.62
I0811 08:37:04.975417 4417594816 basic_session_run_hooks.py:260] loss = 387.39078, step = 14001 (0.994 sec)
I0811 08:37:05.971054 4417594816 basic_session_run_hooks.py:692] global_step/sec: 100.315
I0811 08:37:05.972491 4417594816 basic_session_run_hooks.py:260] loss = 618.4244, step = 14101 (0.997 sec)
I0811 08:37:06.958670 4417594816 basic_session_run_hooks.py:692] global_step/sec: 101.254
I0811 08:37:06.960695 4417594816 basic_session_run_hooks.py:260] loss = 524.05084, step = 14201 (0.988 sec)
I0811 08:37:07.905459 4417594816 basic_session_run_hooks.py:692] global_step/sec: 105.619
I0811 08:37:07.906853 4417594816 basic_session_run_hooks.py:260] loss = 292.97815, step = 14301 (0.946 sec)
I0811 08:37:0

I0811 08:37:46.354999 4417594816 basic_session_run_hooks.py:692] global_step/sec: 96.9873
I0811 08:37:46.357136 4417594816 basic_session_run_hooks.py:260] loss = 414.1195, step = 18101 (1.032 sec)
I0811 08:37:47.352155 4417594816 basic_session_run_hooks.py:692] global_step/sec: 100.285
I0811 08:37:47.353559 4417594816 basic_session_run_hooks.py:260] loss = 1215.5198, step = 18201 (0.996 sec)
I0811 08:37:48.352243 4417594816 basic_session_run_hooks.py:692] global_step/sec: 99.9912
I0811 08:37:48.354239 4417594816 basic_session_run_hooks.py:260] loss = 469.1147, step = 18301 (1.001 sec)
I0811 08:37:49.361339 4417594816 basic_session_run_hooks.py:692] global_step/sec: 99.0985
I0811 08:37:49.363368 4417594816 basic_session_run_hooks.py:260] loss = 794.35876, step = 18401 (1.009 sec)
I0811 08:37:50.328892 4417594816 basic_session_run_hooks.py:692] global_step/sec: 103.353
I0811 08:37:50.330676 4417594816 basic_session_run_hooks.py:260] loss = 868.2528, step = 18501 (0.967 sec)
I0811 08:37:5

elbo
-593.57776

kl
2.2910603e-05

loss
593.63367

perplexity
1107.724

reconstruction
-593.57776

global_step
20000




I0811 08:38:09.185754 4417594816 estimator.py:1147] Done calling model_fn.
I0811 08:38:09.188213 4417594816 basic_session_run_hooks.py:541] Create CheckpointSaverHook.
I0811 08:38:09.304349 4417594816 monitored_session.py:240] Graph was finalized.
I0811 08:38:09.306898 4417594816 saver.py:1280] Restoring parameters from /tmp/lda/model.ckpt-20000
I0811 08:38:09.396704 4417594816 session_manager.py:500] Running local_init_op.
I0811 08:38:09.417471 4417594816 session_manager.py:502] Done running local_init_op.
I0811 08:38:09.918674 4417594816 basic_session_run_hooks.py:606] Saving checkpoints for 20000 into /tmp/lda/model.ckpt.
I0811 08:38:10.344127 4417594816 basic_session_run_hooks.py:262] loss = 489.34753, step = 20001
I0811 08:38:11.332593 4417594816 basic_session_run_hooks.py:692] global_step/sec: 101.112
I0811 08:38:11.333951 4417594816 basic_session_run_hooks.py:260] loss = 688.89624, step = 20101 (0.990 sec)
I0811 08:38:12.287664 4417594816 basic_session_run_hooks.py:692] global_s

I0811 08:38:49.641997 4417594816 basic_session_run_hooks.py:692] global_step/sec: 105.492
I0811 08:38:49.644459 4417594816 basic_session_run_hooks.py:260] loss = 1075.2699, step = 23901 (0.948 sec)
I0811 08:38:50.598524 4417594816 basic_session_run_hooks.py:692] global_step/sec: 104.545
I0811 08:38:50.600356 4417594816 basic_session_run_hooks.py:260] loss = 448.86438, step = 24001 (0.956 sec)
I0811 08:38:51.553316 4417594816 basic_session_run_hooks.py:692] global_step/sec: 104.735
I0811 08:38:51.555409 4417594816 basic_session_run_hooks.py:260] loss = 414.18622, step = 24101 (0.955 sec)
I0811 08:38:52.566509 4417594816 basic_session_run_hooks.py:692] global_step/sec: 98.6982
I0811 08:38:52.567945 4417594816 basic_session_run_hooks.py:260] loss = 526.42114, step = 24201 (1.013 sec)
I0811 08:38:53.617048 4417594816 basic_session_run_hooks.py:692] global_step/sec: 95.1888
I0811 08:38:53.618710 4417594816 basic_session_run_hooks.py:260] loss = 431.11874, step = 24301 (1.051 sec)
I0811 08:3

I0811 08:39:32.708058 4417594816 basic_session_run_hooks.py:692] global_step/sec: 95.765
I0811 08:39:32.709293 4417594816 basic_session_run_hooks.py:260] loss = 521.0239, step = 28101 (1.044 sec)
I0811 08:39:33.807689 4417594816 basic_session_run_hooks.py:692] global_step/sec: 90.94
I0811 08:39:33.809469 4417594816 basic_session_run_hooks.py:260] loss = 660.8716, step = 28201 (1.100 sec)
I0811 08:39:34.807759 4417594816 basic_session_run_hooks.py:692] global_step/sec: 99.9923
I0811 08:39:34.809625 4417594816 basic_session_run_hooks.py:260] loss = 755.045, step = 28301 (1.000 sec)
I0811 08:39:35.850054 4417594816 basic_session_run_hooks.py:692] global_step/sec: 95.9427
I0811 08:39:35.851583 4417594816 basic_session_run_hooks.py:260] loss = 687.1467, step = 28401 (1.042 sec)
I0811 08:39:36.847403 4417594816 basic_session_run_hooks.py:692] global_step/sec: 100.266
I0811 08:39:36.849786 4417594816 basic_session_run_hooks.py:260] loss = 573.83356, step = 28501 (0.998 sec)
I0811 08:39:37.900

elbo
-593.3253

kl
7.657927e-06

loss
593.3813

perplexity
1103.8638

reconstruction
-593.3253

global_step
30000




I0811 08:39:55.818456 4417594816 estimator.py:1147] Done calling model_fn.
I0811 08:39:55.820935 4417594816 basic_session_run_hooks.py:541] Create CheckpointSaverHook.
I0811 08:39:55.945163 4417594816 monitored_session.py:240] Graph was finalized.
I0811 08:39:55.947604 4417594816 saver.py:1280] Restoring parameters from /tmp/lda/model.ckpt-30000
I0811 08:39:56.037698 4417594816 session_manager.py:500] Running local_init_op.
I0811 08:39:56.059741 4417594816 session_manager.py:502] Done running local_init_op.
I0811 08:39:56.549352 4417594816 basic_session_run_hooks.py:606] Saving checkpoints for 30000 into /tmp/lda/model.ckpt.
I0811 08:39:56.973101 4417594816 basic_session_run_hooks.py:262] loss = 623.6709, step = 30001
I0811 08:39:58.038998 4417594816 basic_session_run_hooks.py:692] global_step/sec: 93.7873
I0811 08:39:58.040432 4417594816 basic_session_run_hooks.py:260] loss = 544.70105, step = 30101 (1.067 sec)
I0811 08:39:58.997410 4417594816 basic_session_run_hooks.py:692] global_st

I0811 08:40:37.910223 4417594816 basic_session_run_hooks.py:692] global_step/sec: 88.6927
I0811 08:40:37.911984 4417594816 basic_session_run_hooks.py:260] loss = 944.24976, step = 33901 (1.127 sec)
I0811 08:40:38.969446 4417594816 basic_session_run_hooks.py:692] global_step/sec: 94.4085
I0811 08:40:38.971693 4417594816 basic_session_run_hooks.py:260] loss = 481.62778, step = 34001 (1.060 sec)
I0811 08:40:40.150248 4417594816 basic_session_run_hooks.py:692] global_step/sec: 84.6882
I0811 08:40:40.151797 4417594816 basic_session_run_hooks.py:260] loss = 1316.8501, step = 34101 (1.180 sec)
I0811 08:40:41.262739 4417594816 basic_session_run_hooks.py:692] global_step/sec: 89.8883
I0811 08:40:41.264142 4417594816 basic_session_run_hooks.py:260] loss = 690.1309, step = 34201 (1.112 sec)
I0811 08:40:42.399588 4417594816 basic_session_run_hooks.py:692] global_step/sec: 87.9635
I0811 08:40:42.401073 4417594816 basic_session_run_hooks.py:260] loss = 624.7378, step = 34301 (1.137 sec)
I0811 08:40: