# `Word2Vec`

In [15]:
# Vanilla PyLibraries
import os
import sys
import datetime as dt

# Third party Librarires
import numpy as np
import tensorflow as tf

# Custom libraries
from dataset import TextDataset

## initializing `Word2Vec`

In [16]:
data_dir = 'datasets/wiki.valid.raw'
save_file = 'datasets/saved/data.pkl'

w2v = TextDataset(data_dir=data_dir, logging=True)
w2v.create()
# w2v.save(save_file=save_file)
# w2v = w2v.load(save_file=save_file)

Processing 8,224 of 8,224 sentences.	Time taken: 0:00:36.157439

### Hyperparameters

In [17]:
# Model/Network
embedding_dim = 50
learning_rate = 1e-3
vocab_size = w2v.vocab_size

# Training
epochs = 5000
save_interval = 100
batch_size = 25

## Training with a `tensorflow` model

In [18]:
# Model's placeholders
X = tf.placeholder(tf.float32, shape=[None, vocab_size], name='X_palceholder')
y = tf.placeholder(tf.float32, shape=[None, vocab_size], name='y_placeholder')
y_true = tf.argmax(y, axis=1)

### Building the Network

In [19]:
# Input -> Hidden
W1 = tf.Variable(tf.truncated_normal(shape=[vocab_size, embedding_dim]))
b1 = tf.Variable(tf.zeros(shape=[embedding_dim]))
hidden = tf.matmul(X, W1) + b1

In [20]:
# Hidden -> Output
W2 = tf.Variable(tf.truncated_normal(shape=[embedding_dim, vocab_size]))
b2 = tf.Variable(tf.zeros(shape=[vocab_size]))
y_hat = tf.matmul(hidden, W2) + b2
y_norm = tf.nn.softmax(y_hat)
y_pred = tf.argmax(y_norm, axis=1)

### Loss, training and accuracy

In [21]:
xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=y_hat, labels=y, name='xentropy')
loss = tf.reduce_mean(xentropy, name='loss')

# optimizer
global_step = tf.Variable(0, trainable=False, name='global_step')
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_step = optimizer.minimize(loss, global_step=global_step)

### Accuracy

In [22]:
correct = tf.equal(y_pred, y_true)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

## Initializing global variables and  `tf.Session()`

In [23]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

## Tensorboard

In [24]:
# Files & directories
saved_path = 'saved/'

model_dir = os.path.join(saved_path, 'models/')  # Trained model's directory
model_path = os.path.join(model_dir, 'model.ckpt')  # Pre-trained model
tensorboard_dir = os.path.join(saved_path, 'tensorboard/')  # summary protobuf
logdir = os.path.join(tensorboard_dir, 'log')   # summary's file writer

# Summaries
tf.summary.scalar('Loss', loss)
tf.summary.scalar('Accuracy', accuracy)
merged = tf.summary.merge_all()

# saver & writer
saver = tf.train.Saver()
writer = tf.summary.FileWriter(logdir=logdir, graph=sess.graph)

### Restore last checkpoint

In [25]:
# Restore or create
if tf.gfile.Exists(model_dir):
    try:
        print('INFO: Attempting to restore last checkpoint')
        last_ckpt = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess=sess, save_path=last_ckpt)
        print(f'INFO: Successfully restored checkpoint - {last_ckpt}')
    except Exception as e:
        sys.stderr.write(f'ERR: Could not restore checkpoint. {e}')
        sys.stderr.flush()
else:
    tf.gfile.MakeDirs(model_dir)
    print(f'INFO: Checkpoint directory created - {model_dir}')

INFO: Attempting to restore last checkpoint
INFO:tensorflow:Restoring parameters from saved/models/model.ckpt-4901


ERR: Could not restore checkpoint. Key global_step_1 not found in checkpoint
	 [[Node: save_1/RestoreV2_29 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_29/tensor_names, save_1/RestoreV2_29/shape_and_slices)]]

Caused by op 'save_1/RestoreV2_29', defined at:
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/pytho

## Training

In [12]:
train_start = dt.datetime.now()
for i in range(epochs):
    # Train
    X_batch, y_batch = w2v.next_batch(batch_size=batch_size, shuffle=True)
    _, i_global = sess.run([train_step, global_step], feed_dict={X: X_batch, y: y_batch})
    # Save at interval
    if i % save_interval == 0:
        # Save model
        saver.save(sess=sess, save_path=model_path, global_step=global_step)
        # Tensorboard summary
        summary = sess.run(merged, feed_dict={X: X_batch, y: y_batch})
        writer.add_summary(summary=summary, global_step=i_global)
    sys.stdout.write(f'\rIter: {i+1:,}\tGlobal steps: {i_global:,}'
                     f'\tTime taken: {dt.datetime.now() - train_start}')
    sys.stdout.flush()

Iter: 5,000	Global steps: 5,000	Time taken: 0:05:22.476653

## Word vectors

In [13]:
word_vectors = sess.run(W1 + b1)
print(word_vectors.shape)

(17588, 50)


In [14]:
word_vectors[w2v.word2id['you']]

array([ -3.83681208e-01,  -1.88688326e+00,  -6.29122406e-02,
         8.34418058e-01,   5.04731297e-01,  -1.50038385e+00,
         7.60304868e-01,   6.93101525e-01,   2.07367361e-01,
        -2.11913705e+00,  -5.94396293e-01,  -6.24938965e-01,
         7.17186093e-01,   1.36648226e+00,   1.52218103e+00,
         5.17773807e-01,  -1.81506714e-03,  -5.21959960e-01,
        -4.70628560e-01,  -1.39591992e-02,  -4.68842804e-01,
         6.02150857e-01,   2.01299816e-01,   3.94365966e-01,
        -1.16026449e+00,   7.18030095e-01,  -6.85940325e-01,
         5.34943700e-01,  -1.05071521e+00,   1.87786102e-01,
        -1.80632210e+00,   1.21711850e+00,   1.83124340e+00,
         1.53925550e+00,   3.72910887e-01,   1.66313446e+00,
         1.51644218e+00,   3.47793996e-01,   1.45923102e+00,
         4.10007238e-01,   1.61073136e+00,   3.67032886e-01,
         1.16520548e+00,  -6.75011277e-01,  -3.13062608e-01,
        -1.03816330e+00,  -1.27810693e+00,   8.23863983e-01,
        -1.40525579e-01,