Skip to content

Commit

Permalink
Fix tfds build with multiprocessing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604609960
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Feb 6, 2024
1 parent 69e781f commit 9e70806
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions tensorflow_datasets/scripts/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
import json
import multiprocessing
import os
import typing
from typing import Dict, Iterator, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, Optional, Tuple, Type, Union

from absl import logging
from etils import epath
Expand Down Expand Up @@ -299,10 +298,17 @@ def _build_datasets(args: argparse.Namespace) -> None:
else:
datasets = datasets or [''] # Empty string for default

# Import builder classes
builders_cls_and_kwargs = [
_get_builder_cls_and_kwargs(dataset, has_imports=bool(args.imports))
for dataset in datasets
]

# Parallelize datasets generation.
builders = itertools.chain(
*(_make_builders(args, dataset) for dataset in datasets)
)
builders = itertools.chain(*(
_make_builders(args, builder_cls, builder_kwargs)
for (builder_cls, builder_kwargs) in builders_cls_and_kwargs
))
process_builder_fn = functools.partial(
_download if args.download_only else _download_and_prepare, args
)
Expand All @@ -317,14 +323,19 @@ def _build_datasets(args: argparse.Namespace) -> None:

def _make_builders(
args: argparse.Namespace,
ds_to_build: str,
builder_cls: Type[tfds.core.DatasetBuilder],
builder_kwargs: Dict[str, Any],
) -> Iterator[tfds.core.DatasetBuilder]:
"""Yields builders to generate."""
builder_cls, builder_kwargs = _get_builder_cls(
ds_to_build,
has_imports=bool(args.imports),
)
"""Yields builders to generate.
Args:
args: Command line arguments.
builder_cls: Dataset builder class.
builder_kwargs: Dataset builder kwargs.
Yields:
Initialized dataset builders.
"""
# Eventually overwrite version
if args.experimental_latest_version:
if 'version' in builder_kwargs:
Expand Down Expand Up @@ -362,12 +373,12 @@ def _make_builders(
yield make_builder()


def _get_builder_cls(
def _get_builder_cls_and_kwargs(
ds_to_build: str,
*,
has_imports: bool,
) -> Tuple[Type[tfds.core.DatasetBuilder], Dict[str, str]]:
"""Infer the builder class to build.
) -> Tuple[Type[tfds.core.DatasetBuilder], Dict[str, Any]]:
"""Infer the builder class to build and its kwargs.
Args:
ds_to_build: Dataset argument.
Expand Down Expand Up @@ -410,7 +421,6 @@ def _get_builder_cls(
logging.info(
f'Loading dataset {ds_to_build} from imports: {builder_cls.__module__}'
)
builder_kwargs = typing.cast(Dict[str, str], builder_kwargs)
return builder_cls, builder_kwargs


Expand Down

0 comments on commit 9e70806

Please sign in to comment.