In [175]:
import jax
import jax.numpy as jnp
from pathlib import Path 

In [176]:
# limit the vocabulary V
V = 1000
# limit the topics 
T = 25
# limit the epochs
epochs = 50

stop_words = [w.strip() for w in Path("stoplist.txt").read_text().splitlines()]
stop_words = set(stop_words) | set([".", "-", "?"])


In [177]:
doc = Path("documents.txt")

def get_tokens(text):
  t = text.strip().split("\t")[-1].lower().replace("?", " ?").replace(".", " .").split()
  return [w for w in t if w not in stop_words and len(w) > 4]
  # someone should research how these tiny choices influence foundation models

from collections import Counter

c = Counter()
for text in doc.read_text().splitlines():
  toks = get_tokens(text)
  for t in toks:
    c[t] += 1


In [178]:
vocab = {v : idx for idx, (v, _) in enumerate(c.most_common(V))}
# find the |V| most common words in counter c
# and give each an idx
# so the most commonly occuring word has idx 0
# could even do the above with Counter(chain(*list_of_list_of_sentences))

In [179]:
# time to encode data
data = []
labels = []
for idx, text in enumerate(doc.read_text().splitlines()):
  toks = get_tokens(text)
  toks = [t for t in toks if t in vocab]
  data += [vocab.get(t) for t in toks] # get index of t
  labels += [idx] * len(toks) # the tokens belong to doc_idx
# data -- list of words 
# labels -- the doc they belong to

In [180]:
N = len(data) # data size
assert N == len(labels)

print(N)

128056


In [181]:
# M = len(Counter(labels))
# the above is wrong.. because there could be docs in data not represented
# ie, some doc_IDs do not occur in labels[] array
M = idx + 1 # number of documents

In [182]:
key = jax.random.PRNGKey(0)

data = jnp.array(data, dtype = jnp.int32)
labels = jnp.array(labels, dtype = jnp.int32)

In [183]:
topic_assigned = jax.random.randint(key, (N, ), 0, T, dtype = jnp.int32)
# for each word in each doc -- assign a topic randomly from [0, T-1] range

In [184]:
topic_word = jax.ops.index_add(jnp.zeros((T, V), dtype = jnp.int32),
                               jax.ops.index[topic_assigned, data], 1)
# at indices of intersection -- add a 1
# equivalent to for t,v in zip(topic_token, data): zeros_mat[t, v] = 1


In [185]:
doc_topic = jax.ops.index_add(jnp.zeros((M, T), dtype = jnp.int32),
                               jax.ops.index[labels, topic_assigned], 1)
# what are the topics present in each document
# row -- each document; column -- topics

In [186]:
doc_word = jax.ops.index_add(jnp.zeros((M, V), dtype = jnp.int32),
                              jax.ops.index[labels, data], 1)

In [187]:
jnp.sum(topic_word), jnp.sum(doc_topic), jnp.sum(doc_word)

(DeviceArray(128056, dtype=int32),
 DeviceArray(128056, dtype=int32),
 DeviceArray(128056, dtype=int32))

In [188]:
len(data), len(labels), M, N

(128056, 128056, 9743, 128056)

In [189]:
# hyperparams
ALPHA, BETA = 0.1, 0.01

In [190]:
reverse_vocab = {v : k for k, v in vocab.items()}

