Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Limit all gather after pre-unshard #89057

Closed
wants to merge 2 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Nov 15, 2022

Stack from ghstack:

To reuse memory when allocating the unsharded FlatParameter in the unshard stream, we only need to block the CPU thread on the preceding free event (i.e. event.synchronize()) before allocating the unsharded memory, which happens in handle.unshard(). Notably, this can be done after the pre-unshard logic, which at most performs sharded allocations (low precision shard or H2D sharded FlatParameter copy) in its own pre-unshard stream. This enables the pre-unshard to overlap with any pending ops.

With this change, I believe that we should use limit_all_gathers=True all the time to stay true to FSDP's proposed memory semantics.

If a user wants to set limit_all_gathers=False, that would mean that he/she wants to overlap ops that are issued after the unshard logic's all-gather with ops that are pending at the time when FSDP would block the CPU thread via event.synchronize().

  • If the user is willing to not reuse memory for that all-gather, then the user may as well have applied NO_SHARD and optionally ZeRO-1 (if this niche is important, then maybe we should consider hardening ZeRO-1). This is because now the unsharded memory for the all-gather additionally contributes to peak memory since it cannot reuse memory.
  • If the user wanted to reuse memory for that all-gather, then we needed to block the CPU thread. There is no way around that given the caching allocator semantics.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 15, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89057

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 47af6d0:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Nov 15, 2022
@awgu awgu added the topic: improvements topic category label Nov 15, 2022
awgu added a commit to awgu/pytorch that referenced this pull request Nov 23, 2022
ghstack-source-id: 72e1b5f3d523e6fde6ca183af492cee947adbb43
Pull Request resolved: pytorch#89057
awgu added a commit to awgu/pytorch that referenced this pull request Nov 29, 2022
ghstack-source-id: 72e1b5f3d523e6fde6ca183af492cee947adbb43
Pull Request resolved: pytorch#89057
To reuse memory when allocating the unsharded `FlatParameter` in the unshard stream, we only need to block the CPU thread on the preceding free event (i.e. `event.synchronize()`) before allocating the unsharded memory, which happens in `handle.unshard()`. Notably, this can be done after the pre-unshard logic, which at most performs _sharded_ allocations (low precision shard or H2D sharded `FlatParameter` copy) in its own pre-unshard stream. This enables the pre-unshard to overlap with any pending ops.

With this change, I believe that we should use `limit_all_gathers=True` all the time to stay true to FSDP's proposed memory semantics.

If a user wants to set `limit_all_gathers=False`, that would mean that he/she wants to overlap ops that are issued after the unshard logic's all-gather with ops that are pending at the time when FSDP _would_ block the CPU thread via `event.synchronize()`.
- If the user is willing to not reuse memory for that all-gather, then the user may as well have applied `NO_SHARD` and optionally ZeRO-1 (if this niche is important, then maybe we should consider hardening ZeRO-1). This is because now the unsharded memory for the all-gather additionally contributes to peak memory since it cannot reuse memory.
- If the user wanted to reuse memory for that all-gather, then we needed to block the CPU thread. There is no way around that given the caching allocator semantics.

[ghstack-poisoned]
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 29, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
To reuse memory when allocating the unsharded `FlatParameter` in the unshard stream, we only need to block the CPU thread on the preceding free event (i.e. `event.synchronize()`) before allocating the unsharded memory, which happens in `handle.unshard()`. Notably, this can be done after the pre-unshard logic, which at most performs _sharded_ allocations (low precision shard or H2D sharded `FlatParameter` copy) in its own pre-unshard stream. This enables the pre-unshard to overlap with any pending ops.

With this change, I believe that we should use `limit_all_gathers=True` all the time to stay true to FSDP's proposed memory semantics.

If a user wants to set `limit_all_gathers=False`, that would mean that he/she wants to overlap ops that are issued after the unshard logic's all-gather with ops that are pending at the time when FSDP _would_ block the CPU thread via `event.synchronize()`.
- If the user is willing to not reuse memory for that all-gather, then the user may as well have applied `NO_SHARD` and optionally ZeRO-1 (if this niche is important, then maybe we should consider hardening ZeRO-1). This is because now the unsharded memory for the all-gather additionally contributes to peak memory since it cannot reuse memory.
- If the user wanted to reuse memory for that all-gather, then we needed to block the CPU thread. There is no way around that given the caching allocator semantics.
Pull Request resolved: pytorch#89057
Approved by: https://github.com/mrshenli
@facebook-github-bot facebook-github-bot deleted the gh/awgu/200/head branch June 8, 2023 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants