Skip to content

Commit

Permalink
TensorBoard Modifications for Word2Vec Example (#14908)
Browse files Browse the repository at this point in the history
* TensorBoard modifications are added to visualize loss graph and embeddings.

* Update word2vec_basic.py

* Flag is added for log directory.
  • Loading branch information
ozgyal authored and drpngx committed Jan 20, 2018
1 parent 9563c09 commit 8126571
Showing 1 changed file with 85 additions and 19 deletions.
104 changes: 85 additions & 19 deletions tensorflow/examples/tutorials/word2vec/word2vec_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import collections
import math
import os
import sys
import argparse
import random
from tempfile import gettempdir
import zipfile
Expand All @@ -30,6 +32,24 @@
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.contrib.tensorboard.plugins import projector

# Give a folder path as an argument with '--log_dir' to save
# TensorBoard summaries. Default is a log folder in current directory.
current_path = os.path.dirname(os.path.realpath(sys.argv[0]))

parser = argparse.ArgumentParser()
parser.add_argument(
'--log_dir',
type=str,
default=os.path.join(current_path, 'log'),
help='The log directory for TensorBoard summaries.')
FLAGS, unparsed = parser.parse_known_args()

# Create the directory for TensorBoard variables if there is not.
if not os.path.exists(FLAGS.log_dir):
os.makedirs(FLAGS.log_dir)

# Step 1: Download the data.
url = 'http://mattmahoney.net/dc/'

Expand Down Expand Up @@ -156,38 +176,47 @@ def generate_batch(batch_size, num_skips, skip_window):
with graph.as_default():

# Input data.
train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
with tf.name_scope('inputs'):
train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
valid_dataset = tf.constant(valid_examples, dtype=tf.int32)

# Ops and variables pinned to the CPU because of missing GPU implementation
with tf.device('/cpu:0'):
# Look up embeddings for inputs.
embeddings = tf.Variable(
tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_inputs)
with tf.name_scope('embeddings'):
embeddings = tf.Variable(
tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_inputs)

# Construct the variables for the NCE loss
nce_weights = tf.Variable(
tf.truncated_normal([vocabulary_size, embedding_size],
stddev=1.0 / math.sqrt(embedding_size)))
nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
with tf.name_scope('weights'):
nce_weights = tf.Variable(
tf.truncated_normal([vocabulary_size, embedding_size],
stddev=1.0 / math.sqrt(embedding_size)))
with tf.name_scope('biases'):
nce_biases = tf.Variable(tf.zeros([vocabulary_size]))

# Compute the average NCE loss for the batch.
# tf.nce_loss automatically draws a new sample of the negative labels each
# time we evaluate the loss.
# Explanation of the meaning of NCE loss:
# http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/
loss = tf.reduce_mean(
tf.nn.nce_loss(weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size))
with tf.name_scope('loss'):
loss = tf.reduce_mean(
tf.nn.nce_loss(weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size))

# Add the loss value as a scalar to summary.
tf.summary.scalar('loss', loss)

# Construct the SGD optimizer using a learning rate of 1.0.
optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)
with tf.name_scope('optimizer'):
optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

# Compute the cosine similarity between minibatch examples and all embeddings.
norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
Expand All @@ -197,13 +226,22 @@ def generate_batch(batch_size, num_skips, skip_window):
similarity = tf.matmul(
valid_embeddings, normalized_embeddings, transpose_b=True)

# Merge all summaries.
merged = tf.summary.merge_all()

# Add variable initializer.
init = tf.global_variables_initializer()

# Create a saver.
saver = tf.train.Saver()

# Step 5: Begin training.
num_steps = 100001

with tf.Session(graph=graph) as session:
# Open a writer to write summaries.
writer = tf.summary.FileWriter(FLAGS.log_dir, session.graph)

# We must initialize all variables before we use them.
init.run()
print('Initialized')
Expand All @@ -214,10 +252,21 @@ def generate_batch(batch_size, num_skips, skip_window):
batch_size, num_skips, skip_window)
feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}

# Define metadata variable.
run_metadata = tf.RunMetadata()

# We perform one update step by evaluating the optimizer op (including it
# in the list of returned values for session.run()
_, loss_val = session.run([optimizer, loss], feed_dict=feed_dict)
# Also, evaluate the merged op to get all summaries from the returned "summary" variable.
# Feed metadata variable to session for visualizing the graph in TensorBoard.
_, summary, loss_val = session.run([optimizer, merged, loss], feed_dict=feed_dict, run_metadata=run_metadata)
average_loss += loss_val

# Add returned summaries to writer in each step.
writer.add_summary(summary, step)
# Add metadata to visualize the graph for the last run.
if step == (num_steps - 1):
writer.add_run_metadata(run_metadata, 'step%d' % step)

if step % 2000 == 0:
if step > 0:
Expand All @@ -240,6 +289,23 @@ def generate_batch(batch_size, num_skips, skip_window):
print(log_str)
final_embeddings = normalized_embeddings.eval()

# Write corresponding labels for the embeddings.
with open(FLAGS.log_dir + '/metadata.tsv', 'w') as f:
for i in xrange(vocabulary_size):
f.write(reverse_dictionary[i] + '\n')

# Save the model for checkpoints.
saver.save(session, os.path.join(FLAGS.log_dir, "model.ckpt"))

# Create a configuration for visualizing embeddings with the labels in TensorBoard.
config = projector.ProjectorConfig()
embedding_conf = config.embeddings.add()
embedding_conf.tensor_name = embeddings.name
embedding_conf.metadata_path = os.path.join(FLAGS.log_dir, 'metadata.tsv')
projector.visualize_embeddings(writer, config)

writer.close()

# Step 6: Visualize the embeddings.


Expand Down

0 comments on commit 8126571

Please sign in to comment.