In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.compat import v1 as tf1
from tensorflow.keras import layers as tfkl
import pandas as pd

tfb = tfp.bijectors
tfd = tfp.distributions
tfk = tfp.math.psd_kernels

from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

import sys
sys.path.append('../src/')
import correlated_topic_model as ctmd
import dynamic_correlated_topic_model as dctm
from sklearn import metrics, preprocessing

from imp import reload
import os
from scipy import sparse as sp

In [None]:
# DATASET MAY BE DOWNLOADED FROM
# https://www.kaggle.com/jbencina/department-of-justice-20092018-press-releases/data#

# this assumes that the dataset has been downloaded and extracted in Downloads/combined.json

import datasets
df, corpus, vocabulary = datasets.get_doj('~/Downloads/combined.json')

In [None]:
scaler = preprocessing.MinMaxScaler([-1, 1])
index_points = scaler.fit_transform(df.days[:, None])

np.random.seed(42)
X = np.expand_dims(corpus.todense().astype(np.float64), -2)

(X_tr, X_ts, index_tr, index_ts, X_tr_sorted, X_ts_sorted,
 index_tr_sorted, index_ts_sorted
) = datasets.train_test_split(X, index_points, return_sorted=True)

inverse_transform_fn = lambda x: pd.to_datetime(
    scaler.inverse_transform(x)[:, 0], format='%Y-%m')
df_train = pd.DataFrame(X_tr_sorted[:, 0, :])
df_train['days'] = inverse_transform_fn(index_tr_sorted)

df_test = pd.DataFrame(X_ts_sorted[:, 0, :])
df_test['days'] = inverse_transform_fn(index_ts_sorted)

In [None]:
print("Dataset shape: tr: {}, ts: {}".format(X_tr.shape, X_ts.shape))

In [None]:
# dok_tr = sp.dok_matrix(X_tr_sorted[:, 0, :])
# dok_ts = sp.dok_matrix(X_ts_sorted[:, 0, :])

# name = 'doj'
# save_pickle(dok_tr, '../data/{}_tr_doc.pkl'.format(name))
# save_pickle(dok_ts, '../data/{}_ts_doc.pkl'.format(name))
# save_pickle(vocabulary, '../data/{}_vocabulary.pkl'.format(name))

# save_pickle(index_tr_sorted, '../data/{}_tr_index.pkl'.format(name))
# save_pickle(index_ts_sorted, '../data/{}_ts_index.pkl'.format(name))

# X_sorted = np.vstack((X_tr_sorted[:, 0, :], X_ts_sorted[:, 0, :]))
# print_to_file_for_gdtm(
#     df_train.append(df_test),
#     vocabulary,
#     sp.dok_matrix(X_sorted), filename='doj_all',
#     patth='../data/'
# )

In [None]:
n_training_points = X_tr.shape[0]

batch_size = 50
dataset = tf.data.Dataset.zip(
    tuple(map(tf.data.Dataset.from_tensor_slices,
              (X_tr, index_tr))))
dataset = dataset.shuffle(n_training_points, reshuffle_each_iteration=True)
data_tr = dataset.batch(batch_size)

In [None]:
inducing_index_points_beta = np.linspace(-1, 1, 10)[:, None]
inducing_index_points_mu = np.linspace(-1, 1, 10)[:, None]
inducing_index_points_ell = np.linspace(-1, 1, 10)[:, None]

dtype = np.float64
amplitude_beta = tfp.util.TransformedVariable(
    1., bijector=tfb.Softplus(), dtype=dtype, name='amplitude_beta')
length_scale_beta = tfp.util.TransformedVariable(
    0.5, bijector=tfb.Softplus(), dtype=dtype,
    name='length_scale_beta')
kernel_beta = tfk.MaternOneHalf(amplitude=amplitude_beta, length_scale=length_scale_beta)

