Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make full model before calling set_model on callback #21244

Merged
merged 5 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 40 additions & 0 deletions tensorflow/python/keras/callbacks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
Expand Down Expand Up @@ -1222,6 +1223,45 @@ def test_RemoteMonitorWithJsonPayload(self):
callbacks=cbks,
epochs=1)

def test_fit_generator_with_callback(self):

class TestCallback(keras.callbacks.Callback):
def set_model(self, model):
# Check the model operations for the optimizer operations that
# the _make_train_function adds under a named scope for the
# optimizer. This ensurs the full model is populated before the
# set_model callback is called.
optimizer_name_scope = 'training/' + model.optimizer.__class__.__name__
graph_def = ops.get_default_graph().as_graph_def()
for node in graph_def.node:
if node.name.startswith(optimizer_name_scope):
return
raise RuntimeError('The optimizer operations are not present in the '
'model graph when the Callback.set_model function '
'is called')
np.random.seed(1337)

def generator():
x = np.random.randn(10, 100).astype(np.float32)
y = np.random.randn(10, 10).astype(np.float32)
while True:
yield x, y

with self.cached_session():
model = testing_utils.get_small_sequential_mlp(
num_hidden=10, num_classes=10, input_dim=100)
model.compile(
loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.fit_generator(
generator(),
steps_per_epoch=2,
epochs=1,
validation_data=generator(),
validation_steps=2,
callbacks=[TestCallback()],
verbose=0)

if __name__ == '__main__':
test.main()
11 changes: 11 additions & 0 deletions tensorflow/python/keras/engine/training_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
Expand Down Expand Up @@ -48,6 +49,10 @@ def fit_generator(model,
epoch = initial_epoch

do_validation = bool(validation_data)
if not context.executing_eagerly():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you mentioned there are callbacks that depend on model being set. Can you make unit tests per this change so that it's more clear what's the intention here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I can write a test for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

model._make_train_function()
if do_validation:
model._make_test_function()

is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
Expand Down Expand Up @@ -233,6 +238,9 @@ def evaluate_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.evaluate_generator`."""
if not context.executing_eagerly():
model._make_test_function()

if hasattr(model, 'metrics'):
for m in model.stateful_metric_functions:
m.reset_states()
Expand Down Expand Up @@ -342,6 +350,9 @@ def predict_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.predict_generator`."""
if not context.executing_eagerly():
model._make_test_function()

steps_done = 0
wait_time = 0.01
all_outs = []
Expand Down