In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from util.util_word2vec_preprocess import process_data
import tensorflow as tf

In [2]:
# References
# http://web.stanford.edu/class/cs20si/lectures/notes_04.pdf
# https://stackoverflow.com/questions/34870614/what-does-tf-nn-embedding-lookup-function-do
# https://www.tensorflow.org/tutorials/word2vec

In [3]:
class Word2Vec(object):
    
    def __init__(self, sess, batch_size, vocab_size, embed_size, n_samples, epochs):
        self.sess = sess
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.n_samples = n_samples
        self.epochs = epochs
        self._build_graph()
    
    def _build_graph(self):
        # Scalars (index in vocab) are being fed in thus dimensions is size of batch.
        self.center = tf.placeholder(tf.int32, shape=[self.batch_size])
        self.target = tf.placeholder(tf.int32, shape=[self.batch_size,1])
        
        # Initialize uninformly, with dimensions vocab size by embed size.
        self.embed_matrix = tf.Variable(tf.random_uniform([self.vocab_size, self.embed_size], -1, 1))
        self.w = tf.Variable(tf.truncated_normal(
                [self.vocab_size, self.embed_size], stddev=0.1/self.embed_size**0.5)
        )
        self.b = tf.Variable(tf.zeros([self.vocab_size]))
        
        # The function embedding_lookup intuitively means "select rows given the row id"
        # https://stackoverflow.com/questions/34870614/what-does-tf-nn-embedding-lookup-function-do
        # Get the rows in self.embed_matrix, "id"-ed by self.center
        self.embed = tf.nn.embedding_lookup(self.embed_matrix, self.center)
        
        
        # NSE handles the computational complexity associated with applying the
        # softmax function to a large class,
        # i.e vocab size in this case.
        # For intuitiive explanation and resources.
        # https://datascience.stackexchange.com/questions/13216/intuitive-explanation\
        #-of-noise-contrastive-estimation-nce-loss
        self.loss = tf.reduce_mean(tf.nn.nce_loss(
                weights=self.w,
                biases=self.b,
                labels=self.target,
                inputs=self.embed,
                num_sampled=self.n_samples,
                num_classes=self.vocab_size)
        )
        
        # We will just use default settings for hyperparameters defined:
        # https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
        # Cant beat RMSprop + momentum :)
        self.optimizer = tf.train.AdamOptimizer().minimize(self.loss)    
        
    def train(self, batch_gen):
        # Need to always initialize variables
        tf.global_variables_initializer().run()
        ave_loss = 0
        for i_epoch in xrange(self.epochs):
            # Get batch
            batch = batch_gen.next()
            loss, _ = self.sess.run([self.loss, self.optimizer],
                feed_dict={
                    self.center: batch[0],
                    self.target: batch[1]
                }
            )
            ave_loss += loss
            if (i_epoch + 1)%1000 == 0:
                print("Average loss at epoch {}: {:5.1f}".format(
                        i_epoch + 1, ave_loss/(i_epoch + 1)))
            
            

In [4]:
# Set up parameters 
vocab_size = 50000
batch_size = 128
embed_size = 128
skip_window = 1
n_samples = 64
epochs = 10000

# process_data is defined in utils
batch_gen = process_data(vocab_size, batch_size, skip_window)

# Set up graph and train
with tf.Session() as sess:
    model = Word2Vec(sess, batch_size, vocab_size, embed_size, n_samples, epochs)
    model.train(batch_gen)

Dataset ready
Average loss at epoch 1000: 238.7
Average loss at epoch 2000: 207.1
Average loss at epoch 3000: 185.8
Average loss at epoch 4000: 169.1
Average loss at epoch 5000: 155.6
Average loss at epoch 6000: 145.0
Average loss at epoch 7000: 135.9
Average loss at epoch 8000: 128.1
Average loss at epoch 9000: 120.9
Average loss at epoch 10000: 114.6
