Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow GANEstimator get_hooks_fn to be set manually #14723

Merged
merged 2 commits into from
Dec 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self,
discriminator_loss_fn=None,
generator_optimizer=None,
discriminator_optimizer=None,
get_hooks_fn=None,
add_summaries=None,
use_loss_summaries=True,
config=None):
Expand Down Expand Up @@ -132,6 +133,10 @@ def __init__(self,
work.
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
list of hooks. These hooks are run on the generator and discriminator
train ops, and can be used to implement the GAN training scheme.
Defaults to `train.get_sequential_train_hooks()`.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
Expand All @@ -146,7 +151,7 @@ def _model_fn(features, labels, mode):
else discriminator_optimizer)
gan_head = head_lib.gan_head(
generator_loss_fn, discriminator_loss_fn, gopt, dopt,
use_loss_summaries)
use_loss_summaries, get_hooks_fn=get_hooks_fn)
return _gan_model_fn(
features, labels, mode, generator_fn, discriminator_fn, gan_head,
add_summaries)
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/contrib/gan/python/estimator/python/head_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GANHead(head._Head): # pylint: disable=protected-access
def __init__(self, generator_loss_fn, discriminator_loss_fn,
generator_optimizer, discriminator_optimizer,
use_loss_summaries=True,
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
get_hooks_fn=None,
name=None):
"""`Head` for GAN training.

Expand All @@ -86,10 +86,12 @@ def __init__(self, generator_loss_fn, discriminator_loss_fn,
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
of hooks.
of hooks. Defaults to `train.get_sequential_train_hooks()`
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
"""
if get_hooks_fn is None:
get_hooks_fn = tfgan_train.get_sequential_train_hooks()
# TODO(joelshor): Validate inputs.

if use_loss_summaries in [True, False]:
Expand Down