In [191]:
def token_loop(state, scanned):
  """
  runs once for every word in every document

  (topic_word, doc_topic, topic_cnt) in state
  topic_word [T x V]
  doc_topic [M x T]
  topic_cnt [T]

  (topic_assigned, data, doc, key) in scanned
  topic_assigned [1] -- for the current word, whats the topic assigned
  data [1] -- a data point, ie, a word/token
  doc [1] -- label of that data point -- ie, which document it belongs to
  key [1] -- jax stuff
  """
  topic_word, doc_topic, topic_cnt = state 
  topic_assigned, data, doc, key = scanned # current point we are looking at

  local_tw = topic_word[:, data].at[topic_assigned].add(-1) # remove topic at which it belongs [T, V]
  local_dt = doc_topic[doc].at[topic_assigned].add(-1) # remove corresponding token in doc [M, T]
  local_tc = topic_cnt.at[topic_assigned].add(-1) # reduce cnt for that topic too [T]

  # update the distribution 
  dist = ((local_tw + BETA) / (local_tc + V * BETA)) \
        * ((local_dt + ALPHA) / (doc_word.sum(axis = -1)[doc] + T * ALPHA)) # E [phi_tv] * E[theta_dt]
  new_topic = jax.random.categorical(key, jnp.log(dist)) # sample topic from new dist

  def update(_):
    # state update after each scan
    return (topic_word.at[new_topic, data].add(1).at[topic_assigned, data].add(-1),
            doc_topic.at[doc, new_topic].add(1).at[doc, topic_assigned].add(-1),
            topic_cnt.at[new_topic].add(1).at[topic_assigned].add(-1),
    )
    
  return jax.lax.cond((new_topic != topic_assigned),
                      update,
                      lambda _ : (topic_word, doc_topic, topic_cnt), None
         ), (new_topic, None, None, None)



In [192]:
def mcmc(state):
  """
  looks more like gibbs sampling -- since all samples are accepted
  
  topic_cnt [T x 1] -- number of words assigned to topic
  topic_word [T x V] -- words in topic 
  doc_topic [M x T] -- topics in document
  topic_assigned [N] -- topic assigned to each word; where N is total num of words
  key -- jax stuff
  """
  topic_cnt, topic_word, doc_topic, topic_assigned, key = state
  keys = jax.random.split(key, N + 1) # split for each word .. and another for what/

  (topic_word, doc_topic, topic_cnt), (topic_assigned, _ , _, _) = \
    jax.lax.scan(token_loop, # function 
                 (topic_word, doc_topic, topic_cnt), # init state
                 (topic_assigned, data, labels, keys[1:])) # xs
    # scan effectively scans through each data point.. for x in xs
    # and accumulates result in state [kinda like a reduce function]

  return topic_cnt, topic_word, doc_topic, topic_assigned, keys[0] # don't get the keys thing.. ignoring for now

In [193]:
def run(topic_word, doc_topic, topic_assigned):
  key = jax.random.PRNGKey(1)
  topic_cnt = topic_word.sum(axis = -1)
  for i in range(50):
    (topic_cnt, topic_word, doc_topic, topic_assigned, key) = \
      mcmc((topic_cnt, topic_word, doc_topic, topic_assigned, key))
    return topic_word, doc_topic, topic_assigned

topic_word, doc_topic, topic_assigned = run(topic_word, doc_topic, topic_assigned)

DeviceArray([ 4,  2,  2,  2, 14,  7, 21, 21,  2,  2], dtype=int32)

In [196]:
out = topic_word / topic_word.sum(axis = -1, keepdims = True)

for i in range(T):
  print("TOPIC", i, [reverse_vocab[int(x)] for x in reversed(jnp.argsort(out[i])[-5:])])

TOPIC 0 ['america', 'nation', 'health', 'national', 'american']
TOPIC 1 ['government', 'states', 'nation', 'america', 'federal']
TOPIC 2 ['government', 'country', 'security', 'united', 'freedom']
TOPIC 3 ['government', 'american', 'program', 'america', 'against']
TOPIC 4 ['under', 'states', 'economic', 'american', 'public']
TOPIC 5 ['government', 'about', 'american', 'defense', 'military']
TOPIC 6 ['american', 'federal', 'billion', 'country', 'government']
TOPIC 7 ['american', 'economic', 'federal', 'states', 'present']
TOPIC 8 ['government', 'national', 'peace', 'nation', 'federal']
TOPIC 9 ['government', 'national', 'public', 'nations', 'american']
TOPIC 10 ['states', 'national', 'federal', 'without', 'about']
TOPIC 11 ['american', 'country', 'strengthen', 'americans', 'security']
TOPIC 12 ['government', 'national', 'nation', 'american', 'right']
TOPIC 13 ['american', 'national', 'government', 'united', 'program']
TOPIC 14 ['government', 'national', 'american', 'america', 'united']
T

In [None]:
# how to use gradients... to improve this
# what if the sampling distrib was fancier
# or sampling was HMC
# will have to sample latents with requires_grad = True