In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

tfk = tf.keras

import numpy as np

# Getting and preprocessing dataset

In [23]:
tf.random.set_seed(0)

num_articles = 100

dataset = tfds.load(name="cnn_dailymail", split=tfds.Split.ALL)
dataset = dataset.shuffle(buffer_size=512).repeat()
dataset = dataset.batch(num_articles)

for example in dataset.take(1):
    
    # Dowload all articles and remove punctuation
    articles = tf.strings.regex_replace(example['article'],
                                        tf.constant("[[:punct:]]"),
                                        tf.constant(""),
                                        replace_global=True)
    
    # Split articles into words
    words = tf.strings.split(articles)
    
    # Count occurences of unique words
    unique, idx, counts = tf.unique_with_counts(words.values)
    
    vocab_dict = dict(zip(unique.numpy(), np.arange(len(unique))))
    vocab = tf.gather(unique, tf.where(counts > 1)).numpy()[:, 0]
    
    def get_rare_word_mask(tensor):
        
        if len(tensor.shape) > 1:
            return tf.ragged.map_flat_values(get_rare_word_mask, tensor)

        else:
            return tf.map_fn(lambda x : str(int(x in vocab)), tensor)
    
    def indexify(tensor):

        if len(tensor.shape) > 1:
            return tf.ragged.map_flat_values(indexify, tensor)

        else:
            return tf.map_fn(lambda x : str(vocab_dict[x.numpy()]), tensor)
    
    word_mask = tf.ragged.map_flat_values(get_rare_word_mask, words)
    word_mask = tf.strings.to_number(word_mask, out_type=np.int32)
    word_mask = tf.cast(word_mask, dtype=tf.bool)
    
    words = tf.ragged.boolean_mask(words, word_mask)
    words = tf.ragged.map_flat_values(indexify, words)

In [24]:
print(len(unique), len(tf.unique_with_counts(words.values)[0]))

10140 5226


# Model

In [115]:
class LDA:
    
    def __init__(self, data, num_topics):
        
        self.data = data
        
        self.K = num_topics
        self.D = self.data.shape[0]
        self.V = len(tf.unique(words.values).y)
        self.W = tf.ragged.stack([tf.one_hot(row, self.V) for row in data], axis=0)
        
        # Model parameters
        self.alpha = tf.Variable(tf.ones(shape=(self.K,)))
        self.eta = tf.Variable(1.0)
        
        # Variational posterior parameters
        self.lamda = tf.Variable(tf.ones(shape=(self.K, self.V)))
        self.gamma = tf.Variable(tf.ones(shape=(self.D, self.K)))
        self.phi = tf.ragged.stack([tf.ones_like(data) for i in range(self.K)], axis=-1)
        self.phi = self.phi / tf.reduce_sum(self.phi, axis=-1)[..., None]
        
        print(self.W.shape)
        
    
    def single_variational_parameter_update(self):
        
        gamma = self.alpha[None, ...] + tf.reduce_sum(self.phi, axis=1)

        lamda = self.eta[None, None] + tf.reduce_sum(self.phi[..., None], axis=1)

In [121]:
LDA(data=words[:100], num_topics=5)

(100, None, None)


<__main__.LDA at 0x1b2a52860>

<tf.Tensor: id=14349337, shape=(3,), dtype=float32, numpy=array([1., 1., 0.], dtype=float32)>