Skip to content

Commit

Permalink
Log if dataset is already available in our GCS bucket (redux)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 241424282
  • Loading branch information
Ryan Sepassi authored and Copybara-Service committed Apr 1, 2019
1 parent 592fdef commit a81a42f
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 55 deletions.
2 changes: 2 additions & 0 deletions tensorflow_datasets/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
# Directory where to store processed datasets.
DATA_DIR = os.path.join("~", "tensorflow_datasets")

GCS_DATA_DIR = "gs://tfds-data/datasets"

# Suffix of files / directories which aren't finished downloading / extracting.
INCOMPLETE_SUFFIX = ".incomplete"

55 changes: 49 additions & 6 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import units
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.utils import gcs_utils

import termcolor

Expand All @@ -47,6 +48,13 @@
REUSE_CACHE_IF_EXISTS = download.GenerateMode.REUSE_CACHE_IF_EXISTS
REUSE_DATASET_IF_EXISTS = download.GenerateMode.REUSE_DATASET_IF_EXISTS

GCS_HOSTED_MSG = """\
Dataset {name} is hosted on GCS. You can skip download_and_prepare by setting
data_dir=gs://tfds-data/datasets. If you find
that read performance is slow, copy the data locally with gsutil:
gsutil -m cp -R {gcs_path} {local_data_dir_no_version}
"""


