Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

--num-processes causes build to error #5278

Closed
BogueUser opened this issue Feb 5, 2024 · 6 comments
Closed

--num-processes causes build to error #5278

BogueUser opened this issue Feb 5, 2024 · 6 comments
Assignees
Labels

Comments

@BogueUser
Copy link

What I need help with / What I was wondering
So I am trying to run tfds build myDataset with multi processing since the dataset is pretty wide and my bottle neck seems to be tensorflow itself.
When using --num-processes 12 it hangs the following error being printed to the terminal.

Traceback (most recent call last):
  File "/usr/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.11/multiprocessing/pool.py", line 114, in worker
    task = get()
           ^^^^^
  File "/usr/lib/python3.11/multiprocessing/queues.py", line 367, in get
    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'Mk0'

I haven't a clue in the slightest what is called before that to get the issue.
My dataset does build without the flag, it just takes for friggen ever.

What I've tried so far
Running it again.
Using tfds-nightly
Restarting my computer.
using tensorflow-dataset
Yelling at it.
Asking nicely.
Reading and failing to understand the source code
Complaining on the internet.

It would be nice if...
I may have missed it but it appears there is no documentation for --num-processes. Everything else I have found indicates that tfds local builds are single core only.
Is it not fully implemented?

Environment information
(if applicable)

  • Operating System: Arch Linux
  • Python version: 3.11.5
  • tensorflow-datasets version: 4.9.4
  • tensorflow version: 2.14.0

If ya'll need anymore information, please let me know.

@BogueUser BogueUser added the help label Feb 5, 2024
@fineguy
Copy link
Collaborator

fineguy commented Feb 6, 2024

Parallelization is done by processing each config in a separate process:

with multiprocessing.Pool(args.num_processes) as pool:
pool.map(process_builder_fn, builders)

This means that there's no reason to set -num-processes higher than the number of distinct configs for your dataset.

The way multiprocessing works in Python it would pickle each builder (including references to Mk0 module) in the main process and then unpickle it in the child processes. During unpickling it must be able to import Mk0, which is most probably not installed in your Python environment.

You could make your dataset builder code discoverable by Python (e.g. including it into TFDS and reinstalling the library locally). Or perhaps you could consider implementing parallelization in your dataset builder or use Beam.

@fineguy fineguy self-assigned this Feb 6, 2024
@fineguy
Copy link
Collaborator

fineguy commented Feb 6, 2024

#5279 should make children processes aware of your dataset builder.

@BogueUser
Copy link
Author

#5279 should make children processes aware of your dataset builder.

I can now run with the --num-processes flag which is pretty cool. Thank you for fixing that!

This means that there's no reason to set -num-processes higher than the number of distinct configs for your dataset.

Is it possible to effectively split my dataset into different configs then merge them together when done?

Or perhaps you could consider implementing parallelization in your dataset builder

My code unfortunately isn't the slow part. Its my abuse of TFDS due to me having 138 features in my dataset . My code takes 2.5 seconds per 3000 examples meanwhile my generation happens at 80 examples per second.

@fineguy
Copy link
Collaborator

fineguy commented Feb 7, 2024

I can now run with the --num-processes flag which is pretty cool. Thank you for fixing that!

Great, really glad that it worked for you! I'll be closing this issue then.

Is it possible to effectively split my dataset into different configs then merge them together when done?

TFDS doesn't natively support mixing datasets, but you can use some other tools for that, e.g. https://github.com/google/seqio

My code unfortunately isn't the slow part. Its my abuse of TFDS due to me having 138 features in my dataset . My code takes 2.5 seconds per 3000 examples meanwhile my generation happens at 80 examples per second.

It's usually very straightforward to parallelize examples generation:

def _generate_examples(self, data) -> split_builder_lib.SplitGenerator:
dataset_info = self._info()
if self._tfds_num_proc is None:
for index, example in enumerate(data):
yield _convert_example(index, example, dataset_info.features)
else:
with multiprocessing.Pool(processes=self._tfds_num_proc) as pool:
examples = pool.starmap(
functools.partial(_convert_example, features=dataset_info.features),
enumerate(data),
)
yield from examples

Or with Beam:

def _generate_query_examples(self, pipeline):
"""Generates examples for query split."""
beam = tfds.core.lazy_imports.apache_beam
return (
pipeline
| 'AddEmptyFeatures'
>> beam.Map(
functools.partial(
_append_constant_features,
mapping={
'passage_id': '',
'passage': '',
'passage_metadata': '{}',
'score': -1,
},
)
)
| 'GetHashKey' >> beam.Map(lambda x: (_get_hash(x), x))
)

@fineguy fineguy closed this as completed Feb 7, 2024
@BogueUser
Copy link
Author

Ok. Thanks for you for your help with this.

I'll look into parallel example generation. and see if I can get that working with my dataset.

@mathpluscode
Copy link

Hi, I have unfortunately encountered the same issues and had no luck on trying tf-nightly or multiprocessing with python 3.9/3.10 (py3.10 was necessary to use tensorflow-datasets==4.9.4 and nightly versions with the fix) with MacOS.

However, I think I managed to bypass this using BUILDER_CONFIGS. Let's say I want to generate the datasets in parallel with 10 workers, I simply need to define 10 configs as below,

class Builder(tfds.core.GeneratorBasedBuilder, skip_registration=True):
    VERSION = tfds.core.Version("1.0.0")
    BUILDER_CONFIGS: ClassVar[list[tfds.core.BuilderConfig]] = [
        tfds.core.BuilderConfig(name=str(group)) for group in range(1,11)
    ]

then run the CLI per group one by one.

tfds build my_dataset --config 1

Later we can load and concatenate the datasets as below

ds1 = tfds.load('my_dataset/1', split='train')
ds2 = tfds.load('my_dataset/2', split='train')
ds = tf.data.Dataset.sample_from_datasets([ds1, ds2])

I tested locally, it does seem to work. Please let me know if there are any hidden traps 🙌 If not, hope it could help others who are facing same issues!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants