Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 38 additions & 44 deletions official/recommendation/data_async_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import datetime
import gc
import functools
import logging
import multiprocessing
import json
import os
Expand All @@ -40,23 +39,25 @@
import tensorflow as tf

from absl import app as absl_app
from absl import logging as absl_logging
from absl import flags

from official.datasets import movielens
from official.recommendation import constants as rconst
from official.recommendation import stat_utils


_log_file = None


def log_msg(msg):
"""Include timestamp info when logging messages to a file."""
if flags.FLAGS.redirect_logs:
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
absl_logging.info("[{}] {}".format(timestamp, msg))
print("[{}] {}".format(timestamp, msg), file=_log_file)
else:
absl_logging.info(msg)
sys.stdout.flush()
sys.stderr.flush()
print(msg, file=_log_file)
if _log_file:
_log_file.flush()


def get_cycle_folder_name(i):
Expand Down Expand Up @@ -395,61 +396,54 @@ def _generation_loop(


def main(_):
global _log_file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yuk. Oh well.

redirect_logs = flags.FLAGS.redirect_logs
cache_paths = rconst.Paths(
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)


log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
log_file = os.path.join(cache_paths.data_dir, log_file_name)
if log_file.startswith("gs://") and redirect_logs:
log_path = os.path.join(cache_paths.data_dir, log_file_name)
if log_path.startswith("gs://") and redirect_logs:
fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
print("Unable to log to {}. Falling back to {}"
.format(log_file, fallback_log_file))
log_file = fallback_log_file
.format(log_path, fallback_log_file))
log_path = fallback_log_file

# This server is generally run in a subprocess.
if redirect_logs:
print("Redirecting stdout and stderr to {}".format(log_file))
log_stream = open(log_file, "wt") # Note: not tf.gfile.Open().
stdout = log_stream
stderr = log_stream
print("Redirecting output of data_async_generation.py process to {}"
.format(log_path))
_log_file = open(log_path, "wt") # Note: not tf.gfile.Open().
try:
if redirect_logs:
absl_logging.get_absl_logger().addHandler(
hdlr=logging.StreamHandler(stream=stdout))
sys.stdout = stdout
sys.stderr = stderr
print("Logs redirected.")
try:
log_msg("sys.argv: {}".format(" ".join(sys.argv)))

if flags.FLAGS.seed is not None:
np.random.seed(flags.FLAGS.seed)

_generation_loop(
num_workers=flags.FLAGS.num_workers,
cache_paths=cache_paths,
num_readers=flags.FLAGS.num_readers,
num_neg=flags.FLAGS.num_neg,
num_train_positives=flags.FLAGS.num_train_positives,
num_items=flags.FLAGS.num_items,
spillover=flags.FLAGS.spillover,
epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
train_batch_size=flags.FLAGS.train_batch_size,
eval_batch_size=flags.FLAGS.eval_batch_size,
)
except KeyboardInterrupt:
log_msg("KeyboardInterrupt registered.")
except:
traceback.print_exc()
raise
log_msg("sys.argv: {}".format(" ".join(sys.argv)))

if flags.FLAGS.seed is not None:
np.random.seed(flags.FLAGS.seed)

_generation_loop(
num_workers=flags.FLAGS.num_workers,
cache_paths=cache_paths,
num_readers=flags.FLAGS.num_readers,
num_neg=flags.FLAGS.num_neg,
num_train_positives=flags.FLAGS.num_train_positives,
num_items=flags.FLAGS.num_items,
spillover=flags.FLAGS.spillover,
epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
train_batch_size=flags.FLAGS.train_batch_size,
eval_batch_size=flags.FLAGS.eval_batch_size,
)
except KeyboardInterrupt:
log_msg("KeyboardInterrupt registered.")
except:
traceback.print_exc(file=_log_file)
raise
finally:
log_msg("Shutting down generation subprocess.")
sys.stdout.flush()
sys.stderr.flush()
if redirect_logs:
log_stream.close()
_log_file.close()


def define_flags():
Expand Down
4 changes: 1 addition & 3 deletions official/recommendation/data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
tf.logging.info(
"Generation subprocess command: {}".format(" ".join(subproc_args)))

proc = subprocess.Popen(args=subproc_args, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=False, env=subproc_env)
proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)

atexit.register(_shutdown, proc=proc)
atexit.register(tf.gfile.DeleteRecursively,
Expand Down