class BuilderConfig(object):
"""Base class for `DatasetBuilder` data configuration.
Expand Down Expand Up @@ -194,11 +202,14 @@ def download_and_prepare(self, download_dir=None, download_config=None):

download_config = download_config or download.DownloadConfig()
data_exists = tf.io.gfile.exists(self._data_dir)
if (data_exists and
download_config.download_mode == REUSE_DATASET_IF_EXISTS):
if data_exists and download_config.download_mode == REUSE_DATASET_IF_EXISTS:
logging.info("Reusing dataset %s (%s)", self.name, self._data_dir)
return

# Data may exist on GCS
if not data_exists:
self._maybe_log_gcs_data_dir()

dl_manager = self._make_download_manager(
download_dir=download_dir,
download_config=download_config)
Expand Down Expand Up @@ -243,6 +254,7 @@ def download_and_prepare(self, download_dir=None, download_config=None):
self.info.size_in_bytes = dl_manager.downloaded_size
# Write DatasetInfo to disk, even if we haven't computed the statistics.
self.info.write_to_directory(self._data_dir)
self._log_download_done()

@api_utils.disallow_positional_args
def as_dataset(self,
Expand Down Expand Up @@ -340,14 +352,37 @@ def _build_single_dataset(self, split, shuffle_files, batch_size,
return tf.data.experimental.get_single_element(dataset)
return dataset

def _build_data_dir(self):
"""Return the data directory for the current version."""
builder_data_dir = os.path.join(self._data_dir_root, self.name)
def _maybe_log_gcs_data_dir(self):
"""If data is on GCS, set _data_dir to GCS path."""
if not gcs_utils.is_gcs_dataset_accessible(self.info.full_name):
return

gcs_path = os.path.join(constants.GCS_DATA_DIR, self.info.full_name)
msg = GCS_HOSTED_MSG.format(
name=self.name,
gcs_path=gcs_path,
local_data_dir_no_version=os.path.split(self._data_dir)[0])
logging.info(msg)

def _relative_data_dir(self, with_version=True):
"""Relative path of this dataset in data_dir."""
builder_data_dir = self.name
builder_config = self._builder_config
if builder_config:
builder_data_dir = os.path.join(builder_data_dir, builder_config.name)
if not with_version:
return builder_data_dir

version = self._version
version_data_dir = os.path.join(builder_data_dir, str(version))
return version_data_dir

def _build_data_dir(self):
"""Return the data directory for the current version."""
builder_data_dir = os.path.join(
self._data_dir_root, self._relative_data_dir(with_version=False))
version_data_dir = os.path.join(
self._data_dir_root, self._relative_data_dir(with_version=True))

def _other_versions_on_disk():
"""Returns previous versions on disk."""
Expand Down Expand Up @@ -380,13 +415,21 @@ def _other_versions_on_disk():

return version_data_dir

def _log_download_done(self):
msg = ("Dataset {name} downloaded and prepared to {data_dir}. "
"Subsequent calls will reuse this data.").format(
name=self.name,
data_dir=self._data_dir,
)
termcolor.cprint(msg, attrs=["bold"])

def _log_download_bytes(self):
# Print is intentional: we want this to always go to stdout so user has
# information needed to cancel download/preparation if needed.
# This comes right before the progress bar.
size_text = units.size_str(self.info.size_in_bytes)
termcolor.cprint(
"Downloading / extracting dataset %s (%s) to %s..." %
"Downloading and preparing dataset %s (%s) to %s..." %
(self.name, size_text, self._data_dir),
attrs=["bold"])
# TODO(tfds): Should try to estimate the available free disk space (if
Expand Down
46 changes: 3 additions & 43 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,33 +39,26 @@
import posixpath
import pprint
import tempfile
from xml.etree import ElementTree

from absl import logging
import numpy as np
import requests
import tensorflow as tf

from tensorflow_datasets.core import api_utils
from tensorflow_datasets.core import dataset_utils
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.proto import dataset_info_pb2
from tensorflow_datasets.core.utils import gcs_utils
from google.protobuf import json_format
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2


# Name of the file to output the DatasetInfo protobuf object.
DATASET_INFO_FILENAME = "dataset_info.json"

LICENSE_FILENAME = "LICENSE"

# GCS
GCS_URL = "http://storage.googleapis.com/tfds-data"
GCS_BUCKET = "gs://tfds-data"
GCS_DATASET_INFO_PATH = "dataset_info"

INFO_STR = """tfds.core.DatasetInfo(
name='{name}',
version={version},
Expand Down Expand Up @@ -377,14 +370,13 @@ def initialize_from_bucket(self):
# In order to support Colab, we use the HTTP GCS API to access the metadata
# files. They are copied locally and then loaded.
tmp_dir = tempfile.mkdtemp("tfds")
data_files = gcs_dataset_files(self.full_name)
data_files = gcs_utils.gcs_dataset_info_files(self.full_name)
if not data_files:
logging.info("No GCS info files found for %s", self.full_name)
return
logging.info("Loading info from GCS for %s", self.full_name)
for fname in data_files:
out_fname = os.path.join(tmp_dir, os.path.basename(fname))
download_gcs_file(fname, out_fname)
gcs_utils.download_gcs_file(fname, out_fname)
self.read_from_directory(tmp_dir)

def __str__(self):
Expand Down Expand Up @@ -571,35 +563,3 @@ def read_from_json(json_filename):
parsed_proto = json_format.Parse(dataset_info_json_str,
dataset_info_pb2.DatasetInfo())
return parsed_proto


def download_gcs_file(path, out_fname=None):
"""Download a file from GCS, optionally to a file."""
url = "/".join([GCS_URL, path])
stream = bool(out_fname)
resp = requests.get(url, stream=stream)
if not resp.ok:
raise ValueError("GCS bucket inaccessible")
if out_fname:
with tf.io.gfile.GFile(out_fname, "wb") as f:
for chunk in resp.iter_content(1024):
f.write(chunk)
else:
return resp.content


@utils.memoize()
def gcs_files():
top_level_xml_str = download_gcs_file("")
xml_root = ElementTree.fromstring(top_level_xml_str)
filenames = [el[0].text for el in xml_root if el.tag.endswith("Contents")]
return filenames


def gcs_dataset_files(dataset_dir):
"""Return paths to GCS files in the given dataset directory."""
prefix = posixpath.join(GCS_DATASET_INFO_PATH, dataset_dir, "")
# Filter for this dataset
filenames = [el for el in gcs_files()
if el.startswith(prefix) and len(el) > len(prefix)]
return filenames
68 changes: 68 additions & 0 deletions tensorflow_datasets/core/utils/gcs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# coding=utf-8
# Copyright 2019 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for accessing TFDS GCS buckets."""

import posixpath
from xml.etree import ElementTree

import requests
import tensorflow as tf

from tensorflow_datasets.core import utils

GCS_URL = "http://storage.googleapis.com"

# for dataset_info/
GCS_BUCKET = posixpath.join(GCS_URL, "tfds-data")
GCS_DATASET_INFO_DIR = "dataset_info"
GCS_DATASETS_DIR = "datasets"


def download_gcs_file(path, out_fname=None):
"""Download a file from GCS, optionally to a file."""
url = posixpath.join(GCS_BUCKET, path)
stream = bool(out_fname)
resp = requests.get(url, stream=stream)
if not resp.ok:
raise ValueError("GCS bucket inaccessible")
if out_fname:
with tf.io.gfile.GFile(out_fname, "wb") as f:
for chunk in resp.iter_content(1024):
f.write(chunk)
else:
return resp.content


@utils.memoize()
def gcs_files():
top_level_xml_str = download_gcs_file("")
xml_root = ElementTree.fromstring(top_level_xml_str)
filenames = [el[0].text for el in xml_root if el.tag.endswith("Contents")]
return filenames


def gcs_dataset_info_files(dataset_dir):
"""Return paths to GCS files in the given dataset directory."""
prefix = posixpath.join(GCS_DATASET_INFO_DIR, dataset_dir, "")
# Filter for this dataset
filenames = [el for el in gcs_files()
if el.startswith(prefix) and len(el) > len(prefix)]
return filenames


def is_gcs_dataset_accessible(dataset_dir):
info_file = posixpath.join(GCS_DATASETS_DIR, dataset_dir, "dataset_info.json")
return info_file in gcs_files()
42 changes: 42 additions & 0 deletions tensorflow_datasets/core/utils/gcs_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# coding=utf-8
# Copyright 2019 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GCS utils test."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow_datasets as tfds
from tensorflow_datasets import testing
from tensorflow_datasets.core.utils import gcs_utils


class GcsUtilsTest(testing.TestCase):

def is_dataset_accessible(self):
# Re-enable GCS access. TestCase disables it.
with self.gcs_access():
self.assertTrue(gcs_utils.is_gcs_dataset_accessible("mnist/1.0.0"))

def test_mnist(self):
with self.gcs_access():
mnist = tfds.image.MNIST(data_dir="gs://tfds-data/datasets")
example = next(tfds.as_numpy(mnist.as_dataset(split="train").take(1)))
_ = example["image"], example["label"]


if __name__ == "__main__":
testing.test_main()
17 changes: 11 additions & 6 deletions tensorflow_datasets/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
from absl.testing import absltest
import six
import tensorflow as tf
from tensorflow_datasets.core import dataset_info
from tensorflow_datasets.core.utils import gcs_utils



GCS_ACCESS_FNS = {
"original": dataset_info.gcs_dataset_files,
"dummy": lambda _: []
"original_info": gcs_utils.gcs_dataset_info_files,
"dummy_info": lambda _: [],
"original_datasets": gcs_utils.is_gcs_dataset_accessible,
"dummy_datasets": lambda _: False,
}


Expand All @@ -49,15 +51,18 @@ def setUpClass(cls):
super(TestCase, cls).setUpClass()
cls.test_data = os.path.join(os.path.dirname(__file__), "test_data")
# Test must not communicate with GCS.
dataset_info.gcs_dataset_files = GCS_ACCESS_FNS["dummy"]
gcs_utils.gcs_dataset_info_files = GCS_ACCESS_FNS["dummy_info"]
gcs_utils.is_gcs_dataset_accessible = GCS_ACCESS_FNS["dummy_datasets"]

@contextlib.contextmanager
def gcs_access(self):
# Restore GCS access
dataset_info.gcs_dataset_files = GCS_ACCESS_FNS["original"]
gcs_utils.gcs_dataset_info_files = GCS_ACCESS_FNS["original_info"]
gcs_utils.is_gcs_dataset_accessible = GCS_ACCESS_FNS["original_datasets"]
yield
# Revert access
dataset_info.gcs_dataset_files = GCS_ACCESS_FNS["dummy"]
gcs_utils.gcs_dataset_info_files = GCS_ACCESS_FNS["dummy_info"]
gcs_utils.is_gcs_dataset_accessible = GCS_ACCESS_FNS["dummy_datasets"]

def setUp(self):
super(TestCase, self).setUp()
Expand Down

0 comments on commit a81a42f

Please sign in to comment.