Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not materialize entire randperm in RandomSampler #103339

Closed
wants to merge 5 commits into from

Conversation

aviverma01
Copy link
Contributor

@aviverma01 aviverma01 commented Jun 9, 2023

In our DDP training workloads, each rank was initializing a RandomSampler for a dataset with a length of 3.5 billion items. We noticed that when this sampler was in scope, gc.collect calls were taking on the order of seconds to run, which would slow down the entire training iteration. This is because when we call torch.randperm(n).tolist(), we create a python list of 3.5 billion items, which massively slows down the periodic mark & sweep garbage collection.

This PR swaps out the .tolist() call with a .numpy() call and manually calls .item() on each element as it is being requested. This has two benefits:

  1. The first call to RandomSampler::__next__ should be about twice as fast, since .numpy does not copy the contents of the original tensor
  2. The runtime of gc.collect() calls no longer scales linearly with the size of the dataset passed to RandomSampler

I've attached some timeit samples to illustrate the speedups with this Pr:

Main (no GC):  51.72115747816861
Main (10 GC calls) 83.61965207383037
PR (no GC) 33.06403830461204
PR (10 GC calls) 33.959467427805066

Code

from timeit import timeit


baseline_no_gc = """
import torch

n = int(1e9)
steps = n // 100

x = torch.randperm(n).tolist()
x_iter = iter(x)

for i in range(steps):
    next(x_iter)
"""


baseline_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10

x = torch.randperm(n).tolist()
x_iter = iter(x)

for i in range(steps):
    next(x_iter)
    if i % gc_every == 0:
        gc.collect()
"""


numpy_no_gc = """
import torch
n = int(1e9)
steps = n // 100

x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)

for i in range(steps):
    next(x_iter)
"""

numpy_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10

x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)

for i in range(steps):
    next(x_iter)
    if i % gc_every == 0:
        gc.collect()
"""


if __name__ == "__main__":
    print("Main (no GC): ", timeit(baseline_no_gc, number=1))
    print("Main (10 GC calls)", timeit(baseline_gc, number=1))
    print("PR (no GC)",  timeit(numpy_no_gc, number=1))
    print("PR (10 GC calls)", timeit(numpy_gc, number=1))

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103339

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8f5e445 with merge base bc2caa7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jun 9, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: dataloader release notes category label Jun 9, 2023
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
indices = torch.randperm(n, generator=generator)
for i in indices:
yield i.item()
Copy link
Contributor

@vadimkantorov vadimkantorov Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one minor issue is that repeated .item() are slow #29973

so maybe tolist() is actually not that bad? (espeically if indices itself is materialized as tensor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I've updated the code to call .numpy on the tensor before iterating on it, which should avoid the slow .item() calls

Copy link
Contributor

@vadimkantorov vadimkantorov Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if you can afford an extra allocation, why not just yield from indices.tolist()? Because the indices Python list would take too much memory?

I don't know what's the current state of affaires on obligation of numpy dependency

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth filing a feature request to iterate Python items? Or supporting memoryview on tensors (so that it can be iterated)...

Copy link
Contributor Author

@aviverma01 aviverma01 Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some comments in the PR description, but the main issue is that calling .tolist() on a torch tensor of size 1billion+ adds a massive garbage collection overhead because we just allocated billions of individual python int objects that need to be managed separately. By using a numpy array instead, the garbage collector only needs to keep track of 1 object regardless of the dataset size, making garbage collection much faster during training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i created also #103352

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this just be yield from torch.randperm(n, generator=generator).numpy() (+ some indexing) keeping the existing terse syntax?

Copy link
Contributor Author

@aviverma01 aviverma01 Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 7a7a1e9, note that there's an additional map call compared to the original comment to ensure that the type matches

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aviverma01 Should we also change the other branch, yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in b450690

@kit1980 kit1980 added ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/inductor labels Jun 13, 2023
@facebook-github-bot
Copy link
Contributor

@kit1980 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@kit1980
Copy link
Member

kit1980 commented Jun 13, 2023

I've triggered more tests and also imported this internally to make sure nothing breaks.

@aviverma01
Copy link
Contributor Author

I've triggered more tests and also imported this internally to make sure nothing breaks.

Thanks @kit1980, would you able to help fix the "Meta Internal-Only Changes Check"? Also I think some of the tests may be flakey. Would it be possible to re-kick the failing tests?

@kit1980
Copy link
Member

kit1980 commented Jun 14, 2023

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased patch-1 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout patch-1 && git pull --rebase)

@aviverma01
Copy link
Contributor Author

@kit1980 Looks like there may be some remaining flakey tests, and I believe the "Meta Internal-Only" Changes Check is still failing. Any chance you could help/show me how to fix it?

@facebook-github-bot
Copy link
Contributor

@kit1980 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch pytorch deleted a comment from pytorch-bot bot Jun 15, 2023
@kit1980
Copy link
Member

kit1980 commented Jun 15, 2023

@mergebot merge -i

@aviverma01
Copy link
Contributor Author

@mergebot merge -i

@kit1980 do you know why the pytorchbot didn't pick up the merge request?

@kit1980
Copy link
Member

kit1980 commented Jun 16, 2023

@aviverma01 sorry, I misspelled the bot name.

@kit1980
Copy link
Member

kit1980 commented Jun 16, 2023

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 16, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@qqaatw
Copy link
Collaborator

qqaatw commented Oct 25, 2023

Hi @aviverma01 @kit1980, the change this PR introduces seems to block the usage of manually specifying DataLoader's generator with a non-CPU device due to the numpy() operations:

gen = torch.Generator(device=torch.device("mps:0"))
data_loader = data.DataLoader(dataset, batch_size=8, shuffle=True, generator=gen)

The error:

TypeError: can't convert mps:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Was it intended to be? Or do you have any idea on this?

Thanks :)

@kit1980
Copy link
Member

kit1980 commented Oct 26, 2023

@pytorchbot revert -m "Cause issues on MPS, and also fails without numpy" -c nosignal

@kit1980
Copy link
Member

kit1980 commented Oct 26, 2023

I'm reverting this.

I've realized there is another issue with the PR, in fails without numpy, which is an optional dependency actually.

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@aviverma01 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Oct 26, 2023
This reverts commit d80174e.

Reverted #103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](#103339 (comment)))
kit1980 pushed a commit to kit1980/pytorch that referenced this pull request Oct 26, 2023
…103339)"

This reverts commit d80174e.

Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))
kit1980 pushed a commit to kit1980/pytorch that referenced this pull request Oct 26, 2023
…103339)"

This reverts commit d80174e.

Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
…103339)"

This reverts commit d80174e.

Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))
huydhn pushed a commit that referenced this pull request Oct 26, 2023
#112187)

This reverts commit d80174e.

Reverted #103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](#103339 (comment)))

Co-authored-by: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…103339)"

This reverts commit d80174e.

Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…103339)"

This reverts commit d80174e.

Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))
Halmoni100 pushed a commit to Halmoni100/pytorch that referenced this pull request Nov 25, 2023
…103339)" (pytorch#112187)

This reverts commit d80174e.

Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))

Co-authored-by: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: dataloader release notes category Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants