Skip to content

Commit

Permalink
Enable Model training/eval from generator in eager execution. Fixes t…
Browse files Browse the repository at this point in the history
…ensorflow#18287

PiperOrigin-RevId: 196171525
  • Loading branch information
fchollet authored and tensorflower-gardener committed May 10, 2018
1 parent 7a49337 commit 1b67ccb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
18 changes: 18 additions & 0 deletions tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,24 @@ def test_model_methods_with_eager_tensors_single_io(self):
model.train_on_batch(inputs, targets)
model.test_on_batch(inputs, targets)

def test_generator_methods(self):
model = keras.Sequential()
model.add(keras.layers.Dense(4, input_shape=(3,)))
optimizer = RMSPropOptimizer(learning_rate=0.001)
model.compile(optimizer, 'mse', metrics=['mae'])

x = np.random.random((10, 3))
y = np.random.random((10, 4))

def iterator():
while 1:
yield x, y

model.fit_generator(iterator(), steps_per_epoch=3, epochs=1)
model.evaluate_generator(iterator(), steps=3)
out = model.predict_generator(iterator(), steps=3)
self.assertEqual(out.shape, (30, 4))


class LossWeightingTest(test.TestCase):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def fit_generator(model,
epoch = initial_epoch

do_validation = bool(validation_data)
model._make_train_function()
if do_validation:
model._make_test_function()

is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
Expand Down Expand Up @@ -252,8 +249,6 @@ def evaluate_generator(model,
workers=1,
use_multiprocessing=False):
"""See docstring for `Model.evaluate_generator`."""
model._make_test_function()

steps_done = 0
wait_time = 0.01
all_outs = []
Expand Down Expand Up @@ -346,8 +341,6 @@ def predict_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.predict_generator`."""
model._make_predict_function()

steps_done = 0
wait_time = 0.01
all_outs = []
Expand Down

0 comments on commit 1b67ccb

Please sign in to comment.