amplitude_mu = tfp.util.TransformedVariable(
    1., bijector=tfb.Softplus(), dtype=dtype, name="amplitude_mu")
length_scale_mu = tfp.util.TransformedVariable(
    0.5, bijector=tfb.Softplus(), dtype=dtype,
    name="length_scale_mu")
kernel_mu = tfk.ExponentiatedQuadratic(amplitude=amplitude_mu, length_scale=length_scale_mu)

amplitude_ell = tfp.util.TransformedVariable(
    1., bijector=tfb.Softplus(), dtype=dtype, name='amplitude_ell')
length_scale_ell = tfp.util.TransformedVariable(
    0.5, bijector=tfb.Softplus(), dtype=dtype,
    name='length_scale_ell')
kernel_ell = tfk.ExponentiatedQuadratic(amplitude=amplitude_ell, length_scale=length_scale_ell)

reload(ctmd)
reload(dctm);

mdl = dctm.DCTM(
    n_topics=30, n_words=vocabulary.size,
    kernel_beta=kernel_beta,
    index_points_beta=np.unique(index_tr)[:, None],
    inducing_index_points_beta=inducing_index_points_beta,
    kernel_ell=kernel_ell,
    kernel_mu=kernel_mu,
    index_points_mu=np.unique(index_tr)[:, None],
    index_points_ell=np.unique(index_tr)[:, None],
    inducing_index_points_mu=inducing_index_points_mu,
    inducing_index_points_ell=inducing_index_points_ell,
    layer_sizes=(500, 300, 200),
    jitter_beta=1e-6,
    jitter_mu=1e-5, 
    jitter_ell=1e-6,
    encoder_jitter=1e-8,dtype=dtype)

n_iter = 2
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
losses = []
perplexities = []

In [None]:
# checkpoint_directory = "../tmp/training_checkpoints-30-topics"
# checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
# checkpoint = tf.train.Checkpoint(model=mdl)

# status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
# mdl = checkpoint.model

In [None]:
pbar = tqdm(range(n_iter), disable=False)
with tf.device('gpu'):    
    for epoch in pbar:
        loss_value = 0
        perplexity_value = 0

        for x_batch, index_points_batch in data_tr:
            loss, perpl = mdl.batch_optimize(
                x_batch,
                optimizer=optimizer,
                observation_index_points=index_points_batch,
                trainable_variables=None,
                kl_weight=float(x_batch.shape[0]) / float(n_training_points))
            loss = tf.reduce_mean(loss, 0)
            loss_value += loss
            perplexity_value += perpl
        
        pbar.set_description(
        'loss {:.3e}, perpl {:.3e}'.format(loss_value, perplexity_value))
#         if epoch % 50 == 0:
#             checkpoint.save(file_prefix=checkpoint_prefix)


        losses.append(loss_value)
        perplexities.append(perplexity_value)

In [None]:
# checkpoint.save(file_prefix=checkpoint_prefix)

In [None]:
def perplexity_test(self, X, index_points, batch_size):
    ntot = X.shape[0]
    dataset = tf.data.Dataset.zip(
        tuple(map(tf.data.Dataset.from_tensor_slices, (X, index_points))))
    data_ts = dataset.batch(batch_size)
    
    log_perplexity = []
    for x_batch, index_points_batch in data_tr:
        words_per_document = tf.reduce_sum(input_tensor=x_batch, axis=-1)
        elbo = self.elbo(
            x_batch, observation_index_points=index_points_batch,
            kl_weight=0.)
        log_perplexity.extend([x for x in (-elbo / words_per_document)])
    perplexity = tf.exp(tf.reduce_mean(log_perplexity))
    return perplexity

with tf.device('gpu'):
    perpl = perplexity_test(mdl, X_ts, index_ts, batch_size=100)
    print(perpl)
# 484.62

In [None]:
plt.plot(losses)
plt.semilogy();

