-
Notifications
You must be signed in to change notification settings - Fork 223
/
prodlda.py
349 lines (292 loc) · 11.9 KB
/
prodlda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
Example: ProdLDA with Flax and Haiku
====================================
In this example, we will follow [1] to implement the ProdLDA topic model from
Autoencoding Variational Inference For Topic Models by Akash Srivastava and Charles
Sutton [2]. This model returns consistently better topics than vanilla LDA and trains
much more quickly. Furthermore, it does not require a custom inference algorithm that
relies on complex mathematical derivations. This example also serves as an
introduction to Flax and Haiku modules in NumPyro.
Note that unlike [1, 2], this implementation uses a Dirichlet prior directly rather
than approximating it with a softmax-normal distribution.
For the interested reader, a nice extension of this model is the CombinedTM model [3]
which utilizes a pre-trained sentence transformer (like https://www.sbert.net/) to
generate a better representation of the encoded latent vector.
**References:**
1. http://pyro.ai/examples/prodlda.html
2. Akash Srivastava, & Charles Sutton. (2017). Autoencoding Variational Inference
For Topic Models.
3. Federico Bianchi, Silvia Terragni, and Dirk Hovy (2021), "Pre-training is a Hot
Topic: Contextualized Document Embeddings Improve Topic Coherence"
(https://arxiv.org/abs/2004.03974)
.. image:: ../_static/img/examples/prodlda.png
:align: center
"""
import argparse
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from wordcloud import WordCloud
import flax.linen as nn
import haiku as hk
import jax
from jax import device_put, random
import jax.numpy as jnp
import numpyro
from numpyro.contrib.module import flax_module, haiku_module
import numpyro.distributions as dist
from numpyro.infer import SVI, TraceMeanField_ELBO
class HaikuEncoder:
def __init__(self, vocab_size, num_topics, hidden, dropout_rate):
self._vocab_size = vocab_size
self._num_topics = num_topics
self._hidden = hidden
self._dropout_rate = dropout_rate
def __call__(self, inputs, is_training):
dropout_rate = self._dropout_rate if is_training else 0.0
h = jax.nn.softplus(hk.Linear(self._hidden)(inputs))
h = jax.nn.softplus(hk.Linear(self._hidden)(h))
h = hk.dropout(hk.next_rng_key(), dropout_rate, h)
h = hk.Linear(self._num_topics)(h)
# NB: here we set `create_scale=False` and `create_offset=False` to reduce
# the number of learning parameters
log_concentration = hk.BatchNorm(
create_scale=False, create_offset=False, decay_rate=0.9
)(h, is_training)
return jnp.exp(log_concentration)
class HaikuDecoder:
def __init__(self, vocab_size, dropout_rate):
self._vocab_size = vocab_size
self._dropout_rate = dropout_rate
def __call__(self, inputs, is_training):
dropout_rate = self._dropout_rate if is_training else 0.0
h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs)
h = hk.Linear(self._vocab_size, with_bias=False)(h)
return hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(
h, is_training
)
class FlaxEncoder(nn.Module):
vocab_size: int
num_topics: int
hidden: int
dropout_rate: float
@nn.compact
def __call__(self, inputs, is_training):
h = nn.softplus(nn.Dense(self.hidden)(inputs))
h = nn.softplus(nn.Dense(self.hidden)(h))
h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h)
h = nn.Dense(self.num_topics)(h)
log_concentration = nn.BatchNorm(
use_bias=False,
use_scale=False,
momentum=0.9,
use_running_average=not is_training,
)(h)
return jnp.exp(log_concentration)
class FlaxDecoder(nn.Module):
vocab_size: int
dropout_rate: float
@nn.compact
def __call__(self, inputs, is_training):
h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(inputs)
h = nn.Dense(self.vocab_size, use_bias=False)(h)
return nn.BatchNorm(
use_bias=False,
use_scale=False,
momentum=0.9,
use_running_average=not is_training,
)(h)
def model(docs, hyperparams, is_training=False, nn_framework="flax"):
if nn_framework == "flax":
decoder = flax_module(
"decoder",
FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]),
input_shape=(1, hyperparams["num_topics"]),
# ensure PRNGKey is made available to dropout layers
apply_rng=["dropout"],
# indicate mutable state due to BatchNorm layers
mutable=["batch_stats"],
# to ensure proper initialisation of BatchNorm we must
# initialise with is_training=True
is_training=True,
)
elif nn_framework == "haiku":
decoder = haiku_module(
"decoder",
# use `transform_with_state` for BatchNorm
hk.transform_with_state(
HaikuDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"])
),
input_shape=(1, hyperparams["num_topics"]),
apply_rng=True,
# to ensure proper initialisation of BatchNorm we must
# initialise with is_training=True
is_training=True,
)
else:
raise ValueError(f"Invalid choice {nn_framework} for argument nn_framework")
with numpyro.plate(
"documents", docs.shape[0], subsample_size=hyperparams["batch_size"]
):
batch_docs = numpyro.subsample(docs, event_dim=1)
theta = numpyro.sample(
"theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"]))
)
if nn_framework == "flax":
logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()})
elif nn_framework == "haiku":
logits = decoder(numpyro.prng_key(), theta, is_training)
total_count = batch_docs.sum(-1)
numpyro.sample(
"obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs
)
def guide(docs, hyperparams, is_training=False, nn_framework="flax"):
if nn_framework == "flax":
encoder = flax_module(
"encoder",
FlaxEncoder(
hyperparams["vocab_size"],
hyperparams["num_topics"],
hyperparams["hidden"],
hyperparams["dropout_rate"],
),
input_shape=(1, hyperparams["vocab_size"]),
# ensure PRNGKey is made available to dropout layers
apply_rng=["dropout"],
# indicate mutable state due to BatchNorm layers
mutable=["batch_stats"],
# to ensure proper initialisation of BatchNorm we must
# initialise with is_training=True
is_training=True,
)
elif nn_framework == "haiku":
encoder = haiku_module(
"encoder",
# use `transform_with_state` for BatchNorm
hk.transform_with_state(
HaikuEncoder(
hyperparams["vocab_size"],
hyperparams["num_topics"],
hyperparams["hidden"],
hyperparams["dropout_rate"],
)
),
input_shape=(1, hyperparams["vocab_size"]),
apply_rng=True,
# to ensure proper initialisation of BatchNorm we must
# initialise with is_training=True
is_training=True,
)
else:
raise ValueError(f"Invalid choice {nn_framework} for argument nn_framework")
with numpyro.plate(
"documents", docs.shape[0], subsample_size=hyperparams["batch_size"]
):
batch_docs = numpyro.subsample(docs, event_dim=1)
if nn_framework == "flax":
concentration = encoder(
batch_docs, is_training, rngs={"dropout": numpyro.prng_key()}
)
elif nn_framework == "haiku":
concentration = encoder(numpyro.prng_key(), batch_docs, is_training)
numpyro.sample("theta", dist.Dirichlet(concentration))
def load_data():
news = fetch_20newsgroups(subset="all")
vectorizer = CountVectorizer(max_df=0.5, min_df=20, stop_words="english")
docs = jnp.array(vectorizer.fit_transform(news["data"]).toarray())
vocab = pd.DataFrame(columns=["word", "index"])
vocab["word"] = vectorizer.get_feature_names_out()
vocab["index"] = vocab.index
return docs, vocab
def run_inference(docs, args):
rng_key = random.PRNGKey(0)
docs = device_put(docs)
hyperparams = dict(
vocab_size=docs.shape[1],
num_topics=args.num_topics,
hidden=args.hidden,
dropout_rate=args.dropout_rate,
batch_size=args.batch_size,
)
optimizer = numpyro.optim.Adam(args.learning_rate)
svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO())
return svi.run(
rng_key,
args.num_steps,
docs,
hyperparams,
is_training=True,
progress_bar=not args.disable_progbar,
nn_framework=args.nn_framework,
)
def plot_word_cloud(b, ax, vocab, n):
indices = jnp.argsort(b)[::-1]
top20 = indices[:20]
df = pd.DataFrame(top20, columns=["index"])
words = pd.merge(df, vocab[["index", "word"]], how="left", on="index")[
"word"
].values.tolist()
sizes = b[top20].tolist()
freqs = {words[i]: sizes[i] for i in range(len(words))}
wc = WordCloud(background_color="white", width=800, height=500)
wc = wc.generate_from_frequencies(freqs)
ax.set_title(f"Topic {n + 1}")
ax.imshow(wc, interpolation="bilinear")
ax.axis("off")
def main(args):
docs, vocab = load_data()
print(f"Dictionary size: {len(vocab)}")
print(f"Corpus size: {docs.shape}")
svi_result = run_inference(docs, args)
if args.nn_framework == "flax":
beta = svi_result.params["decoder$params"]["Dense_0"]["kernel"]
elif args.nn_framework == "haiku":
beta = svi_result.params["decoder$params"]["linear"]["w"]
beta = jax.nn.softmax(beta)
# the number of plots depends on the chosen number of topics.
# add 2 to num topics to ensure we create a row for any remainder after division
nrows = (args.num_topics + 2) // 3
fig, axs = plt.subplots(nrows, 3, figsize=(14, 3 + 3 * nrows))
axs = axs.flatten()
for n in range(beta.shape[0]):
plot_word_cloud(beta[n], axs[n], vocab, n)
# hide any unused axes
for i in range(n, len(axs)):
axs[i].axis("off")
fig.savefig("wordclouds.png")
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.9.2")
parser = argparse.ArgumentParser(
description="Probabilistic topic modelling with Flax and Haiku"
)
parser.add_argument("-n", "--num-steps", nargs="?", default=30_000, type=int)
parser.add_argument("-t", "--num-topics", nargs="?", default=12, type=int)
parser.add_argument("--batch-size", nargs="?", default=32, type=int)
parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float)
parser.add_argument("--hidden", nargs="?", default=200, type=int)
parser.add_argument("--dropout-rate", nargs="?", default=0.2, type=float)
parser.add_argument(
"-dp",
"--disable-progbar",
action="store_true",
default=False,
help="Whether to disable progress bar",
)
parser.add_argument(
"--device", default="cpu", type=str, help='use "cpu", "gpu" or "tpu".'
)
parser.add_argument(
"--nn-framework",
nargs="?",
default="flax",
help=(
"The framework to use for constructing encoder / decoder. Options are "
'"flax" or "haiku".'
),
)
args = parser.parse_args()
numpyro.set_platform(args.device)
main(args)