Skip to content

Commit

Permalink
Add logging calls to NCF (#5576)
Browse files Browse the repository at this point in the history
* first pass at __getattr__ abuse logger

* first pass at adding tags to NCF

* minor formatting updates

* fix tag name

* convert metrics to python floats

* getting closer...

* direct mlperf logs to a file

* small tweaks and add stitching

* update tags

* fix tag and add a sudo call

* tweak format of run.sh

* delint

* use distribution strategies for evaluation

* address PR comments

* delint and fix test

* adjust flag validation for xla

* add prefix to distinguish log stitching

* fix index bug

* fix clear cache for root user

* dockerize cache drop

* TIL some regex magic
  • Loading branch information
Taylor Robie committed Oct 24, 2018
1 parent f2b702a commit 780f526
Show file tree
Hide file tree
Showing 7 changed files with 421 additions and 55 deletions.
95 changes: 67 additions & 28 deletions official/recommendation/data_async_generation.py
Expand Up @@ -44,6 +44,7 @@
from official.recommendation import constants as rconst
from official.recommendation import stat_utils
from official.recommendation import popen_helper
from official.utils.logs import mlperf_helper


_log_file = None
Expand Down Expand Up @@ -222,11 +223,25 @@ def _construct_records(
"""
st = timeit.default_timer()

if not is_training:
if is_training:
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_STEP_TRAIN_NEG_GEN)
mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.INPUT_HP_NUM_NEG, value=num_neg)

# set inside _process_shard()
mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.INPUT_HP_SAMPLE_TRAIN_REPLACEMENT, value=True)

else:
# Later logic assumes that all items for a given user are in the same batch.
assert not batch_size % (rconst.NUM_EVAL_NEGATIVES + 1)
assert num_neg == rconst.NUM_EVAL_NEGATIVES

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_STEP_EVAL_NEG_GEN)

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_HP_NUM_USERS,
value=num_positives)

assert epochs_per_cycle == 1 or is_training
num_workers = min([num_workers, len(training_shards) * epochs_per_cycle])

Expand Down Expand Up @@ -259,6 +274,7 @@ def _construct_records(
# user is grouped within a batch.
if is_training:
index_destinations = np.random.permutation(num_pts)
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
else:
index_destinations = np.arange(num_pts)

Expand All @@ -276,6 +292,8 @@ def _construct_records(
if num_padding:
# In order to have a full batch, randomly include points from earlier in
# the batch.

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
pad_sample_indices = np.random.randint(
low=0, high=num_pts, size=(num_padding,))
dest = np.arange(start=start_ind, stop=start_ind + num_padding)
Expand All @@ -287,10 +305,20 @@ def _construct_records(
# to interpret and discard the zero padded entries.
data[0][num_pts:] = 0

# Check that no points were overlooked.

# Check that no points were overlooked.
assert not np.sum(data[0] == -1)

if is_training:
# The number of points is slightly larger than num_pts due to padding.
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_SIZE,
value=int(data[0].shape[0]))
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_BATCH_SIZE,
value=batch_size)
else:
# num_pts is logged instead of int(data[0].shape[0]), because the size
# of the data vector includes zero pads which are ignored.
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_SIZE, value=num_pts)

batches_per_file = np.ceil(num_pts_with_padding / batch_size / num_readers)
current_file_id = -1
current_batch_id = -1
Expand All @@ -316,6 +344,7 @@ def _construct_records(
if is_training:
# Empirically it is observed that placing the batch with repeated values at
# the start rather than the end improves convergence.
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
batches_by_file[0][0], batches_by_file[-1][-1] = \
batches_by_file[-1][-1], batches_by_file[0][0]

Expand Down Expand Up @@ -389,17 +418,6 @@ def _generation_loop(num_workers, # type: int
# type: (...) -> None
"""Primary run loop for data file generation."""

log_msg("Signaling that I am alive.")
with tf.gfile.Open(cache_paths.subproc_alive, "w") as f:
f.write("Generation subproc has started.")

@atexit.register
def remove_alive_file():
try:
tf.gfile.Remove(cache_paths.subproc_alive)
except tf.errors.NotFoundError:
return # Main thread has already deleted the entire cache dir.

log_msg("Entering generation loop.")
tf.gfile.MakeDirs(cache_paths.train_epoch_dir)
tf.gfile.MakeDirs(cache_paths.eval_data_subdir)
Expand Down Expand Up @@ -484,10 +502,29 @@ def _parse_flagfile(flagfile):
tf.gfile.Remove(flagfile_temp)


def write_alive_file(cache_paths):
"""Write file to signal that generation process started correctly."""
log_msg("Signaling that I am alive.")
with tf.gfile.Open(cache_paths.subproc_alive, "w") as f:
f.write("Generation subproc has started.")

@atexit.register
def remove_alive_file():
try:
tf.gfile.Remove(cache_paths.subproc_alive)
except tf.errors.NotFoundError:
return # Main thread has already deleted the entire cache dir.


def main(_):
# Note: The async process must execute the following two steps in the
# following order BEFORE doing anything else:
# 1) Write the alive file
# 2) Wait for the flagfile to be written.
global _log_file
cache_paths = rconst.Paths(
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
write_alive_file(cache_paths=cache_paths)

flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
_parse_flagfile(flagfile)
Expand All @@ -513,20 +550,22 @@ def main(_):
if flags.FLAGS.seed is not None:
np.random.seed(flags.FLAGS.seed)

_generation_loop(
num_workers=flags.FLAGS.num_workers,
cache_paths=cache_paths,
num_readers=flags.FLAGS.num_readers,
num_neg=flags.FLAGS.num_neg,
num_train_positives=flags.FLAGS.num_train_positives,
num_items=flags.FLAGS.num_items,
num_users=flags.FLAGS.num_users,
epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
train_batch_size=flags.FLAGS.train_batch_size,
eval_batch_size=flags.FLAGS.eval_batch_size,
deterministic=flags.FLAGS.seed is not None,
match_mlperf=flags.FLAGS.ml_perf,
)
with mlperf_helper.LOGGER(enable=flags.FLAGS.ml_perf):
mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
_generation_loop(
num_workers=flags.FLAGS.num_workers,
cache_paths=cache_paths,
num_readers=flags.FLAGS.num_readers,
num_neg=flags.FLAGS.num_neg,
num_train_positives=flags.FLAGS.num_train_positives,
num_items=flags.FLAGS.num_items,
num_users=flags.FLAGS.num_users,
epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
train_batch_size=flags.FLAGS.train_batch_size,
eval_batch_size=flags.FLAGS.eval_batch_size,
deterministic=flags.FLAGS.seed is not None,
match_mlperf=flags.FLAGS.ml_perf,
)
except KeyboardInterrupt:
log_msg("KeyboardInterrupt registered.")
except:
Expand Down
59 changes: 40 additions & 19 deletions official/recommendation/data_preprocessing.py
Expand Up @@ -46,6 +46,7 @@
from official.recommendation import constants as rconst
from official.recommendation import stat_utils
from official.recommendation import popen_helper
from official.utils.logs import mlperf_helper


DATASET_TO_NUM_USERS_AND_ITEMS = {
Expand Down Expand Up @@ -134,6 +135,9 @@ def _filter_index_sort(raw_rating_path, match_mlperf):
original_users = df[movielens.USER_COLUMN].unique()
original_items = df[movielens.ITEM_COLUMN].unique()

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.PREPROC_HP_MIN_RATINGS,
value=rconst.MIN_NUM_RATINGS)

# Map the ids of user and item to 0 based index for following processing
tf.logging.info("Generating user_map and item_map...")
user_map = {user: index for index, user in enumerate(original_users)}
Expand All @@ -147,6 +151,12 @@ def _filter_index_sort(raw_rating_path, match_mlperf):
num_users = len(original_users)
num_items = len(original_items)

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.PREPROC_HP_NUM_EVAL,
value=num_users * (1 + rconst.NUM_EVAL_NEGATIVES))
mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.PREPROC_HP_SAMPLE_EVAL_REPLACEMENT,
value=match_mlperf)

assert num_users <= np.iinfo(np.int32).max
assert num_items <= np.iinfo(np.uint16).max
assert df[movielens.USER_COLUMN].max() == num_users - 1
Expand Down Expand Up @@ -397,6 +407,27 @@ def _shutdown(proc):
tf.logging.error("Data generation subprocess could not be killed.")



def write_flagfile(flags_, ncf_dataset):
"""Write flagfile to begin async data generation."""
if ncf_dataset.deterministic:
flags_["seed"] = stat_utils.random_int32()

# We write to a temp file then atomically rename it to the final file,
# because writing directly to the final file can cause the data generation
# async process to read a partially written JSON file.
flagfile_temp = os.path.join(ncf_dataset.cache_paths.cache_root,
rconst.FLAGFILE_TEMP)
tf.logging.info("Preparing flagfile for async data generation in {} ..."
.format(flagfile_temp))
with tf.gfile.Open(flagfile_temp, "w") as f:
for k, v in six.iteritems(flags_):
f.write("--{}={}\n".format(k, v))
flagfile = os.path.join(ncf_dataset.cache_paths.cache_root, rconst.FLAGFILE)
tf.gfile.Rename(flagfile_temp, flagfile)
tf.logging.info(
"Wrote flagfile for async data generation in {}.".format(flagfile))

def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
num_data_readers=None, num_neg=4, epochs_per_cycle=1,
match_mlperf=False, deterministic=False,
Expand All @@ -405,6 +436,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
"""Preprocess data and start negative generation subprocess."""

tf.logging.info("Beginning data preprocessing.")
tf.gfile.MakeDirs(data_dir)
ncf_dataset = construct_cache(dataset=dataset, data_dir=data_dir,
num_data_readers=num_data_readers,
match_mlperf=match_mlperf,
Expand All @@ -431,25 +463,6 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
"ml_perf": match_mlperf,
}

if ncf_dataset.deterministic:
flags_["seed"] = stat_utils.random_int32()
tf.gfile.MakeDirs(data_dir)
# We write to a temp file then atomically rename it to the final file,
# because writing directly to the final file can cause the data generation
# async process to read a partially written JSON file.
flagfile_temp = os.path.join(ncf_dataset.cache_paths.cache_root,
rconst.FLAGFILE_TEMP)
tf.logging.info("Preparing flagfile for async data generation in {} ..."
.format(flagfile_temp))
with tf.gfile.Open(flagfile_temp, "w") as f:
for k, v in six.iteritems(flags_):
f.write("--{}={}\n".format(k, v))
flagfile = os.path.join(ncf_dataset.cache_paths.cache_root, rconst.FLAGFILE)
tf.gfile.Rename(flagfile_temp, flagfile)
tf.logging.info(
"Wrote flagfile for async data generation in {}."
.format(flagfile))

if use_subprocess:
tf.logging.info("Creating training file subprocess.")
subproc_env = os.environ.copy()
Expand Down Expand Up @@ -489,6 +502,14 @@ def cleanup():
raise ValueError("Generation subprocess did not start correctly. Data will "
"not be available; exiting to avoid waiting forever.")

# We start the async process and wait for it to signal that it is alive. It
# will then enter a loop waiting for the flagfile to be written. Once we see
# that the async process has signaled that it is alive, we clear the system
# caches and begin the run.
mlperf_helper.clear_system_caches()
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.RUN_START)
write_flagfile(flags_, ncf_dataset)

return ncf_dataset, cleanup


Expand Down
45 changes: 40 additions & 5 deletions official/recommendation/ncf_main.py
Expand Up @@ -45,6 +45,7 @@
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.logs import mlperf_helper
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers

Expand Down Expand Up @@ -104,7 +105,8 @@ def construct_estimator(num_gpus, model_dir, params, batch_size,
return train_estimator, eval_estimator

distribution = distribution_utils.get_distribution_strategy(num_gpus=num_gpus)
run_config = tf.estimator.RunConfig(train_distribute=distribution)
run_config = tf.estimator.RunConfig(train_distribute=distribution,
eval_distribute=distribution)
params["eval_batch_size"] = eval_batch_size
model_fn = neumf_model.neumf_model_fn
if params["use_xla_for_gpu"]:
Expand All @@ -116,8 +118,10 @@ def construct_estimator(num_gpus, model_dir, params, batch_size,


def main(_):
with logger.benchmark_context(FLAGS):
with logger.benchmark_context(FLAGS), mlperf_helper.LOGGER(FLAGS.ml_perf):
mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
run_ncf(FLAGS)
mlperf_helper.stitch_ncf()


def run_ncf(_):
Expand Down Expand Up @@ -218,10 +222,16 @@ def run_ncf(_):

pred_input_fn = None
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
target_reached = False
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_LOOP)
for cycle_index in range(total_training_cycle):
assert FLAGS.epochs_between_evals == 1 or not mlperf_helper.LOGGER.enabled
tf.logging.info("Starting a training cycle: {}/{}".format(
cycle_index + 1, total_training_cycle))

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_EPOCH,
value=cycle_index)

# Train the model
train_input_fn, train_record_dir, batch_count = \
data_preprocessing.make_input_fn(
Expand All @@ -248,27 +258,49 @@ def run_ncf(_):
"producing incorrect shards.".format(
eval_batch_count, num_eval_steps))

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_START,
value=cycle_index)
eval_results = eval_estimator.evaluate(pred_input_fn, steps=num_eval_steps)
hr = float(eval_results[rconst.HR_KEY])
ndcg = float(eval_results[rconst.NDCG_KEY])
tf.logging.info("Evaluation complete.")

mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.EVAL_TARGET,
value={"epoch": cycle_index, "value": FLAGS.hr_threshold})
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_ACCURACY,
value={"epoch": cycle_index, "value": hr})
mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.EVAL_HP_NUM_NEG,
value={"epoch": cycle_index, "value": rconst.NUM_EVAL_NEGATIVES})

# Logged by the async process during record creation.
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_HP_NUM_USERS,
deferred=True)

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_STOP, value=cycle_index)

# Benchmark the evaluation results
benchmark_logger.log_evaluation_result(eval_results)
# Log the HR and NDCG results.
hr = eval_results[rconst.HR_KEY]
ndcg = eval_results[rconst.NDCG_KEY]
tf.logging.info(
"Iteration {}: HR = {:.4f}, NDCG = {:.4f}".format(
cycle_index + 1, hr, ndcg))

# If some evaluation threshold is met
if model_helpers.past_stop_threshold(FLAGS.hr_threshold, hr):
target_reached = True
break

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.RUN_STOP,
value={"success": target_reached})
cleanup_fn() # Cleanup data construction artifacts and subprocess.

# Clear the session explicitly to avoid session delete error
tf.keras.backend.clear_session()

mlperf_helper.ncf_print(key=mlperf_helper.TAGS.RUN_FINAL)


def define_ncf_flags():
"""Add flags for running ncf_main."""
Expand Down Expand Up @@ -419,7 +451,10 @@ def eval_size_check(eval_batch_size):
"If True, use XLA for the model function. Only works when using a "
"GPU. On TPUs, XLA is always used"))

flags.mark_flags_as_mutual_exclusive(["use_xla_for_gpu", "tpu"])
xla_message = "--use_xla_for_gpu is incompatible with --tpu"
@flags.multi_flags_validator(["use_xla_for_gpu", "tpu"], message=xla_message)
def xla_validator(flag_dict):
return not flag_dict["use_xla_for_gpu"] or not flag_dict["tpu"]


if __name__ == "__main__":
Expand Down

0 comments on commit 780f526

Please sign in to comment.