In [None]:
inverse_transform_fn = lambda x: pd.to_datetime(scaler.inverse_transform(x)[:, 0]).strftime('%Y-%m')

reload(dctm)
tops = dctm.print_topics(
    mdl, index_points=np.unique(index_tr)[::10], vocabulary=vocabulary,
    inverse_transform_fn=inverse_transform_fn, top_n_topic=5, top_n_time=5)

In [None]:
test_points = np.linspace(-1,1, 100)[:,None]
corr_sample, Sigma_sample = dctm.get_correlation(mdl.surrogate_posterior_ell.sample(1200, index_points=test_points))
corr_10p = tfp.stats.percentile(corr_sample, 5, axis=0)
corr = tfp.stats.percentile(corr_sample, 50, axis=0)
corr_90p = tfp.stats.percentile(corr_sample, 95, axis=0)
Sigma_10p = tfp.stats.percentile(Sigma_sample, 5, axis=0)
Sigma = tfp.stats.percentile(Sigma_sample, 50, axis=0)
Sigma_90p = tfp.stats.percentile(Sigma_sample, 95, axis=0)

In [None]:
mdl.n_topics = mdl.surrogate_posterior_beta.batch_shape[1]

In [None]:
def top_words(beta, vocab, top_n=10):
    # account for multiple times -> in this case returns
    # the most common (unique) words across time
    # beta is for a single topic
    dd = tf.reshape(
        tf.tile(tf.expand_dims(vocab, -1), [1, beta.shape[-1]]), [-1])
    idx = tf.argsort(tf.reshape(beta, [-1]))[::-1].numpy()

    dd = iter(dd.numpy()[idx])
    top_words = []
    while len(top_words) < top_n:
        x = next(dd).decode('utf8')
        if x not in top_words:
            top_words.append(x)
    return top_words

    for topic_num in range(mdl.n_topics):
        wt = words_topic[:, topic_num, :]
        topics.append(' '.join(top_words(wt, vocabulary, top_n=top_n_topic)))
        print('Topic {}: {}'.format(topic_num, topics[-1]))
        for t, time in enumerate(times_display):
            topics_t = (
                top_words(wt[:, t, None], vocabulary, top_n=top_n_time))
            print('- at t={}: {}'.format(time, ' '.join(topics_t)))
    return topics


In [None]:
topics = tops
topic_num = 0

plt.title("Topic {}: {}".format(topic_num, topics[topic_num]))
for t in range(mdl.n_topics)[:10]:
    if t == topic_num:# or t not in [1,8]:
        continue
    plt.plot(corr[:, topic_num, t], label='{}:{}'.format(t, topics[t]))
plt.xticks(range(test_points.size)[::10], inverse_transform_fn(test_points)[::10], rotation=45);
# plt.legend();

plt.gca().legend(loc='center left', bbox_to_anchor=(1, 0.5));

In [None]:
topic_tr = mdl.predict(X_tr)[:,0,:].numpy()
cc = np.zeros([mdl.n_topics, np.unique(index_tr).size])
for j, i in enumerate(np.unique(index_tr)):
    idx = (np.abs(index_tr-i)<1e-7).flatten()
    cc[:, j] = topic_tr[idx].mean(0)

In [None]:
mu = tf.nn.softmax((mdl.surrogate_posterior_mu.get_marginal_distribution(test_points).mean()), axis=0)

colors = plt.cm.jet(np.linspace(0, 1, mdl.n_topics))
for i in range(30):
    for t in range(i,i+1):
        plt.plot(test_points, mu[t], label=topics[t], color=colors[i]);

    for t in range(i,i+1):
        plt.plot(np.unique(index_tr), cc[t], label='{}'.format(topics[t]), color=colors[t])

    plt.xticks(test_points[::10], inverse_transform_fn(test_points)[::10], rotation=45);
    plt.gca().legend(loc='center left', bbox_to_anchor=(1, 0.5));
    plt.show()