Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions official/utils/testing/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import tempfile


def run_synthetic(main, tmp_root, extra_flags=None):
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
"""Performs a minimal run of a model.

This function is intended to test for syntax errors throughout a model. A
Expand All @@ -37,15 +37,22 @@ def run_synthetic(main, tmp_root, extra_flags=None):
function is "<MODULE>.main(argv)".
tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
"""

extra_flags = [] if extra_flags is None else extra_flags

model_dir = tempfile.mkdtemp(dir=tmp_root)

args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_between_evals", "1", "--use_synthetic_data",
"--max_train_steps", "1"] + extra_flags
"--epochs_between_evals", "1"] + extra_flags

if synth:
args.append("--use_synthetic_data")

if max_train is not None:
args.extend(["--max_train_steps", str(max_train)])

try:
main(args)
Expand Down
8 changes: 6 additions & 2 deletions official/wide_deep/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
}


LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}


def build_model_columns():
"""Builds a set of wide and deep feature columns."""
# Continuous columns
Expand Down Expand Up @@ -190,10 +193,11 @@ def train_input_fn():
def eval_input_fn():
return input_fn(test_file, 1, False, flags.batch_size)

loss_prefix = LOSS_PREFIX.get(flags.model_type, '')
train_hooks = hooks_helper.get_train_hooks(
flags.hooks, batch_size=flags.batch_size,
tensors_to_log={'average_loss': 'head/truediv',
'loss': 'head/weighted_loss/Sum'})
tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
'loss': loss_prefix + 'head/weighted_loss/Sum'})

# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(flags.train_epochs // flags.epochs_between_evals):
Expand Down
33 changes: 33 additions & 0 deletions official/wide_deep/wide_deep_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import tensorflow as tf # pylint: disable=g-bad-import-order

from official.utils.testing import integration
from official.wide_deep import wide_deep

tf.logging.set_verbosity(tf.logging.ERROR)
Expand Down Expand Up @@ -54,6 +55,14 @@ def setUp(self):
with tf.gfile.Open(self.input_csv, 'w') as temp_csv:
temp_csv.write(TEST_INPUT)

with tf.gfile.Open(TEST_CSV, "r") as temp_csv:
test_csv_contents = temp_csv.read()

# Used for end-to-end tests.
for fname in ['adult.data', 'adult.test']:
with tf.gfile.Open(os.path.join(self.temp_dir, fname), 'w') as test_csv:
test_csv.write(test_csv_contents)

def test_input_fn(self):
dataset = wide_deep.input_fn(self.input_csv, 1, False, 1)
features, labels = dataset.make_one_shot_iterator().get_next()
Expand Down Expand Up @@ -107,6 +116,30 @@ def input_fn():
def test_wide_deep_estimator_training(self):
self.build_and_test_estimator('wide_deep')

def test_end_to_end_wide(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'wide',
],
synth=False, max_train=None)

def test_end_to_end_deep(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'deep',
],
synth=False, max_train=None)

def test_end_to_end_wide_deep(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'wide_deep',
],
synth=False, max_train=None)


if __name__ == '__main__':
tf.test.main()