Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
12 contributors

Users who have contributed to this file

@rsepassi @Conchylicultor @pierrot0 @afrozenator @us @brettkoonce @habernal @ChanchalKumarMaji @yashk2810 @MarkDaoust @jsimsa @adarob
1074 lines (906 sloc) 42 KB
# coding=utf-8
# Copyright 2019 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DatasetBuilder base class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import functools
import itertools
import os
import sys
from absl import logging
import six
import tensorflow as tf
from tensorflow_datasets.core import api_utils
from tensorflow_datasets.core import constants
from tensorflow_datasets.core import dataset_utils
from tensorflow_datasets.core import download
from tensorflow_datasets.core import file_format_adapter
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import registered
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import tfrecords_reader
from tensorflow_datasets.core import tfrecords_writer
from tensorflow_datasets.core import units
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.utils import gcs_utils
import termcolor
FORCE_REDOWNLOAD = download.GenerateMode.FORCE_REDOWNLOAD
REUSE_CACHE_IF_EXISTS = download.GenerateMode.REUSE_CACHE_IF_EXISTS
REUSE_DATASET_IF_EXISTS = download.GenerateMode.REUSE_DATASET_IF_EXISTS
GCS_HOSTED_MSG = """\
Dataset {name} is hosted on GCS. You can skip download_and_prepare by setting
data_dir=gs://tfds-data/datasets. If you find
that read performance is slow, copy the data locally with gsutil:
gsutil -m cp -R {gcs_path} {local_data_dir_no_version}
"""
class BuilderConfig(object):
"""Base class for `DatasetBuilder` data configuration.
DatasetBuilder subclasses with data configuration options should subclass
`BuilderConfig` and add their own properties.
"""
@api_utils.disallow_positional_args
def __init__(self, name, version=None, supported_versions=None,
description=None):
self._name = name
self._version = version
self._supported_versions = supported_versions or []
self._description = description
@property
def name(self):
return self._name
@property
def version(self):
return self._version
@property
def supported_versions(self):
return self._supported_versions
@property
def description(self):
return self._description
def __repr__(self):
return "<{cls_name} name={name}, version={version}>".format(
cls_name=type(self).__name__,
name=self.name,
version=self.version or "None")
@six.add_metaclass(registered.RegisteredDataset)
class DatasetBuilder(object):
"""Abstract base class for all datasets.
`DatasetBuilder` has 3 key methods:
* `tfds.DatasetBuilder.info`: documents the dataset, including feature
names, types, and shapes, version, splits, citation, etc.
* `tfds.DatasetBuilder.download_and_prepare`: downloads the source data
and writes it to disk.
* `tfds.DatasetBuilder.as_dataset`: builds an input pipeline using
`tf.data.Dataset`s.
**Configuration**: Some `DatasetBuilder`s expose multiple variants of the
dataset by defining a `tfds.core.BuilderConfig` subclass and accepting a
config object (or name) on construction. Configurable datasets expose a
pre-defined set of configurations in `tfds.DatasetBuilder.builder_configs`.
Typical `DatasetBuilder` usage:
```python
mnist_builder = tfds.builder("mnist")
mnist_info = mnist_builder.info
mnist_builder.download_and_prepare()
datasets = mnist_builder.as_dataset()
train_dataset, test_dataset = datasets["train"], datasets["test"]
assert isinstance(train_dataset, tf.data.Dataset)
# And then the rest of your input pipeline
train_dataset = train_dataset.repeat().shuffle(1024).batch(128)
train_dataset = train_dataset.prefetch(2)
features = tf.compat.v1.data.make_one_shot_iterator(train_dataset).get_next()
image, label = features['image'], features['label']
```
"""
# Name of the dataset, filled by metaclass based on class name.
name = None
# Semantic version of the dataset (ex: tfds.core.Version('1.2.0'))
VERSION = None
# List dataset versions which can be loaded using current code.
# Data can only be prepared with canonical VERSION or above.
SUPPORTED_VERSIONS = []
# Named configurations that modify the data generated by download_and_prepare.
BUILDER_CONFIGS = []
# Set to True for datasets that are under active development and should not
# be available through tfds.{load, builder} or documented in overview.md.
IN_DEVELOPMENT = False
@api_utils.disallow_positional_args
def __init__(self, data_dir=None, config=None, version=None):
"""Constructs a DatasetBuilder.
Callers must pass arguments as keyword arguments.
Args:
data_dir: `str`, directory to read/write data. Defaults to
"~/tensorflow_datasets".
config: `tfds.core.BuilderConfig` or `str` name, optional configuration
for the dataset that affects the data generated on disk. Different
`builder_config`s will have their own subdirectories and versions.
version: `str`. Optional version at which to load the dataset. An error is
raised if specified version cannot be satisfied. Eg: '1.2.3', '1.2.*'.
The special value "experimental_latest" will use the highest version,
even if not default. This is not recommended unless you know what you
are doing, as the version could be broken.
"""
self._builder_config = self._create_builder_config(config)
# Extract code version (VERSION or config)
if not self._builder_config and not self.VERSION:
raise AssertionError(
"DatasetBuilder {} does not have a defined version. Please add a "
"`VERSION = tfds.core.Version('x.y.z')` to the class.".format(
self.name))
self._version = self._pick_version(version)
self._data_dir_root = os.path.expanduser(data_dir or constants.DATA_DIR)
self._data_dir = self._build_data_dir()
if tf.io.gfile.exists(self._data_dir):
logging.info("Overwrite dataset info from restored data version.")
self.info.read_from_directory(self._data_dir)
else: # Use the code version (do not restore data)
logging.info("Load pre-computed datasetinfo (eg: splits) from bucket.")
self.info.initialize_from_bucket()
def _pick_version(self, requested_version):
"""Returns utils.Version instance, or raise AssertionError."""
if self._builder_config:
canonical_version = self._builder_config.version
supported_versions = self._builder_config.supported_versions
else:
canonical_version = self.VERSION
supported_versions = self.SUPPORTED_VERSIONS
versions = [
utils.Version(v) if isinstance(v, six.string_types) else v
for v in [canonical_version] + supported_versions
]
if requested_version == "experimental_latest":
return max(versions)
for version in versions:
if requested_version is None or version.match(requested_version):
return version
available_versions = [str(v)
for v in [canonical_version] + supported_versions]
msg = "Dataset {} cannot be loaded at version {}, only: {}.".format(
self.name, requested_version, ", ".join(available_versions))
raise AssertionError(msg)
@property
def version(self):
return self._version
@property
def data_dir(self):
return self._data_dir
@utils.memoized_property
def info(self):
"""`tfds.core.DatasetInfo` for this builder."""
# Ensure .info hasn't been called before versioning is set-up
# Otherwise, backward compatibility cannot be guaranteed as some code will
# depend on the code version instead of the restored data version
if not getattr(self, "_version", None):
# Message for developper creating new dataset. Will trigger if they are
# using .info in the constructor before calling super().__init__
raise AssertionError(
"Info should not been called before version has been defined. "
"Otherwise, the created .info may not match the info version from "
"the restored dataset.")
return self._info()
@api_utils.disallow_positional_args
def download_and_prepare(self, download_dir=None, download_config=None):
"""Downloads and prepares dataset for reading.
Args:
download_dir: `str`, directory where downloaded files are stored.
Defaults to "~/tensorflow-datasets/downloads".
download_config: `tfds.download.DownloadConfig`, further configuration for
downloading and preparing dataset.
Raises:
IOError: if there is not enough disk space available.
"""
download_config = download_config or download.DownloadConfig()
data_exists = tf.io.gfile.exists(self._data_dir)
if data_exists and download_config.download_mode == REUSE_DATASET_IF_EXISTS:
logging.info("Reusing dataset %s (%s)", self.name, self._data_dir)
return
# Data may exist on GCS
if not data_exists:
self._maybe_log_gcs_data_dir()
dl_manager = self._make_download_manager(
download_dir=download_dir,
download_config=download_config)
# Currently it's not possible to overwrite the data because it would
# conflict with versioning: If the last version has already been generated,
# it will always be reloaded and data_dir will be set at construction.
if data_exists:
raise ValueError(
"Trying to overwrite an existing dataset {} at {}. A dataset with "
"the same version {} already exists. If the dataset has changed, "
"please update the version number.".format(self.name, self._data_dir,
self.version))
logging.info("Generating dataset %s (%s)", self.name, self._data_dir)
if not utils.has_sufficient_disk_space(
self.info.size_in_bytes, directory=self._data_dir_root):
raise IOError("Not enough disk space. Needed: %s" %
units.size_str(self.info.size_in_bytes))
self._log_download_bytes()
# Create a tmp dir and rename to self._data_dir on successful exit.
with file_format_adapter.incomplete_dir(self._data_dir) as tmp_data_dir:
# Temporarily assign _data_dir to tmp_data_dir to avoid having to forward
# it to every sub function.
with utils.temporary_assignment(self, "_data_dir", tmp_data_dir):
self._download_and_prepare(
dl_manager=dl_manager,
download_config=download_config)
# NOTE: If modifying the lines below to put additional information in
# DatasetInfo, you'll likely also want to update
# DatasetInfo.read_from_directory to possibly restore these attributes
# when reading from package data.
# Update the DatasetInfo metadata by computing statistics from the data.
if (download_config.compute_stats == download.ComputeStatsMode.SKIP or
download_config.compute_stats == download.ComputeStatsMode.AUTO and
bool(self.info.splits.total_num_examples)
):
logging.info(
"Skipping computing stats for mode %s.",
download_config.compute_stats)
else: # Mode is forced or stats do not exists yet
logging.info("Computing statistics.")
self.info.compute_dynamic_properties()
self.info.size_in_bytes = dl_manager.downloaded_size
# Write DatasetInfo to disk, even if we haven't computed the statistics.
self.info.write_to_directory(self._data_dir)
self._log_download_done()
@api_utils.disallow_positional_args
def as_dataset(self,
split=None,
batch_size=None,
shuffle_files=False,
decoders=None,
as_supervised=False,
in_memory=None):
# pylint: disable=line-too-long
"""Constructs a `tf.data.Dataset`.
Callers must pass arguments as keyword arguments.
The output types vary depending on the parameters. Examples:
```python
builder = tfds.builder('imdb_reviews')
builder.download_and_prepare()
# Default parameters: Returns the dict of tf.data.Dataset
ds_all_dict = builder.as_dataset()
assert isinstance(ds_all_dict, dict)
print(ds_all_dict.keys()) # ==> ['test', 'train', 'unsupervised']
assert isinstance(ds_all_dict['test'], tf.data.Dataset)
# Each dataset (test, train, unsup.) consists of dictionaries
# {'label': <tf.Tensor: .. dtype=int64, numpy=1>,
# 'text': <tf.Tensor: .. dtype=string, numpy=b"I've watched the movie ..">}
# {'label': <tf.Tensor: .. dtype=int64, numpy=1>,
# 'text': <tf.Tensor: .. dtype=string, numpy=b'If you love Japanese ..'>}
# With as_supervised: tf.data.Dataset only contains (feature, label) tuples
ds_all_supervised = builder.as_dataset(as_supervised=True)
assert isinstance(ds_all_supervised, dict)
print(ds_all_supervised.keys()) # ==> ['test', 'train', 'unsupervised']
assert isinstance(ds_all_supervised['test'], tf.data.Dataset)
# Each dataset (test, train, unsup.) consists of tuples (text, label)
# (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# Same as above plus requesting a particular split
ds_test_supervised = builder.as_dataset(as_supervised=True, split='test')
assert isinstance(ds_test_supervised, tf.data.Dataset)
# The dataset consists of tuples (text, label)
# (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
```
Args:
split: `tfds.core.SplitBase`, which subset(s) of the data to read. If None
(default), returns all splits in a dict
`<key: tfds.Split, value: tf.data.Dataset>`.
batch_size: `int`, batch size. Note that variable-length features will
be 0-padded if `batch_size` is set. Users that want more custom behavior
should use `batch_size=None` and use the `tf.data` API to construct a
custom pipeline. If `batch_size == -1`, will return feature
dictionaries of the whole dataset with `tf.Tensor`s instead of a
`tf.data.Dataset`.
shuffle_files: `bool`, whether to shuffle the input files. Defaults to
`False`.
decoders: Nested dict of `Decoder` objects which allow to customize the
decoding. The structure should match the feature structure, but only
customized feature keys need to be present. See
[the guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md)
for more info.
as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
will have a 2-tuple structure `(input, label)` according to
`builder.info.supervised_keys`. If `False`, the default,
the returned `tf.data.Dataset` will have a dictionary with all the
features.
in_memory: `bool`, if `True`, loads the dataset in memory which
increases iteration speeds. Note that if `True` and the dataset has
unknown dimensions, the features will be padded to the maximum
size across the dataset.
Returns:
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
tfds.data.Dataset>`.
If `batch_size` is -1, will return feature dictionaries containing
the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`.
"""
# pylint: enable=line-too-long
logging.info("Constructing tf.data.Dataset for split %s, from %s",
split, self._data_dir)
if not tf.io.gfile.exists(self._data_dir):
raise AssertionError(
("Dataset %s: could not find data in %s. Please make sure to call "
"dataset_builder.download_and_prepare(), or pass download=True to "
"tfds.load() before trying to access the tf.data.Dataset object."
) % (self.name, self._data_dir_root))
# By default, return all splits
if split is None:
split = {s: s for s in self.info.splits}
# Create a dataset for each of the given splits
build_single_dataset = functools.partial(
self._build_single_dataset,
shuffle_files=shuffle_files,
batch_size=batch_size,
decoders=decoders,
as_supervised=as_supervised,
in_memory=in_memory,
)
datasets = utils.map_nested(build_single_dataset, split, map_tuple=True)
return datasets
def _build_single_dataset(
self,
split,
shuffle_files,
batch_size,
decoders,
as_supervised,
in_memory):
"""as_dataset for a single split."""
if isinstance(split, six.string_types):
split = splits_lib.Split(split)
wants_full_dataset = batch_size == -1
if wants_full_dataset:
batch_size = self.info.splits.total_num_examples or sys.maxsize
# If the dataset is small, load it in memory
dataset_shape_is_fully_defined = (
dataset_utils.features_shape_is_fully_defined(self.info.features))
in_memory_default = False
# TODO(tfds): Consider default in_memory=True for small datasets with
# fully-defined shape.
# Expose and use the actual data size on disk and rm the manual
# name guards. size_in_bytes is the download size, which is misleading,
# particularly for datasets that use manual_dir as well as some downloads
# (wmt and diabetic_retinopathy_detection).
# in_memory_default = (
# self.info.size_in_bytes and
# self.info.size_in_bytes <= 1e9 and
# not self.name.startswith("wmt") and
# not self.name.startswith("diabetic") and
# dataset_shape_is_fully_defined)
in_memory = in_memory_default if in_memory is None else in_memory
# Build base dataset
if in_memory and not wants_full_dataset:
# TODO(tfds): Enable in_memory without padding features. May be able
# to do by using a requested version of tf.data.Dataset.cache that can
# persist a cache beyond iterator instances.
if not dataset_shape_is_fully_defined:
logging.warning("Called in_memory=True on a dataset that does not "
"have fully defined shapes. Note that features with "
"variable length dimensions will be 0-padded to "
"the maximum length across the dataset.")
full_bs = self.info.splits.total_num_examples or sys.maxsize
# If using in_memory, escape all device contexts so we can load the data
# with a local Session.
with tf.device(None):
dataset = self._as_dataset(
split=split, shuffle_files=shuffle_files, decoders=decoders)
# Use padded_batch so that features with unknown shape are supported.
dataset = dataset.padded_batch(full_bs, dataset.output_shapes)
dataset = tf.data.Dataset.from_tensor_slices(
next(dataset_utils.as_numpy(dataset)))
else:
dataset = self._as_dataset(
split=split, shuffle_files=shuffle_files, decoders=decoders)
if batch_size:
# Use padded_batch so that features with unknown shape are supported.
dataset = dataset.padded_batch(batch_size, dataset.output_shapes)
if as_supervised:
if not self.info.supervised_keys:
raise ValueError(
"as_supervised=True but %s does not support a supervised "
"(input, label) structure." % self.name)
input_f, target_f = self.info.supervised_keys
dataset = dataset.map(lambda fs: (fs[input_f], fs[target_f]),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
# If shuffling, allow pipeline to be non-deterministic
options = tf.data.Options()
options.experimental_deterministic = not shuffle_files
dataset = dataset.with_options(options)
if wants_full_dataset:
return tf.data.experimental.get_single_element(dataset)
return dataset
def _maybe_log_gcs_data_dir(self):
"""If data is on GCS, set _data_dir to GCS path."""
if not gcs_utils.is_dataset_on_gcs(self.info.full_name):
return
gcs_path = os.path.join(constants.GCS_DATA_DIR, self.info.full_name)
msg = GCS_HOSTED_MSG.format(
name=self.name,
gcs_path=gcs_path,
local_data_dir_no_version=os.path.split(self._data_dir)[0])
logging.info(msg)
def _relative_data_dir(self, with_version=True):
"""Relative path of this dataset in data_dir."""
builder_data_dir = self.name
builder_config = self._builder_config
if builder_config:
builder_data_dir = os.path.join(builder_data_dir, builder_config.name)
if not with_version:
return builder_data_dir
version = self._version
version_data_dir = os.path.join(builder_data_dir, str(version))
return version_data_dir
def _build_data_dir(self):
"""Return the data directory for the current version."""
builder_data_dir = os.path.join(
self._data_dir_root, self._relative_data_dir(with_version=False))
version_data_dir = os.path.join(
self._data_dir_root, self._relative_data_dir(with_version=True))
def _other_versions_on_disk():
"""Returns previous versions on disk."""
if not tf.io.gfile.exists(builder_data_dir):
return []
version_dirnames = []
for dir_name in tf.io.gfile.listdir(builder_data_dir):
try:
version_dirnames.append((utils.Version(dir_name), dir_name))
except ValueError: # Invalid version (ex: incomplete data dir)
pass
version_dirnames.sort(reverse=True)
return version_dirnames
# Check and warn if other versions exist on disk
version_dirs = _other_versions_on_disk()
if version_dirs:
other_version = version_dirs[0][0]
if other_version != self._version:
warn_msg = (
"Found a different version {other_version} of dataset {name} in "
"data_dir {data_dir}. Using currently defined version "
"{cur_version}.".format(
other_version=str(other_version),
name=self.name,
data_dir=self._data_dir_root,
cur_version=str(self._version)))
logging.warning(warn_msg)
return version_data_dir
def _log_download_done(self):
msg = ("Dataset {name} downloaded and prepared to {data_dir}. "
"Subsequent calls will reuse this data.").format(
name=self.name,
data_dir=self._data_dir,
)
termcolor.cprint(msg, attrs=["bold"])
def _log_download_bytes(self):
# Print is intentional: we want this to always go to stdout so user has
# information needed to cancel download/preparation if needed.
# This comes right before the progress bar.
size_text = units.size_str(self.info.size_in_bytes)
termcolor.cprint(
"Downloading and preparing dataset %s (%s) to %s..." %
(self.name, size_text, self._data_dir),
attrs=["bold"])
# TODO(tfds): Should try to estimate the available free disk space (if
# possible) and raise an error if not.
@abc.abstractmethod
def _info(self):
"""Construct the DatasetInfo object. See `DatasetInfo` for details.
Warning: This function is only called once and the result is cached for all
following .info() calls.
Returns:
dataset_info: (DatasetInfo) The dataset information
"""
raise NotImplementedError
@abc.abstractmethod
def _download_and_prepare(self, dl_manager, download_config=None):
"""Downloads and prepares dataset for reading.
This is the internal implementation to overwrite called when user calls
`download_and_prepare`. It should download all required data and generate
the pre-processed datasets files.
Args:
dl_manager: (DownloadManager) `DownloadManager` used to download and cache
data.
download_config: `DownloadConfig`, Additional options.
"""
raise NotImplementedError
@abc.abstractmethod
def _as_dataset(self, split, decoders=None, shuffle_files=False):
"""Constructs a `tf.data.Dataset`.
This is the internal implementation to overwrite called when user calls
`as_dataset`. It should read the pre-processed datasets files and generate
the `tf.data.Dataset` object.
Args:
split: `tfds.Split` which subset of the data to read.
decoders: Nested structure of `Decoder` object to customize the dataset
decoding.
shuffle_files: `bool`, whether to shuffle the input files. Optional,
defaults to `False`.
Returns:
`tf.data.Dataset`
"""
raise NotImplementedError
def _make_download_manager(self, download_dir, download_config):
download_dir = download_dir or os.path.join(self._data_dir_root,
"downloads")
extract_dir = (download_config.extract_dir or
os.path.join(download_dir, "extracted"))
manual_dir = (download_config.manual_dir or
os.path.join(download_dir, "manual"))
manual_dir = os.path.join(manual_dir, self.name)
return download.DownloadManager(
dataset_name=self.name,
download_dir=download_dir,
extract_dir=extract_dir,
manual_dir=manual_dir,
force_download=(download_config.download_mode == FORCE_REDOWNLOAD),
force_extraction=(download_config.download_mode == FORCE_REDOWNLOAD),
register_checksums=download_config.register_checksums,
)
@property
def builder_config(self):
"""`tfds.core.BuilderConfig` for this builder."""
return self._builder_config
def _create_builder_config(self, builder_config):
"""Create and validate BuilderConfig object."""
if builder_config is None and self.BUILDER_CONFIGS:
builder_config = self.BUILDER_CONFIGS[0]
logging.info("No config specified, defaulting to first: %s/%s", self.name,
builder_config.name)
if not builder_config:
return None
if isinstance(builder_config, six.string_types):
name = builder_config
builder_config = self.builder_configs.get(name)
if builder_config is None:
raise ValueError("BuilderConfig %s not found. Available: %s" %
(name, list(self.builder_configs.keys())))
name = builder_config.name
if not name:
raise ValueError("BuilderConfig must have a name, got %s" % name)
is_custom = name not in self.builder_configs
if is_custom:
logging.warning("Using custom data configuration %s", name)
else:
if builder_config is not self.builder_configs[name]:
raise ValueError(
"Cannot name a custom BuilderConfig the same as an available "
"BuilderConfig. Change the name. Available BuilderConfigs: %s" %
(list(self.builder_configs.keys())))
if not builder_config.version:
raise ValueError("BuilderConfig %s must have a version" % name)
if not builder_config.description:
raise ValueError("BuilderConfig %s must have a description" % name)
return builder_config
@utils.classproperty
@classmethod
@utils.memoize()
def builder_configs(cls):
"""Pre-defined list of configurations for this builder class."""
config_dict = {config.name: config for config in cls.BUILDER_CONFIGS}
if len(config_dict) != len(cls.BUILDER_CONFIGS):
names = [config.name for config in cls.BUILDER_CONFIGS]
raise ValueError(
"Names in BUILDER_CONFIGS must not be duplicated. Got %s" % names)
return config_dict
class FileAdapterBuilder(DatasetBuilder):
"""Base class for datasets with data generation based on file adapter.
`FileFormatAdapter`s are defined in
`tensorflow_datasets.core.file_format_adapter` and specify constraints on the
feature dictionaries yielded by example generators. See the class docstrings.
"""
@utils.memoized_property
def _example_specs(self):
return self.info.features.get_serialized_info()
@utils.memoized_property
def _file_format_adapter(self):
# Load the format adapter (TF-Record,...)
# The file_format_adapter module will eventually be replaced by
# tfrecords_{reader,writer} modules.
file_adapter_cls = file_format_adapter.TFRecordExampleAdapter
return file_adapter_cls(self._example_specs)
@property
def _tfrecords_reader(self):
return tfrecords_reader.Reader(self._data_dir, self._example_specs)
@abc.abstractmethod
def _split_generators(self, dl_manager):
"""Specify feature dictionary generators and dataset splits.
This function returns a list of `SplitGenerator`s defining how to generate
data and what splits to use.
Example:
return[
tfds.SplitGenerator(
name=tfds.Split.TRAIN,
num_shards=10,
gen_kwargs={'file': 'train_data.zip'},
),
tfds.SplitGenerator(
name=tfds.Split.TEST,
num_shards=5,
gen_kwargs={'file': 'test_data.zip'},
),
]
The above code will first call `_generate_examples(file='train_data.zip')`
to write the train data, then `_generate_examples(file='test_data.zip')` to
write the test data.
Datasets are typically split into different subsets to be used at various
stages of training and evaluation.
Note that for datasets without a `VALIDATION` split, you can use a
fraction of the `TRAIN` data for evaluation as you iterate on your model
so as not to overfit to the `TEST` data.
For downloads and extractions, use the given `download_manager`.
Note that the `DownloadManager` caches downloads, so it is fine to have each
generator attempt to download the source data.
A good practice is to download all data in this function, and then
distribute the relevant parts to each split with the `gen_kwargs` argument
Args:
dl_manager: (DownloadManager) Download manager to download the data
Returns:
`list<SplitGenerator>`.
"""
raise NotImplementedError()
@abc.abstractmethod
def _prepare_split(self, split_generator, **kwargs):
"""Generate the examples and record them on disk.
Args:
split_generator: `SplitGenerator`, Split generator to process
**kwargs: Additional kwargs forwarded from _download_and_prepare (ex:
beam pipeline)
"""
raise NotImplementedError()
def _download_and_prepare(self, dl_manager, **prepare_split_kwargs):
if not tf.io.gfile.exists(self._data_dir):
tf.io.gfile.makedirs(self._data_dir)
# Generating data for all splits
split_dict = splits_lib.SplitDict()
for split_generator in self._split_generators(dl_manager):
if splits_lib.Split.ALL == split_generator.split_info.name:
raise ValueError(
"tfds.Split.ALL is a special split keyword corresponding to the "
"union of all splits, so cannot be used as key in "
"._split_generator()."
)
logging.info("Generating split %s", split_generator.split_info.name)
split_dict.add(split_generator.split_info)
# Prepare split will record examples associated to the split
self._prepare_split(split_generator, **prepare_split_kwargs)
# Update the info object with the splits.
self.info.update_splits_if_different(split_dict)
def _as_dataset(
self,
split=splits_lib.Split.TRAIN,
decoders=None,
shuffle_files=False):
if self.version.implements(utils.Experiment.S3):
dataset = self._tfrecords_reader.read(
self.name, split, self.info.splits.values(), shuffle_files)
else:
# Resolve all the named split tree by real ones
read_instruction = split.get_read_instruction(self.info.splits)
# Extract the list of SlicedSplitInfo objects containing the splits
# to use and their associated slice
list_sliced_split_info = read_instruction.get_list_sliced_split_info()
# Resolve the SlicedSplitInfo objects into a list of
# {'filepath': 'path/to/data-00032-00100', 'mask': [True, False, ...]}
instruction_dicts = self._slice_split_info_to_instruction_dicts(
list_sliced_split_info)
# Load the dataset
dataset = dataset_utils.build_dataset(
instruction_dicts=instruction_dicts,
dataset_from_file_fn=self._file_format_adapter.dataset_from_filename,
shuffle_files=shuffle_files,
)
decode_fn = functools.partial(
self.info.features.decode_example, decoders=decoders)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def _slice_split_info_to_instruction_dicts(self, list_sliced_split_info):
"""Return the list of files and reading mask of the files to read."""
instruction_dicts = []
for sliced_split_info in list_sliced_split_info:
mask = splits_lib.slice_to_percent_mask(sliced_split_info.slice_value)
# Compute filenames from the given split
filepaths = list(sorted(self._build_split_filenames(
sliced_split_info.split_info)))
# Compute the offsets
if sliced_split_info.split_info.num_examples:
shard_id2num_examples = splits_lib.get_shard_id2num_examples(
sliced_split_info.split_info.num_shards,
sliced_split_info.split_info.num_examples,
)
mask_offsets = splits_lib.compute_mask_offsets(shard_id2num_examples)
else:
logging.warning(
"Statistics not present in the dataset. TFDS is not able to load "
"the total number of examples, so using the subsplit API may not "
"provide precise subsplits."
)
mask_offsets = [0] * len(filepaths)
for filepath, mask_offset in zip(filepaths, mask_offsets):
instruction_dicts.append({
"filepath": filepath,
"mask": mask,
"mask_offset": mask_offset,
})
return instruction_dicts
def _build_split_filenames(self, split_info):
"""Construct the split filenames associated with the split info.
The filenames correspond to the pre-processed datasets files present in
the root directory of the dataset.
Args:
split_info: (SplitInfo) needed split.
Returns:
filenames: (list[str]) The list of filenames path corresponding to the
split info object
"""
return naming.filepaths_for_dataset_split(
dataset_name=self.name,
split=split_info.name,
num_shards=split_info.num_shards,
data_dir=self._data_dir,
filetype_suffix=self._file_format_adapter.filetype_suffix,
)
class GeneratorBasedBuilder(FileAdapterBuilder):
"""Base class for datasets with data generation based on dict generators.
`GeneratorBasedBuilder` is a convenience class that abstracts away much
of the data writing and reading of `DatasetBuilder`. It expects subclasses to
implement generators of feature dictionaries across the dataset splits
(`_split_generators`) and to specify a file type
(`_file_format_adapter`). See the method docstrings for details.
`FileFormatAdapter`s are defined in
`tensorflow_datasets.core.file_format_adapter` and specify constraints on the
feature dictionaries yielded by example generators. See the class docstrings.
"""
@abc.abstractmethod
def _generate_examples(self, **kwargs):
"""Default function generating examples for each `SplitGenerator`.
This function preprocess the examples from the raw data to the preprocessed
dataset files.
This function is called once for each `SplitGenerator` defined in
`_split_generators`. The examples yielded here will be written on
disk.
Args:
**kwargs: (dict) Arguments forwarded from the SplitGenerator.gen_kwargs
Yields:
example: (`dict<str feature_name, feature_value>`), a feature dictionary
ready to be encoded and written to disk. The example will be
encoded with `self.info.features.encode_example({...})`.
"""
raise NotImplementedError()
def _download_and_prepare(self, dl_manager, download_config):
# Extract max_examples_per_split and forward it to _prepare_split
super(GeneratorBasedBuilder, self)._download_and_prepare(
dl_manager=dl_manager,
max_examples_per_split=download_config.max_examples_per_split,
)
def _prepare_split_legacy(self, generator, split_info):
# TODO(pierrot): delete function once S3 has been fully rolled-out.
# For builders having both S3 and non S3 versions: drop key if any yielded.
generator = (ex[1] if isinstance(ex, tuple) else ex for ex in generator)
generator = (self.info.features.encode_example(ex) for ex in generator)
output_files = self._build_split_filenames(split_info)
self._file_format_adapter.write_from_generator(generator, output_files)
def _prepare_split(self, split_generator, max_examples_per_split):
generator = self._generate_examples(**split_generator.gen_kwargs)
split_info = split_generator.split_info
if max_examples_per_split is not None:
logging.warning("Splits capped at %s examples max.",
max_examples_per_split)
generator = itertools.islice(generator, max_examples_per_split)
if not self.version.implements(utils.Experiment.S3):
return self._prepare_split_legacy(generator, split_info)
fname = "{}-{}.tfrecord".format(self.name, split_generator.name)
fpath = os.path.join(self._data_dir, fname)
writer = tfrecords_writer.Writer(self._example_specs, fpath,
hash_salt=split_generator.name)
for key, record in utils.tqdm(generator, unit=" examples",
total=split_info.num_examples, leave=False):
example = self.info.features.encode_example(record)
writer.write(key, example)
shard_lengths = writer.finalize()
split_generator.split_info.shard_lengths.extend(shard_lengths)
class BeamBasedBuilder(FileAdapterBuilder):
"""Beam based Builder."""
@abc.abstractmethod
def _build_pcollection(self, pipeline, **kwargs):
"""Build the beam pipeline examples for each `SplitGenerator`.
This function extracts examples from the raw data with parallel transforms
in a Beam pipeline. It is called once for each `SplitGenerator` defined in
`_split_generators`. The examples from the PCollection will be
encoded and written to disk.
Beam liquid sharding can be used by setting num_shards to `None` in the
`SplitGenerator`.
Warning: When running in a distributed setup, make sure that the data
which will be read (download_dir, manual_dir,...) and written (data_dir)
can be accessed by the workers jobs. The data should be located in a
shared filesystem, like GCS.
Example:
```
def _build_pcollection(pipeline, extracted_dir):
return (
pipeline
| beam.Create(gfile.io.listdir(extracted_dir))
| beam.Map(_process_file)
)
```
Args:
pipeline: `beam.Pipeline`, root Beam pipeline
**kwargs: Arguments forwarded from the SplitGenerator.gen_kwargs
Returns:
pcollection: `PCollection`, an Apache Beam PCollection containing the
example to send to `self.info.features.encode_example(...)`.
"""
raise NotImplementedError()
def _download_and_prepare(self, dl_manager, download_config):
# Create the Beam pipeline and forward it to _prepare_split
beam = lazy_imports_lib.lazy_imports.apache_beam
if not download_config.beam_runner and not download_config.beam_options:
raise ValueError(
"Trying to generate a dataset using Apache Beam, yet no Beam Runner "
"or PipelineOptions() has been provided. Please pass a "
"tfds.download.DownloadConfig(beam_runner=...) object to the "
"builder.download_and_prepare(download_config=...) method"
)
# Use a single pipeline for all splits
with beam.Pipeline(
runner=download_config.beam_runner,
options=download_config.beam_options,
) as pipeline:
# TODO(tfds): Should eventually try to add support to
# download_config.max_examples_per_split
super(BeamBasedBuilder, self)._download_and_prepare(
dl_manager,
pipeline=pipeline,
)
# Update the number of shards for splits where liquid sharding were used.
split_dict = self.info.splits
for split_info in split_dict.values():
if not split_info.num_shards:
output_prefix = naming.filename_prefix_for_split(
self.name, split_info.name)
output_prefix = os.path.join(self._data_dir, output_prefix)
split_info.num_shards = len(tf.io.gfile.glob(output_prefix + "*"))
self.info.update_splits_if_different(split_dict)
def _prepare_split(self, split_generator, pipeline):
beam = lazy_imports_lib.lazy_imports.apache_beam
if not tf.io.gfile.exists(self._data_dir):
tf.io.gfile.makedirs(self._data_dir)
split_info = split_generator.split_info
output_prefix = naming.filename_prefix_for_split(
self.name, split_info.name)
output_prefix = os.path.join(self._data_dir, output_prefix)
# Note: We need to wrap the pipeline in a PTransform to avoid re-using the
# same label names for each split
@beam.ptransform_fn
def _build_pcollection(pipeline):
"""PTransformation which build a single split."""
# Encode the PCollection
pcoll_examples = self._build_pcollection(
pipeline, **split_generator.gen_kwargs)
pcoll_examples |= "Encode" >> beam.Map(self.info.features.encode_example)
# Write the example to disk
return self._file_format_adapter.write_from_pcollection(
pcoll_examples,
file_path_prefix=output_prefix,
num_shards=split_info.num_shards,
)
# Add the PCollection to the pipeline
_ = pipeline | split_info.name >> _build_pcollection() # pylint: disable=no-value-for-parameter
You can’t perform that action at this time.