-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removed depreciated tensorflow dataset APIs #680
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks for putting this together so quickly. Just a handful of comments.
examples/tensorflow_mnist.py
Outdated
|
||
import tensorflow as tf | ||
import keras |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use tf.keras
instead? Ideally, we'd like this example to work with standalone TensorFlow without requiring the user to install standalone keras as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially did that but the problem is tf.keras
is not available in TF 1.1.x. I can check for the TF version in the code and import keras
or tf.keras.
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would do a version check again v1.4.0 similar to what we do for horovod.tensorflow.keras
:
if LooseVersion(tf.__version__) >= LooseVersion("1.4.0"):
from tensorflow import keras
else:
from tensorflow.contrib import keras
examples/tensorflow_mnist.py
Outdated
@@ -12,12 +12,17 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
#!/usr/bin/env python | |||
# !/usr/bin/env python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: remove leading space
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
examples/tensorflow_mnist.py
Outdated
import horovod.tensorflow as hvd | ||
import numpy as np | ||
|
||
layers = tf.contrib.layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to use tf.layers
instead of tf.contrib.layers
as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this as well.
import os | ||
import shutil | ||
|
||
import keras |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above: tf.keras
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above.
examples/tensorflow_mnist.py
Outdated
x_test, y_test) = keras.datasets.mnist.load_data( | ||
'MNIST-data-%d' % hvd.rank()) | ||
|
||
x_train = np.reshape(x_train, (-1, 784)) / 255 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this safe in Python2, where integer division is the default? Maybe we can divide by 255.0 to be safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
examples/tensorflow_mnist.py
Outdated
# When running tests, if dataset is previously downloaded, it may cause | ||
# the tests to fail. In this case, we need to remove the dataset cache | ||
# folder first and download the dataset again. | ||
cache_dir = os.path.join(os.path.expanduser('~'), '.keras') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, is it necessary to remove the dataset every time? This seems expensive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When running with MPI and with more than 1 process, all the processes try
to download the data and this can cause a race condition.
Multiple processes might simultaneously check if the dataset folder
exists and then try to create the folder and download the data. However,
one of them only succeeds, and the rest fail with an os IOError.
(train_data, train_labels), (eval_data, eval_labels) = \ | ||
keras.datasets.mnist.load_data('MNIST-data-%d' % hvd.rank()) | ||
except OSError as ex: | ||
# When running tests, if dataset is previously downloaded, it may cause |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above: is it necessary to remove all?
'MNIST-data-%d' % hvd.rank()) | ||
|
||
# reshape the features and normalize them between 0 and 1 | ||
train_data = np.reshape(train_data, (-1, 784)) / 255 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above: division by int vs float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Few comments inline.
examples/tensorflow_mnist.py
Outdated
def main(_): | ||
# Horovod: initialize Horovod. | ||
hvd.init() | ||
|
||
# Download and load MNIST dataset. | ||
mnist = learn.datasets.mnist.read_data_sets('MNIST-data-%d' % hvd.rank()) | ||
dataset_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), | ||
'MNIST-data-%d' % hvd.rank()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix formatting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
@@ -14,10 +14,20 @@ | |||
# ============================================================================== | |||
#!/usr/bin/env python | |||
|
|||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to the next import group
examples/tensorflow_mnist.py
Outdated
|
||
from distutils.version import LooseVersion | ||
|
||
if LooseVersion(tf.__version__) >= LooseVersion("1.4.0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think it's OK to assume reasonably fresh TF version in examples - so, just doing tf.keras should be OK (and cleaner).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And if this breaks integration tests for TF 1.1.0, we can fix them by adding code in .travis.yml
to conditionally patch the import. This way users don't have to see this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Some of the examples were using depreciated tensorflow APIs to load the data and these changes fix the problem.
Issue #673