Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On BatchedDataLoader performance #740

Closed
jarandaf opened this issue Mar 14, 2022 · 8 comments
Closed

On BatchedDataLoader performance #740

jarandaf opened this issue Mar 14, 2022 · 8 comments

Comments

@jarandaf
Copy link

Hi all,

I am trying to train a PyTorch model with a pretty big dataset (of the order of millions of samples, ~100 columns, including scalars and arrays) stored as Parquet files. After reading the docs it seems BatchedDataLoader should be the choice.

I have been having a look at the BatchedDataLoader class and despite reading parquet files in parallel with PyArrow, batches seem to be built on demand in an iterative way. This somehow does not leverage the processing power of GPUs and during training I don't observe GPU usage >20% and the GPU usage is very unstable.

I am afraid the GPU is idle waiting all those batches to be built. Would it be possible to build them in advance?

@selitvin
Copy link
Collaborator

Batch building implementation in BatchedDataLoader should be fairly efficient. Are you sure the slowness comes from BatchedDataLoader? Could it be that the data is not supplied fast enough to the BatchedDataLoader? Did you try tweaking parameters you pass to make_batch_reader (specifically reader_pool_type and workers_count)?

@jarandaf
Copy link
Author

jarandaf commented Mar 15, 2022

Hi @selitvin, thank you for your answer.

Yes, I tried both arguments and did not notice big improvements (thread vs process pool, +/- workers, etc).

Could it be that the data is not supplied fast enough to the BatchedDataLoader?

As far as I understand, independently of how fast parquet files are read in parallel and results are made available to the underlying results queue, the batches are built on demand when iterating the BatchedDataLoader, right?

I profiled a piece of code that simply consumed the dataset as follows:

PARQUET_PATH = 'file:///Users/jarandaf/some_big_dataset.parquet'
READER_POOL_TYPE = 'thread'
N_WORKERS = 10
BATCH_SIZE = 1024
COLS2KEEP = [...] # list of columns to load, around 100

reader = make_batch_reader(PARQUET_PATH, reader_pool_type=READER_POOL_TYPE, workers_count=N_WORKERS, schema_fields=COLS2KEEP)
with BatchedDataLoader(reader, batch_size=BATCH_SIZE) as loader:
    for i, batch in tqdm(enumerate(loader)): 
        pass

I observed a throughput around ~30 batches/s. From the profiling results it seems that it takes more time building batches than actually reading the parquet files and converting them to proper types (I found this quite surprising).

profile

Note: You can download the above image and open it with your browser to see more details.

Does all this look reasonable for such dataset (~100 columns, a couple of them arrays) or would you expect a higher reading performance? I must mention that if I only select a couple of columns the dataset is read blazingly fast.

@selitvin
Copy link
Collaborator

Got it. Interesting. Indeed, multiplicity of columns is tricky since it is handled by these two loops, it might end up pretty slow.

Couple of ideas:

  1. Add a pipelining - a thread + a queue: this way batch construction would be done in parallel with training. I think it's not a large undertaking and can be done externally to petastorm; alternatively, we can think of adding this as a built-in feature into BatchedDataLoader.
  2. Your proposal of moving the batching into worker processes/threads might be doable. I am afraid it would be a bit trickier to implement. Also, a concern of having shuffling done at a per-worker basis would result in worse shuffling quality.

If you are interested, feel free to propose a PR - we can work together to get it into the petastorm codebase.

@jarandaf
Copy link
Author

Could you please elaborate on 1?

@selitvin
Copy link
Collaborator

In (1) I am referring to the following idea:

  • Implement a class that has the same interface as BatchedDataLoader
  • The is instantiated with an instance of a petastorm's BatchedDataLoader
  • In the constructor it instantiates a queue (bounded size) and a thread
  • On the thread, we continuously read batches from BatchedDataLoader and store results in the queue
  • __iter__ of the new class returns data from the queue.

This way, the batching will be done on the background thread. This way the main thread can drive GPU based training while CPU/GIL would be busy creating batches.

@jarandaf
Copy link
Author

Thank you for your clarification. Yeah, I implemented this and it indeed improved GPU usage. I might submit a PR in the coming days.

@selitvin
Copy link
Collaborator

I wonder if we should put this mechanism behind the current BatchedDataLoader implementation (optional with a switch).

@jarandaf
Copy link
Author

I think it definitely makes sense!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants