Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager execution breaks fit_generator in tf.keras #18287

Closed
batzner opened this issue Apr 6, 2018 · 10 comments
Closed

Eager execution breaks fit_generator in tf.keras #18287

batzner opened this issue Apr 6, 2018 · 10 comments
Assignees

Comments

@batzner
Copy link

batzner commented Apr 6, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.12.6
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version: 1.7.0
  • Python version: 3.6.3
  • Numpy version: 1.14.2
  • Bazel version (if compiling from source): N/A
  • GCC/Compiler version (if compiling from source): N/A
  • CUDA/cuDNN version: not installed
  • GPU model and memory: CPU only
  • Exact command to reproduce: Run code below

Describe the problem

tf.enable_eager_execution() leads to a RuntimeError: You must compile your model before using it. when calling Keras's model.fit_generator, even if the model has already been compiled. Calling model.fit works on the other hand.

Source code / logs

Minimum reproducible test case:

import numpy as np
import tensorflow as tf

tf.enable_eager_execution()  # It works without this line

x, y = np.random.randn(100, 10), np.random.randn(100, 4)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(4, input_dim=10)])
model.compile(tf.train.RMSPropOptimizer(0.001), 'mse')

model.fit(x, y)  # Fitting without a generator works in eager mode

class Iterator:
    def __next__(self):
        return x, y

model.fit_generator(Iterator(), steps_per_epoch=10)

Log:

Epoch 1/1
100/100 [==============================] - 0s 445us/step - loss: 2.1153
Traceback (most recent call last):
  File "tmp.py", line 16, in <module>
    model.fit_generator(Iterator(), steps_per_epoch=10)
  File "/Users/kilian/.pyenv/versions/3.6.3/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/engine/sequential.py", line 860, in fit_generator
    initial_epoch=initial_epoch)
  File "/Users/kilian/.pyenv/versions/3.6.3/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/engine/training.py", line 1603, in fit_generator
    initial_epoch=initial_epoch)
  File "/Users/kilian/.pyenv/versions/3.6.3/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/engine/training_generator.py", line 52, in fit_generator
    model._make_train_function()
  File "/Users/kilian/.pyenv/versions/3.6.3/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/engine/training.py", line 578, in _make_train_function
    raise RuntimeError('You must compile your model before using it.')
RuntimeError: You must compile your model before using it.
@tensorflowbutler tensorflowbutler added the stat:awaiting response Status - Awaiting response from author label Apr 6, 2018
@tensorflowbutler
Copy link
Member

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks.
Bazel version

@batzner
Copy link
Author

batzner commented Apr 6, 2018

Updated

@facaiy
Copy link
Member

facaiy commented Apr 7, 2018

@batzner Could you use tf.keras to make a test, instead of contrib module?

@facaiy
Copy link
Member

facaiy commented Apr 7, 2018

@fchollet Sounds like a problem. I checked the keras codes and found that generator seems to have not been supported in eager mode, right? Does anyone have worked on it?

@batzner
Copy link
Author

batzner commented Apr 7, 2018

I updated the code to use tf.keras instead of contrib. The output is the same as before.

@asimshankar asimshankar assigned fchollet and unassigned michaelisard Apr 7, 2018
@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Apr 7, 2018
@tensorflowbutler
Copy link
Member

Nagging Assignee @fchollet: It has been 15 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

1 similar comment
@tensorflowbutler
Copy link
Member

Nagging Assignee @fchollet: It has been 15 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@AakashKumarNain
Copy link
Member

AakashKumarNain commented May 10, 2018

I faced a similar issue. When using fit_generator() in eager mode, keras throws a NotImplemented error.

@fchollet
Copy link
Member

Thanks for the bug report. I have fixed the issue and the fix will soon be available in the TF nightly release.

@DollarAkshay
Copy link

Just noticed that having tf.enable_eager_execution() enabled drastically changed the training part for fit_generator(). Here are some examples that I ran on a pre-trained VGG19 model.

Without Eager Exectution`
image

  • The loss reduces drastically after one epoch
  • The training accuracy goes to 67% after epoch 1 and 93% after epoch 2
  • The validation accuracy jumps to 80% after one epoch

With Eager Execution
image

  • The loss is stuck at 15 and reduces by a tiny amount every epoch
  • The training accuracy is at 4% after epoch 1 and at 5% after epoch 2
  • The validation accuracy is also at 4-5%

This is quite a strong difference. I am guessing that this is not expected ?

Link to Full Notebook : here

Dataset used for Training : HackerEarth ML Challenge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants