# `Word2Vec`

In [1]:
# Vanilla PyLibraries
import os
import sys
import datetime as dt

# Third party Librarires
import numpy as np
import tensorflow as tf

# Custom libraries
from dataset import TextDataset

## initializing `Word2Vec`

In [2]:
data_dir = 'datasets/wiki.valid.raw'
save_file = 'datasets/saved/data.pkl'

w2v = TextDataset(data_dir=data_dir, logging=True)
w2v.create()

Processing 8,224 of 8,224 sentences. Time taken: 0:00:32.279244

### Hyperparameters

In [3]:
# Model/Network
embedding_dim = 50
learning_rate = 1e-3
vocab_size = w2v.vocab_size
# Training
epochs = 10000
save_interval = 50
batch_size = 25

## Training with a `tensorflow` model

In [4]:
# Model's placeholders
X = tf.placeholder(tf.float32, shape=[None, vocab_size], name='X_palceholder')
y = tf.placeholder(tf.float32, shape=[None, vocab_size], name='y_placeholder')
y_true = tf.argmax(y, axis=1)

### Building the Network

In [5]:
# Input -> Hidden
W1 = tf.Variable(tf.truncated_normal(shape=[vocab_size, embedding_dim]))
b1 = tf.Variable(tf.zeros(shape=[embedding_dim]))
hidden = tf.matmul(X, W1) + b1

In [6]:
# Hidden -> Output
W2 = tf.Variable(tf.truncated_normal(shape=[embedding_dim, vocab_size]))
b2 = tf.Variable(tf.zeros(shape=[vocab_size]))
y_hat = tf.matmul(hidden, W2) + b2
y_norm = tf.nn.softmax(y_hat)
y_pred = tf.argmax(y_norm, axis=1)

### Loss, training and accuracy

In [7]:
xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=y_hat, labels=y, name='xentropy')
loss = tf.reduce_mean(xentropy, name='loss')
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_step = optimizer.minimize(loss)

### Accuracy

In [8]:
correct = tf.equal(y_pred, y_true)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

## Initializing global variables and  `tf.Session()`

In [9]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

## Tensorboard

In [10]:
# Files & directories
save_path = 'models/'  # Trained model
tensorboard_dir = 'tensorboard/'  # summary protobuf
logdir = os.path.join(tensorboard_dir, 'log')   # summary's file writer

# Summaries
tf.summary.scalar('Loss', loss)
tf.summary.scalar('Accuracy', accuracy)
merged = tf.summary.merge_all()

# saver & writer
saver = tf.train.Saver()
writer = tf.summary.FileWriter(logdir=logdir, graph=sess.graph)

# Restore or create
if tf.gfile.Exists(save_path):
    if len(os.listdir(save_path)) > 1:
        saver.restore(sess=sess, save_path=save_path)
else:
    tf.gfile.MakeDirs(save_path)

## Training

In [11]:
train_start = dt.datetime.now()
for i in range(epochs):
    # Train
    X_batch, y_batch = w2v.next_batch(batch_size=batch_size, shuffle=True)
    sess.run(train_step, feed_dict={X: X_batch, y: y_batch})
    # Save at interval
    if i % save_interval == 0:
        # Tensorboard summary
        summary = sess.run(merged, feed_dict={X: X_batch, y: y_batch})
        writer.add_summary(summary=summary, global_step=i)
        # Save model
        saver.save(sess=sess, save_path=save_path)
    sys.stdout.write('\r{:,} of {:,} epochs\tTime taken: {}'.format(i+1, 
                                                                    epochs, 
                                                                    dt.datetime.now() - train_start))

10,000 of 10,000 epochs	Time taken: 0:08:51.632382

## Word vectors

In [12]:
word_vectors = sess.run(W1 + b1)
print(word_vectors.shape)

(17588, 50)


In [13]:
word_vectors[w2v.word2id['you']]

array([ 0.34771562, -1.48245394, -0.32719854,  1.17947757,  0.70884764,
        0.4491581 ,  1.41901803,  1.24369752,  0.69479764, -1.02207363,
       -0.60569561, -0.67520511, -0.55108023,  0.92150593,  1.1518805 ,
       -0.0336429 , -1.62853384, -1.21681619, -0.70099843, -1.46538532,
        0.02241147, -0.22869167,  0.18646038, -0.20468491,  0.1531736 ,
        0.22502221,  0.52767438,  0.46577272, -0.07892847,  0.10553825,
        1.09543025,  0.32182369,  1.05936515,  0.47451001,  0.32495922,
        1.48756492, -0.24477828, -0.99777889,  1.9758451 ,  0.72210848,
       -0.30577436, -0.10062671, -0.38749754,  0.14672384, -0.63682145,
       -0.84977138, -0.10278291,  0.09870076,  1.57016897, -0.48524785], dtype=float32)