Skip to content

Commit

Permalink
Start Travis testing against TF 2.0 (#31)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 232046753
  • Loading branch information
Ryan Sepassi authored and Copybara-Service committed Feb 1, 2019
1 parent 78bb6ee commit 740e722
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Expand Up @@ -13,6 +13,7 @@ env:
matrix:
- TF_VERSION="tf-nightly"
- TF_VERSION="1.13.0rc0"
- TF_VERSION="tf2"
install:
- ./oss_scripts/oss_pip_install.sh
script:
Expand Down
15 changes: 14 additions & 1 deletion oss_scripts/oss_tests.sh
Expand Up @@ -14,8 +14,21 @@ function set_status() {
STATUS=$(($last_status || $STATUS))
}

# Certain datasets/tests don't work with TF2
# Skip them here, and link to a GitHub issue that explains why it doesn't work
# and what the plan is to support it.
TF2_IGNORE_TESTS=""
if [[ "$TF_VERSION" == "tf2" ]]
then
# * lsun_test: https://github.com/tensorflow/datasets/issues/34
TF2_IGNORE_TESTS="
tensorflow_datasets/image/lsun_test.py
"
fi
TF2_IGNORE=$(for test in $TF2_IGNORE_TESTS; do echo "--ignore=$test "; done)

# Run Tests
pytest --ignore="tensorflow_datasets/core/test_utils.py"
pytest $TF2_IGNORE --ignore="tensorflow_datasets/core/test_utils.py"
set_status

# Test notebooks
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_datasets/image/open_images.py
Expand Up @@ -178,7 +178,7 @@ def _split_generators(self, dl_manager):
paths = dl_manager.download_and_extract(_URLS)
source_str2int = self.info.features['objects']['source'].str2int
# Set the labels' names:
with tf.gfile.Open(paths['class_descriptions']) as classes_f:
with tf.io.gfile.GFile(paths['class_descriptions']) as classes_f:
classes = [l.split(',')[0]
for l in classes_f.read().split('\n') if l]
logging.info('Number of loaded classes: %s', len(classes))
Expand Down Expand Up @@ -254,7 +254,7 @@ def _load_objects(csv_paths, source_str2int, label_str2int,
csv_paths, csv_positions, prefix)
objects = collections.defaultdict(list)
for i, labels_path in enumerate(csv_paths):
with tf.gfile.Open(labels_path) as csv_f:
with tf.io.gfile.GFile(labels_path) as csv_f:
if csv_positions[i] > 0:
csv_f.seek(csv_positions[i])
else:
Expand All @@ -278,7 +278,7 @@ def _load_bboxes(csv_path, source_str2int, label_str2int,
logging.info('Loading CSVs %s from positions %s with prefix %s',
csv_path, csv_positions, prefix)
boxes = collections.defaultdict(list)
with tf.gfile.Open(csv_path) as csv_f:
with tf.io.gfile.GFile(csv_path) as csv_f:
if csv_positions[0] > 0:
csv_f.seek(csv_positions[0])
else:
Expand Down

0 comments on commit 740e722

Please sign in to comment.