Skip to content
Branch: master
Find file Copy path
Find file Copy path
12 contributors

Users who have contributed to this file

@rsepassi @Conchylicultor @pierrot0 @afrozenator @us @brettkoonce @habernal @ChanchalKumarMaji @yashk2810 @MarkDaoust @jsimsa @adarob
1074 lines (906 sloc) 42 KB
# coding=utf-8
# Copyright 2019 The TensorFlow Datasets Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""DatasetBuilder base class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import functools
import itertools
import os
import sys
from absl import logging
import six
import tensorflow as tf
from tensorflow_datasets.core import api_utils
from tensorflow_datasets.core import constants
from tensorflow_datasets.core import dataset_utils
from tensorflow_datasets.core import download
from tensorflow_datasets.core import file_format_adapter
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import registered
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import tfrecords_reader
from tensorflow_datasets.core import tfrecords_writer
from tensorflow_datasets.core import units
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.utils import gcs_utils
import termcolor
Dataset {name} is hosted on GCS. You can skip download_and_prepare by setting
data_dir=gs://tfds-data/datasets. If you find
that read performance is slow, copy the data locally with gsutil:
gsutil -m cp -R {gcs_path} {local_data_dir_no_version}
class BuilderConfig(object):
"""Base class for `DatasetBuilder` data configuration.
DatasetBuilder subclasses with data configuration options should subclass
`BuilderConfig` and add their own properties.
def __init__(self, name, version=None, supported_versions=None,
self._name = name
self._version = version
self._supported_versions = supported_versions or []
self._description = description
def name(self):
return self._name
def version(self):
return self._version
def supported_versions(self):
return self._supported_versions
def description(self):
return self._description
def __repr__(self):
return "<{cls_name} name={name}, version={version}>".format(
version=self.version or "None")
class DatasetBuilder(object):
"""Abstract base class for all datasets.
`DatasetBuilder` has 3 key methods:
* ``: documents the dataset, including feature
names, types, and shapes, version, splits, citation, etc.
* `tfds.DatasetBuilder.download_and_prepare`: downloads the source data
and writes it to disk.
* `tfds.DatasetBuilder.as_dataset`: builds an input pipeline using
**Configuration**: Some `DatasetBuilder`s expose multiple variants of the
dataset by defining a `tfds.core.BuilderConfig` subclass and accepting a
config object (or name) on construction. Configurable datasets expose a
pre-defined set of configurations in `tfds.DatasetBuilder.builder_configs`.
Typical `DatasetBuilder` usage:
mnist_builder = tfds.builder("mnist")
mnist_info =
datasets = mnist_builder.as_dataset()
train_dataset, test_dataset = datasets["train"], datasets["test"]
assert isinstance(train_dataset,
# And then the rest of your input pipeline
train_dataset = train_dataset.repeat().shuffle(1024).batch(128)
train_dataset = train_dataset.prefetch(2)
features =
image, label = features['image'], features['label']
# Name of the dataset, filled by metaclass based on class name.
name = None
# Semantic version of the dataset (ex: tfds.core.Version('1.2.0'))
# List dataset versions which can be loaded using current code.
# Data can only be prepared with canonical VERSION or above.
# Named configurations that modify the data generated by download_and_prepare.
# Set to True for datasets that are under active development and should not
# be available through tfds.{load, builder} or documented in
def __init__(self, data_dir=None, config=None, version=None):
"""Constructs a DatasetBuilder.
Callers must pass arguments as keyword arguments.
data_dir: `str`, directory to read/write data. Defaults to
config: `tfds.core.BuilderConfig` or `str` name, optional configuration
for the dataset that affects the data generated on disk. Different
`builder_config`s will have their own subdirectories and versions.
version: `str`. Optional version at which to load the dataset. An error is
raised if specified version cannot be satisfied. Eg: '1.2.3', '1.2.*'.
The special value "experimental_latest" will use the highest version,
even if not default. This is not recommended unless you know what you
are doing, as the version could be broken.
self._builder_config = self._create_builder_config(config)
# Extract code version (VERSION or config)
if not self._builder_config and not self.VERSION:
raise AssertionError(
"DatasetBuilder {} does not have a defined version. Please add a "
"`VERSION = tfds.core.Version('x.y.z')` to the class.".format(
self._version = self._pick_version(version)
self._data_dir_root = os.path.expanduser(data_dir or constants.DATA_DIR)
self._data_dir = self._build_data_dir()
if"Overwrite dataset info from restored data version.")
else: # Use the code version (do not restore data)"Load pre-computed datasetinfo (eg: splits) from bucket.")
def _pick_version(self, requested_version):
"""Returns utils.Version instance, or raise AssertionError."""
if self._builder_config:
canonical_version = self._builder_config.version
supported_versions = self._builder_config.supported_versions
canonical_version = self.VERSION
supported_versions = self.SUPPORTED_VERSIONS
versions = [
utils.Version(v) if isinstance(v, six.string_types) else v
for v in [canonical_version] + supported_versions
if requested_version == "experimental_latest":
return max(versions)
for version in versions:
if requested_version is None or version.match(requested_version):
return version
available_versions = [str(v)
for v in [canonical_version] + supported_versions]
msg = "Dataset {} cannot be loaded at version {}, only: {}.".format(, requested_version, ", ".join(available_versions))
raise AssertionError(msg)
def version(self):
return self._version
def data_dir(self):
return self._data_dir
def info(self):
"""`tfds.core.DatasetInfo` for this builder."""
# Ensure .info hasn't been called before versioning is set-up
# Otherwise, backward compatibility cannot be guaranteed as some code will
# depend on the code version instead of the restored data version
if not getattr(self, "_version", None):
# Message for developper creating new dataset. Will trigger if they are
# using .info in the constructor before calling super().__init__
raise AssertionError(
"Info should not been called before version has been defined. "
"Otherwise, the created .info may not match the info version from "
"the restored dataset.")
return self._info()
def download_and_prepare(self, download_dir=None, download_config=None):
"""Downloads and prepares dataset for reading.
download_dir: `str`, directory where downloaded files are stored.
Defaults to "~/tensorflow-datasets/downloads".
download_config: ``, further configuration for
downloading and preparing dataset.
IOError: if there is not enough disk space available.
download_config = download_config or download.DownloadConfig()
data_exists =
if data_exists and download_config.download_mode == REUSE_DATASET_IF_EXISTS:"Reusing dataset %s (%s)",, self._data_dir)
# Data may exist on GCS
if not data_exists:
dl_manager = self._make_download_manager(
# Currently it's not possible to overwrite the data because it would
# conflict with versioning: If the last version has already been generated,
# it will always be reloaded and data_dir will be set at construction.
if data_exists:
raise ValueError(
"Trying to overwrite an existing dataset {} at {}. A dataset with "
"the same version {} already exists. If the dataset has changed, "
"please update the version number.".format(, self._data_dir,
self.version))"Generating dataset %s (%s)",, self._data_dir)
if not utils.has_sufficient_disk_space(, directory=self._data_dir_root):
raise IOError("Not enough disk space. Needed: %s" %
# Create a tmp dir and rename to self._data_dir on successful exit.
with file_format_adapter.incomplete_dir(self._data_dir) as tmp_data_dir:
# Temporarily assign _data_dir to tmp_data_dir to avoid having to forward
# it to every sub function.
with utils.temporary_assignment(self, "_data_dir", tmp_data_dir):
# NOTE: If modifying the lines below to put additional information in
# DatasetInfo, you'll likely also want to update
# DatasetInfo.read_from_directory to possibly restore these attributes
# when reading from package data.
# Update the DatasetInfo metadata by computing statistics from the data.
if (download_config.compute_stats == download.ComputeStatsMode.SKIP or
download_config.compute_stats == download.ComputeStatsMode.AUTO and
"Skipping computing stats for mode %s.",
else: # Mode is forced or stats do not exists yet"Computing statistics.") = dl_manager.downloaded_size
# Write DatasetInfo to disk, even if we haven't computed the statistics.
def as_dataset(self,
# pylint: disable=line-too-long
"""Constructs a ``.
Callers must pass arguments as keyword arguments.
The output types vary depending on the parameters. Examples:
builder = tfds.builder('imdb_reviews')
# Default parameters: Returns the dict of
ds_all_dict = builder.as_dataset()
assert isinstance(ds_all_dict, dict)
print(ds_all_dict.keys()) # ==> ['test', 'train', 'unsupervised']
assert isinstance(ds_all_dict['test'],
# Each dataset (test, train, unsup.) consists of dictionaries
# {'label': <tf.Tensor: .. dtype=int64, numpy=1>,
# 'text': <tf.Tensor: .. dtype=string, numpy=b"I've watched the movie ..">}
# {'label': <tf.Tensor: .. dtype=int64, numpy=1>,
# 'text': <tf.Tensor: .. dtype=string, numpy=b'If you love Japanese ..'>}
# With as_supervised: only contains (feature, label) tuples
ds_all_supervised = builder.as_dataset(as_supervised=True)
assert isinstance(ds_all_supervised, dict)
print(ds_all_supervised.keys()) # ==> ['test', 'train', 'unsupervised']
assert isinstance(ds_all_supervised['test'],
# Each dataset (test, train, unsup.) consists of tuples (text, label)
# (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# Same as above plus requesting a particular split
ds_test_supervised = builder.as_dataset(as_supervised=True, split='test')
assert isinstance(ds_test_supervised,
# The dataset consists of tuples (text, label)
# (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
split: `tfds.core.SplitBase`, which subset(s) of the data to read. If None
(default), returns all splits in a dict
`<key: tfds.Split, value:>`.
batch_size: `int`, batch size. Note that variable-length features will
be 0-padded if `batch_size` is set. Users that want more custom behavior
should use `batch_size=None` and use the `` API to construct a
custom pipeline. If `batch_size == -1`, will return feature
dictionaries of the whole dataset with `tf.Tensor`s instead of a
shuffle_files: `bool`, whether to shuffle the input files. Defaults to
decoders: Nested dict of `Decoder` objects which allow to customize the
decoding. The structure should match the feature structure, but only
customized feature keys need to be present. See
[the guide](
for more info.
as_supervised: `bool`, if `True`, the returned ``
will have a 2-tuple structure `(input, label)` according to
``. If `False`, the default,
the returned `` will have a dictionary with all the
in_memory: `bool`, if `True`, loads the dataset in memory which
increases iteration speeds. Note that if `True` and the dataset has
unknown dimensions, the features will be padded to the maximum
size across the dataset.
``, or if `split=None`, `dict<key: tfds.Split, value:>`.
If `batch_size` is -1, will return feature dictionaries containing
the entire dataset in `tf.Tensor`s instead of a ``.
# pylint: enable=line-too-long"Constructing for split %s, from %s",
split, self._data_dir)
if not
raise AssertionError(
("Dataset %s: could not find data in %s. Please make sure to call "
"dataset_builder.download_and_prepare(), or pass download=True to "
"tfds.load() before trying to access the object."
) % (, self._data_dir_root))
# By default, return all splits
if split is None:
split = {s: s for s in}
# Create a dataset for each of the given splits
build_single_dataset = functools.partial(
datasets = utils.map_nested(build_single_dataset, split, map_tuple=True)
return datasets
def _build_single_dataset(
"""as_dataset for a single split."""
if isinstance(split, six.string_types):
split = splits_lib.Split(split)
wants_full_dataset = batch_size == -1
if wants_full_dataset:
batch_size = or sys.maxsize
# If the dataset is small, load it in memory
dataset_shape_is_fully_defined = (
in_memory_default = False
# TODO(tfds): Consider default in_memory=True for small datasets with
# fully-defined shape.
# Expose and use the actual data size on disk and rm the manual
# name guards. size_in_bytes is the download size, which is misleading,
# particularly for datasets that use manual_dir as well as some downloads
# (wmt and diabetic_retinopathy_detection).
# in_memory_default = (
# and
# <= 1e9 and
# not"wmt") and
# not"diabetic") and
# dataset_shape_is_fully_defined)
in_memory = in_memory_default if in_memory is None else in_memory
# Build base dataset
if in_memory and not wants_full_dataset:
# TODO(tfds): Enable in_memory without padding features. May be able
# to do by using a requested version of that can
# persist a cache beyond iterator instances.
if not dataset_shape_is_fully_defined:
logging.warning("Called in_memory=True on a dataset that does not "
"have fully defined shapes. Note that features with "
"variable length dimensions will be 0-padded to "
"the maximum length across the dataset.")
full_bs = or sys.maxsize
# If using in_memory, escape all device contexts so we can load the data
# with a local Session.
with tf.device(None):
dataset = self._as_dataset(
split=split, shuffle_files=shuffle_files, decoders=decoders)
# Use padded_batch so that features with unknown shape are supported.
dataset = dataset.padded_batch(full_bs, dataset.output_shapes)
dataset =
dataset = self._as_dataset(
split=split, shuffle_files=shuffle_files, decoders=decoders)
if batch_size:
# Use padded_batch so that features with unknown shape are supported.
dataset = dataset.padded_batch(batch_size, dataset.output_shapes)
if as_supervised:
if not
raise ValueError(
"as_supervised=True but %s does not support a supervised "
"(input, label) structure." %
input_f, target_f =
dataset = fs: (fs[input_f], fs[target_f]),
dataset = dataset.prefetch(
# If shuffling, allow pipeline to be non-deterministic
options =
options.experimental_deterministic = not shuffle_files
dataset = dataset.with_options(options)
if wants_full_dataset:
return dataset
def _maybe_log_gcs_data_dir(self):
"""If data is on GCS, set _data_dir to GCS path."""
if not gcs_utils.is_dataset_on_gcs(
gcs_path = os.path.join(constants.GCS_DATA_DIR,
msg = GCS_HOSTED_MSG.format(,
def _relative_data_dir(self, with_version=True):
"""Relative path of this dataset in data_dir."""
builder_data_dir =
builder_config = self._builder_config
if builder_config:
builder_data_dir = os.path.join(builder_data_dir,
if not with_version:
return builder_data_dir
version = self._version
version_data_dir = os.path.join(builder_data_dir, str(version))
return version_data_dir
def _build_data_dir(self):
"""Return the data directory for the current version."""
builder_data_dir = os.path.join(
self._data_dir_root, self._relative_data_dir(with_version=False))
version_data_dir = os.path.join(
self._data_dir_root, self._relative_data_dir(with_version=True))
def _other_versions_on_disk():
"""Returns previous versions on disk."""
if not
return []
version_dirnames = []
for dir_name in
version_dirnames.append((utils.Version(dir_name), dir_name))
except ValueError: # Invalid version (ex: incomplete data dir)
return version_dirnames
# Check and warn if other versions exist on disk
version_dirs = _other_versions_on_disk()
if version_dirs:
other_version = version_dirs[0][0]
if other_version != self._version:
warn_msg = (
"Found a different version {other_version} of dataset {name} in "
"data_dir {data_dir}. Using currently defined version "
return version_data_dir
def _log_download_done(self):
msg = ("Dataset {name} downloaded and prepared to {data_dir}. "
"Subsequent calls will reuse this data.").format(,
termcolor.cprint(msg, attrs=["bold"])
def _log_download_bytes(self):
# Print is intentional: we want this to always go to stdout so user has
# information needed to cancel download/preparation if needed.
# This comes right before the progress bar.
size_text = units.size_str(
"Downloading and preparing dataset %s (%s) to %s..." %
(, size_text, self._data_dir),
# TODO(tfds): Should try to estimate the available free disk space (if
# possible) and raise an error if not.
def _info(self):
"""Construct the DatasetInfo object. See `DatasetInfo` for details.
Warning: This function is only called once and the result is cached for all
following .info() calls.
dataset_info: (DatasetInfo) The dataset information
raise NotImplementedError
def _download_and_prepare(self, dl_manager, download_config=None):
"""Downloads and prepares dataset for reading.
This is the internal implementation to overwrite called when user calls
`download_and_prepare`. It should download all required data and generate
the pre-processed datasets files.
dl_manager: (DownloadManager) `DownloadManager` used to download and cache
download_config: `DownloadConfig`, Additional options.
raise NotImplementedError
def _as_dataset(self, split, decoders=None, shuffle_files=False):
"""Constructs a ``.
This is the internal implementation to overwrite called when user calls
`as_dataset`. It should read the pre-processed datasets files and generate
the `` object.
split: `tfds.Split` which subset of the data to read.
decoders: Nested structure of `Decoder` object to customize the dataset
shuffle_files: `bool`, whether to shuffle the input files. Optional,
defaults to `False`.
raise NotImplementedError
def _make_download_manager(self, download_dir, download_config):
download_dir = download_dir or os.path.join(self._data_dir_root,
extract_dir = (download_config.extract_dir or
os.path.join(download_dir, "extracted"))
manual_dir = (download_config.manual_dir or
os.path.join(download_dir, "manual"))
manual_dir = os.path.join(manual_dir,
return download.DownloadManager(,
force_download=(download_config.download_mode == FORCE_REDOWNLOAD),
force_extraction=(download_config.download_mode == FORCE_REDOWNLOAD),
def builder_config(self):
"""`tfds.core.BuilderConfig` for this builder."""
return self._builder_config
def _create_builder_config(self, builder_config):
"""Create and validate BuilderConfig object."""
if builder_config is None and self.BUILDER_CONFIGS:
builder_config = self.BUILDER_CONFIGS[0]"No config specified, defaulting to first: %s/%s",,
if not builder_config:
return None
if isinstance(builder_config, six.string_types):
name = builder_config
builder_config = self.builder_configs.get(name)
if builder_config is None:
raise ValueError("BuilderConfig %s not found. Available: %s" %
(name, list(self.builder_configs.keys())))
name =
if not name:
raise ValueError("BuilderConfig must have a name, got %s" % name)
is_custom = name not in self.builder_configs
if is_custom:
logging.warning("Using custom data configuration %s", name)
if builder_config is not self.builder_configs[name]:
raise ValueError(
"Cannot name a custom BuilderConfig the same as an available "
"BuilderConfig. Change the name. Available BuilderConfigs: %s" %
if not builder_config.version:
raise ValueError("BuilderConfig %s must have a version" % name)
if not builder_config.description:
raise ValueError("BuilderConfig %s must have a description" % name)
return builder_config
def builder_configs(cls):
"""Pre-defined list of configurations for this builder class."""
config_dict = { config for config in cls.BUILDER_CONFIGS}
if len(config_dict) != len(cls.BUILDER_CONFIGS):
names = [ for config in cls.BUILDER_CONFIGS]
raise ValueError(
"Names in BUILDER_CONFIGS must not be duplicated. Got %s" % names)
return config_dict
class FileAdapterBuilder(DatasetBuilder):
"""Base class for datasets with data generation based on file adapter.
`FileFormatAdapter`s are defined in
`tensorflow_datasets.core.file_format_adapter` and specify constraints on the
feature dictionaries yielded by example generators. See the class docstrings.
def _example_specs(self):
def _file_format_adapter(self):
# Load the format adapter (TF-Record,...)
# The file_format_adapter module will eventually be replaced by
# tfrecords_{reader,writer} modules.
file_adapter_cls = file_format_adapter.TFRecordExampleAdapter
return file_adapter_cls(self._example_specs)
def _tfrecords_reader(self):
return tfrecords_reader.Reader(self._data_dir, self._example_specs)
def _split_generators(self, dl_manager):
"""Specify feature dictionary generators and dataset splits.
This function returns a list of `SplitGenerator`s defining how to generate
data and what splits to use.
gen_kwargs={'file': ''},
gen_kwargs={'file': ''},
The above code will first call `_generate_examples(file='')`
to write the train data, then `_generate_examples(file='')` to
write the test data.
Datasets are typically split into different subsets to be used at various
stages of training and evaluation.
Note that for datasets without a `VALIDATION` split, you can use a
fraction of the `TRAIN` data for evaluation as you iterate on your model
so as not to overfit to the `TEST` data.
For downloads and extractions, use the given `download_manager`.
Note that the `DownloadManager` caches downloads, so it is fine to have each
generator attempt to download the source data.
A good practice is to download all data in this function, and then
distribute the relevant parts to each split with the `gen_kwargs` argument
dl_manager: (DownloadManager) Download manager to download the data
raise NotImplementedError()
def _prepare_split(self, split_generator, **kwargs):
"""Generate the examples and record them on disk.
split_generator: `SplitGenerator`, Split generator to process
**kwargs: Additional kwargs forwarded from _download_and_prepare (ex:
beam pipeline)
raise NotImplementedError()
def _download_and_prepare(self, dl_manager, **prepare_split_kwargs):
if not
# Generating data for all splits
split_dict = splits_lib.SplitDict()
for split_generator in self._split_generators(dl_manager):
if splits_lib.Split.ALL ==
raise ValueError(
"tfds.Split.ALL is a special split keyword corresponding to the "
"union of all splits, so cannot be used as key in "
)"Generating split %s",
# Prepare split will record examples associated to the split
self._prepare_split(split_generator, **prepare_split_kwargs)
# Update the info object with the splits.
def _as_dataset(
if self.version.implements(utils.Experiment.S3):
dataset =, split,, shuffle_files)
# Resolve all the named split tree by real ones
read_instruction = split.get_read_instruction(
# Extract the list of SlicedSplitInfo objects containing the splits
# to use and their associated slice
list_sliced_split_info = read_instruction.get_list_sliced_split_info()
# Resolve the SlicedSplitInfo objects into a list of
# {'filepath': 'path/to/data-00032-00100', 'mask': [True, False, ...]}
instruction_dicts = self._slice_split_info_to_instruction_dicts(
# Load the dataset
dataset = dataset_utils.build_dataset(
decode_fn = functools.partial(, decoders=decoders)
dataset =
return dataset
def _slice_split_info_to_instruction_dicts(self, list_sliced_split_info):
"""Return the list of files and reading mask of the files to read."""
instruction_dicts = []
for sliced_split_info in list_sliced_split_info:
mask = splits_lib.slice_to_percent_mask(sliced_split_info.slice_value)
# Compute filenames from the given split
filepaths = list(sorted(self._build_split_filenames(
# Compute the offsets
if sliced_split_info.split_info.num_examples:
shard_id2num_examples = splits_lib.get_shard_id2num_examples(
mask_offsets = splits_lib.compute_mask_offsets(shard_id2num_examples)
"Statistics not present in the dataset. TFDS is not able to load "
"the total number of examples, so using the subsplit API may not "
"provide precise subsplits."
mask_offsets = [0] * len(filepaths)
for filepath, mask_offset in zip(filepaths, mask_offsets):
"filepath": filepath,
"mask": mask,
"mask_offset": mask_offset,
return instruction_dicts
def _build_split_filenames(self, split_info):
"""Construct the split filenames associated with the split info.
The filenames correspond to the pre-processed datasets files present in
the root directory of the dataset.
split_info: (SplitInfo) needed split.
filenames: (list[str]) The list of filenames path corresponding to the
split info object
return naming.filepaths_for_dataset_split(,,
class GeneratorBasedBuilder(FileAdapterBuilder):
"""Base class for datasets with data generation based on dict generators.
`GeneratorBasedBuilder` is a convenience class that abstracts away much
of the data writing and reading of `DatasetBuilder`. It expects subclasses to
implement generators of feature dictionaries across the dataset splits
(`_split_generators`) and to specify a file type
(`_file_format_adapter`). See the method docstrings for details.
`FileFormatAdapter`s are defined in
`tensorflow_datasets.core.file_format_adapter` and specify constraints on the
feature dictionaries yielded by example generators. See the class docstrings.
def _generate_examples(self, **kwargs):
"""Default function generating examples for each `SplitGenerator`.
This function preprocess the examples from the raw data to the preprocessed
dataset files.
This function is called once for each `SplitGenerator` defined in
`_split_generators`. The examples yielded here will be written on
**kwargs: (dict) Arguments forwarded from the SplitGenerator.gen_kwargs
example: (`dict<str feature_name, feature_value>`), a feature dictionary
ready to be encoded and written to disk. The example will be
encoded with `{...})`.
raise NotImplementedError()
def _download_and_prepare(self, dl_manager, download_config):
# Extract max_examples_per_split and forward it to _prepare_split
super(GeneratorBasedBuilder, self)._download_and_prepare(
def _prepare_split_legacy(self, generator, split_info):
# TODO(pierrot): delete function once S3 has been fully rolled-out.
# For builders having both S3 and non S3 versions: drop key if any yielded.
generator = (ex[1] if isinstance(ex, tuple) else ex for ex in generator)
generator = ( for ex in generator)
output_files = self._build_split_filenames(split_info)
self._file_format_adapter.write_from_generator(generator, output_files)
def _prepare_split(self, split_generator, max_examples_per_split):
generator = self._generate_examples(**split_generator.gen_kwargs)
split_info = split_generator.split_info
if max_examples_per_split is not None:
logging.warning("Splits capped at %s examples max.",
generator = itertools.islice(generator, max_examples_per_split)
if not self.version.implements(utils.Experiment.S3):
return self._prepare_split_legacy(generator, split_info)
fname = "{}-{}.tfrecord".format(,
fpath = os.path.join(self._data_dir, fname)
writer = tfrecords_writer.Writer(self._example_specs, fpath,
for key, record in utils.tqdm(generator, unit=" examples",
total=split_info.num_examples, leave=False):
example =
writer.write(key, example)
shard_lengths = writer.finalize()
class BeamBasedBuilder(FileAdapterBuilder):
"""Beam based Builder."""
def _build_pcollection(self, pipeline, **kwargs):
"""Build the beam pipeline examples for each `SplitGenerator`.
This function extracts examples from the raw data with parallel transforms
in a Beam pipeline. It is called once for each `SplitGenerator` defined in
`_split_generators`. The examples from the PCollection will be
encoded and written to disk.
Beam liquid sharding can be used by setting num_shards to `None` in the
Warning: When running in a distributed setup, make sure that the data
which will be read (download_dir, manual_dir,...) and written (data_dir)
can be accessed by the workers jobs. The data should be located in a
shared filesystem, like GCS.
def _build_pcollection(pipeline, extracted_dir):
return (
| beam.Create(
| beam.Map(_process_file)
pipeline: `beam.Pipeline`, root Beam pipeline
**kwargs: Arguments forwarded from the SplitGenerator.gen_kwargs
pcollection: `PCollection`, an Apache Beam PCollection containing the
example to send to ``.
raise NotImplementedError()
def _download_and_prepare(self, dl_manager, download_config):
# Create the Beam pipeline and forward it to _prepare_split
beam = lazy_imports_lib.lazy_imports.apache_beam
if not download_config.beam_runner and not download_config.beam_options:
raise ValueError(
"Trying to generate a dataset using Apache Beam, yet no Beam Runner "
"or PipelineOptions() has been provided. Please pass a "
" object to the "
"builder.download_and_prepare(download_config=...) method"
# Use a single pipeline for all splits
with beam.Pipeline(
) as pipeline:
# TODO(tfds): Should eventually try to add support to
# download_config.max_examples_per_split
super(BeamBasedBuilder, self)._download_and_prepare(
# Update the number of shards for splits where liquid sharding were used.
split_dict =
for split_info in split_dict.values():
if not split_info.num_shards:
output_prefix = naming.filename_prefix_for_split(,
output_prefix = os.path.join(self._data_dir, output_prefix)
split_info.num_shards = len( + "*"))
def _prepare_split(self, split_generator, pipeline):
beam = lazy_imports_lib.lazy_imports.apache_beam
if not
split_info = split_generator.split_info
output_prefix = naming.filename_prefix_for_split(,
output_prefix = os.path.join(self._data_dir, output_prefix)
# Note: We need to wrap the pipeline in a PTransform to avoid re-using the
# same label names for each split
def _build_pcollection(pipeline):
"""PTransformation which build a single split."""
# Encode the PCollection
pcoll_examples = self._build_pcollection(
pipeline, **split_generator.gen_kwargs)
pcoll_examples |= "Encode" >> beam.Map(
# Write the example to disk
return self._file_format_adapter.write_from_pcollection(
# Add the PCollection to the pipeline
_ = pipeline | >> _build_pcollection() # pylint: disable=no-value-for-parameter
You can’t perform that action at this time.