-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][3/N] Refactor summon_full_params
unit tests
#92298
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92298
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 33c32ce: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: bb336659bb08f715f5578e55aeb96eed4dda17f3 Pull Request resolved: #92298
ghstack-source-id: bb336659bb08f715f5578e55aeb96eed4dda17f3 Pull Request resolved: pytorch#92298
ghstack-source-id: bb336659bb08f715f5578e55aeb96eed4dda17f3 Pull Request resolved: pytorch#92298
**Overview** - This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others. - This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`. - This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`). **Details** <details> <summary>Existing Unit Tests</summary> `test_summon_full_param_writeback()` with `world_size=1` `test_summon_full_param_writeback()` with `world_size=2` - Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params` - `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported. - The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage. `test_summon_full_param_shard_value()` - Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance - This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved. - This test assumes the current FSDP sharding algorithm. `test_summon_full_param_recursive()` - Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not - This test assumes the current FSDP sharding algorithm. `test_cannot_summon_full_params_from_forward()` `test_cannot_summon_full_params_from_backward()` - Tests that calling `summon_full_params()` from inside the forward or backward raises an error - The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR. `test_summon_full_params_respects_reshard_after_forward()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) - This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`). `test_summon_single_param()` - Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`) - This test name is misleading. `test_summon_full_params_equivalence()` - Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()` - The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization. - This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced. `test_summon_from_non_fsdp()` - Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly - This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking. `test_reshard_outside_forward_backward_iteration()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision` - This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well. `test_params_are_unflattenned()` - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_params_count_and_value()` - Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_raises_rank0_with_writeback()` - Tests that `rank0_only` + `writeback=True` raises an error `test_named_parameters_buffers()` - Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()` `test_with_grads_core()` - Tests `with_grads=True` by comparing against DDP `test_with_grads_none_grads()` - Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient </details> <details> <summary>New Unit Tests</summary> `test_unshard_params_writeback_no_shard()` (with `world_size=1`) `test_unshard_params_writeback()` (with `world_size=2`) - Tests the `writeback` argument (using the default value for all others) `test_unshard_params_param_data_no_shard()` (with `world_size=1`) `test_unshard_params_param_data()` (with `world_size=2`) - Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module `test_unshard_singleton_param_writeback()` - Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist `test_unshard_params_respects_reshard()` - Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward `test_unshard_params_recurse()` - Tests the `recurse` argument (using default for all others) `test_offload_to_cpu_no_shard_raises()` - Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error </details> <details> <summary>Summary of Unit Test Changes</summary> - `test_summon_full_param_writeback` -> `test_unshard_params_writeback()` - `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()` - `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()` - `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()` - `test_named_parameters_and_buffers()` unchanged - `test_with_grads_core()` unchanged - `test_with_grads_none_grads()` unchanged - `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()` - `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()` - `test_offload_to_cpu_no_shard_raises()` new - `test_summon_full_param_shard_value()` removed </details> [ghstack-poisoned]
ghstack-source-id: 17cc6b38a54e4e803fadc480859434549a3ee2b2 Pull Request resolved: #92298
ghstack-source-id: 17cc6b38a54e4e803fadc480859434549a3ee2b2 Pull Request resolved: pytorch#92298
**Overview** - This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others. - This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`. - This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`). **Details** <details> <summary>Existing Unit Tests</summary> `test_summon_full_param_writeback()` with `world_size=1` `test_summon_full_param_writeback()` with `world_size=2` - Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params` - `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported. - The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage. `test_summon_full_param_shard_value()` - Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance - This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved. - This test assumes the current FSDP sharding algorithm. `test_summon_full_param_recursive()` - Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not - This test assumes the current FSDP sharding algorithm. `test_cannot_summon_full_params_from_forward()` `test_cannot_summon_full_params_from_backward()` - Tests that calling `summon_full_params()` from inside the forward or backward raises an error - The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR. `test_summon_full_params_respects_reshard_after_forward()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) - This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`). `test_summon_single_param()` - Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`) - This test name is misleading. `test_summon_full_params_equivalence()` - Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()` - The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization. - This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced. `test_summon_from_non_fsdp()` - Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly - This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking. `test_reshard_outside_forward_backward_iteration()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision` - This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well. `test_params_are_unflattenned()` - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_params_count_and_value()` - Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_raises_rank0_with_writeback()` - Tests that `rank0_only` + `writeback=True` raises an error `test_named_parameters_buffers()` - Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()` `test_with_grads_core()` - Tests `with_grads=True` by comparing against DDP `test_with_grads_none_grads()` - Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient </details> <details> <summary>New Unit Tests</summary> `test_unshard_params_writeback_no_shard()` (with `world_size=1`) `test_unshard_params_writeback()` (with `world_size=2`) - Tests the `writeback` argument (using the default value for all others) `test_unshard_params_param_data_no_shard()` (with `world_size=1`) `test_unshard_params_param_data()` (with `world_size=2`) - Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module `test_unshard_singleton_param_writeback()` - Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist `test_unshard_params_respects_reshard()` - Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward `test_unshard_params_recurse()` - Tests the `recurse` argument (using default for all others) `test_offload_to_cpu_no_shard_raises()` - Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error </details> <details> <summary>Summary of Unit Test Changes</summary> - `test_summon_full_param_writeback` -> `test_unshard_params_writeback()` - `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()` - `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()` - `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()` - `test_named_parameters_and_buffers()` unchanged - `test_with_grads_core()` unchanged - `test_with_grads_none_grads()` unchanged - `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()` - `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()` - `test_offload_to_cpu_no_shard_raises()` new - `test_summon_full_param_shard_value()` removed </details> [ghstack-poisoned]
**Overview** - This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others. - This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`. - This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`). **Details** <details> <summary>Existing Unit Tests</summary> `test_summon_full_param_writeback()` with `world_size=1` `test_summon_full_param_writeback()` with `world_size=2` - Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params` - `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported. - The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage. `test_summon_full_param_shard_value()` - Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance - This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved. - This test assumes the current FSDP sharding algorithm. `test_summon_full_param_recursive()` - Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not - This test assumes the current FSDP sharding algorithm. `test_cannot_summon_full_params_from_forward()` `test_cannot_summon_full_params_from_backward()` - Tests that calling `summon_full_params()` from inside the forward or backward raises an error - The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR. `test_summon_full_params_respects_reshard_after_forward()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) - This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`). `test_summon_single_param()` - Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`) - This test name is misleading. `test_summon_full_params_equivalence()` - Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()` - The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization. - This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced. `test_summon_from_non_fsdp()` - Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly - This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking. `test_reshard_outside_forward_backward_iteration()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision` - This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well. `test_params_are_unflattenned()` - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_params_count_and_value()` - Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_raises_rank0_with_writeback()` - Tests that `rank0_only` + `writeback=True` raises an error `test_named_parameters_buffers()` - Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()` `test_with_grads_core()` - Tests `with_grads=True` by comparing against DDP `test_with_grads_none_grads()` - Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient </details> <details> <summary>New Unit Tests</summary> `test_unshard_params_writeback_no_shard()` (with `world_size=1`) `test_unshard_params_writeback()` (with `world_size=2`) - Tests the `writeback` argument (using the default value for all others) `test_unshard_params_param_data_no_shard()` (with `world_size=1`) `test_unshard_params_param_data()` (with `world_size=2`) - Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module `test_unshard_singleton_param_writeback()` - Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist `test_unshard_params_respects_reshard()` - Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward `test_unshard_params_recurse()` - Tests the `recurse` argument (using default for all others) `test_offload_to_cpu_no_shard_raises()` - Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error </details> <details> <summary>Summary of Unit Test Changes</summary> - `test_summon_full_param_writeback` -> `test_unshard_params_writeback()` - `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()` - `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()` - `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()` - `test_named_parameters_and_buffers()` unchanged - `test_with_grads_core()` unchanged - `test_with_grads_none_grads()` unchanged - `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()` - `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()` - `test_offload_to_cpu_no_shard_raises()` new - `test_summon_full_param_shard_value()` removed </details> [ghstack-poisoned]
ghstack-source-id: dcee8e88e877f761188ba38167b7ed46cda88b91 Pull Request resolved: #92298
ghstack-source-id: dcee8e88e877f761188ba38167b7ed46cda88b91 Pull Request resolved: pytorch#92298
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for enhancing the testing!
ghstack-source-id: dcee8e88e877f761188ba38167b7ed46cda88b91 Pull Request resolved: pytorch#92298
**Overview** - This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others. - This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`. - This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`). **Details** <details> <summary>Existing Unit Tests</summary> `test_summon_full_param_writeback()` with `world_size=1` `test_summon_full_param_writeback()` with `world_size=2` - Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params` - `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported. - The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage. `test_summon_full_param_shard_value()` - Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance - This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved. - This test assumes the current FSDP sharding algorithm. `test_summon_full_param_recursive()` - Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not - This test assumes the current FSDP sharding algorithm. `test_cannot_summon_full_params_from_forward()` `test_cannot_summon_full_params_from_backward()` - Tests that calling `summon_full_params()` from inside the forward or backward raises an error - The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR. `test_summon_full_params_respects_reshard_after_forward()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) - This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`). `test_summon_single_param()` - Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`) - This test name is misleading. `test_summon_full_params_equivalence()` - Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()` - The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization. - This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced. `test_summon_from_non_fsdp()` - Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly - This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking. `test_reshard_outside_forward_backward_iteration()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision` - This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well. `test_params_are_unflattenned()` - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_params_count_and_value()` - Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_raises_rank0_with_writeback()` - Tests that `rank0_only` + `writeback=True` raises an error `test_named_parameters_buffers()` - Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()` `test_with_grads_core()` - Tests `with_grads=True` by comparing against DDP `test_with_grads_none_grads()` - Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient </details> <details> <summary>New Unit Tests</summary> `test_unshard_params_writeback_no_shard()` (with `world_size=1`) `test_unshard_params_writeback()` (with `world_size=2`) - Tests the `writeback` argument (using the default value for all others) `test_unshard_params_param_data_no_shard()` (with `world_size=1`) `test_unshard_params_param_data()` (with `world_size=2`) - Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module `test_unshard_singleton_param_writeback()` - Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist `test_unshard_params_respects_reshard()` - Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward `test_unshard_params_recurse()` - Tests the `recurse` argument (using default for all others) `test_offload_to_cpu_no_shard_raises()` - Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error </details> <details> <summary>Summary of Unit Test Changes</summary> - `test_summon_full_param_writeback` -> `test_unshard_params_writeback()` - `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()` - `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()` - `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()` - `test_named_parameters_and_buffers()` unchanged - `test_with_grads_core()` unchanged - `test_with_grads_none_grads()` unchanged - `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()` - `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()` - `test_offload_to_cpu_no_shard_raises()` new - `test_summon_full_param_shard_value()` removed </details> [ghstack-poisoned]
ghstack-source-id: d28beaaacf048fb2b3e45c3db2c158945ddd118f Pull Request resolved: #92298
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…n-dev-setup * origin: (898 commits) Move dynamo.optimizations.distributed to backends (pytorch#93408) Remove cuda 11.6 from nightly (pytorch#93979) Refactor dynamo register_backend/BACKENDS (pytorch#93389) Remove cuda 11.6 from CI replace with 11.7 (pytorch#93406) [Dynamo] Rename `GuardBuilder.guarded_code` -> `check_fn_manager` (pytorch#93934) Revert "Remove CUDA 11.6 from nightly builds (pytorch#93404)" Revert "[inductor] fix crash issue when input is a view tensor (pytorch#90150)" Basic Validation for FSDP `state_dict` transformations of modules with persistent buffers (pytorch#93396) Merge Inductor perf smoke test with other inductor CI tests (pytorch#93395) [inductor] Don't import torchvision (pytorch#93027) [FSDP][3/N] Refactor `summon_full_params` unit tests (pytorch#92298) [FSDP][2/N] `_summon_full_params` -> `_unshard_params` (pytorch#92297) Remove CUDA 11.6 from nightly builds (pytorch#93404) Mark buffers that reuse other buffers (pytorch#93329) Refactor to allow reuse of SchedulerNode.allocate (pytorch#93328) retire sparse_mask_helper (pytorch#91714) update fbgemm third party (pytorch#93907) [inductor] fix crash issue when input is a view tensor (pytorch#90150) [Inductor] add config for weight prepacking (pytorch#93811) Check for none for NNModuleVariable.__module__ (pytorch#93326) ...
Stack from ghstack:
summon_full_params
unit tests #92298 [FSDP][3/N] Refactorsummon_full_params
unit tests_summon_full_params
->_unshard_params
#92297 [FSDP][2/N]_summon_full_params
->_unshard_params
Overview
summon_full_params()
unit tests to prepare forunshard_params()
by consolidating redundant tests and improving others.CPUOffload(offload_params=True)
+NO_SHARD
+writeback=True
.summon_full_params()
from an invalid context (i.e. from forward, backward, or insummon_full_params()
).Details
Existing Unit Tests
test_summon_full_param_writeback()
withworld_size=1
test_summon_full_param_writeback()
withworld_size=2
writeback=True
persists write and thatwriteback=False
does not persist write when modifying a root FSDP instance'sflat_param
(modify_outer=True
) or a non-root FSDP instance'sflat_param
(modify_outer=False
); additionally configures withmixed_precision
anduse_orig_params
CPUOffload(offload_params=True)
+world_size=1
is not tested because it is not supported.summon_full_params()
is on theflat_param
itself, which is not the expected usage.test_summon_full_param_shard_value()
flat_param
(by re-flattening and chunking parameters) insidesummon_full_params()
gives the same as the originally constructedflat_param
when using a single FSDP instancesummon_full_params()
. The only relevant part being implicitly tested is thatmodel.parameters()
order is preserved.test_summon_full_param_recursive()
recurse=True
recursively applies to all FSDP instances and thatrecurse=False
does nottest_cannot_summon_full_params_from_forward()
test_cannot_summon_full_params_from_backward()
summon_full_params()
from inside the forward or backward raises an errorFlatParamHandle
to the user. I provided a better error in this PR.test_summon_full_params_respects_reshard_after_forward()
summon_full_params()
after forward preserves whether the padded unshardedflat_param
data is freed or not (likereshard_after_forward
)flat_param._full_param_padded.storage().size()
).test_summon_single_param()
writeback=True
does not persist those writes (doing so by using a singleton(1, 1)
parameter that gets flattened and padded to(2,)
)test_summon_full_params_equivalence()
writeback
,rank0_only
, andoffload_to_cpu
withwriteback=not rank0_only
, usingCPUOffload(offload_params=True)
and including atorch.cuda._sleep(int(1e6))
after the write insummon_full_params()
torch.cuda._sleep(int(1e6))
exercised the stream synchronization insummon_full_params()
--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since thetorch.cuda._sleep()
call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority sincesummon_full_params()
unshards in the default stream now and does not require stream synchronization.test_summon_full_param_writeback()
and can be coalesced.test_summon_from_non_fsdp()
summon_full_params()
with default args on a non-FSDP root module exposes the original parameters correctlytest_reshard_outside_forward_backward_iteration()
summon_full_params()
after forward preserves whether the padded unshardedflat_param
data is freed or not (likereshard_after_forward
) and that callingsummon_full_params()
after backward preserves that the padded unshardedflat_param
data are freed; additionally configuresmixed_precision
test_summon_full_params_respects_reshard_after_forward()
in strictness since it includes the check after backward as well.test_params_are_unflattenned()
rank0_only
(e.g. including that nonzero ranks reshard early whenrank0_only=True
) and that withoffload_to_cpu=True
, theflat_param
s are moved back to GPU after exiting the context; additionally configuresmixed_precision
test_params_count_and_value()
rank0_only
(e.g. including that nonzero ranks do not expose the original parameters whenrank0_only=True
) and that withoffload_to_cpu=True
, theflat_param
s are moved back to GPU after exiting the context; additionally configuresmixed_precision
test_raises_rank0_with_writeback()
rank0_only
+writeback=True
raises an errortest_named_parameters_buffers()
named_parameters()
andnamed_buffers()
return clean names (without FSDP prefixes) insidesummon_full_params()
test_with_grads_core()
with_grads=True
by comparing against DDPtest_with_grads_none_grads()
with_grads=True
when ranks'FlatParameter
s haveNone
gradientNew Unit Tests
test_unshard_params_writeback_no_shard()
(withworld_size=1
)test_unshard_params_writeback()
(withworld_size=2
)writeback
argument (using the default value for all others)test_unshard_params_param_data_no_shard()
(withworld_size=1
)test_unshard_params_param_data()
(withworld_size=2
)recurse=True
and all other argument configs for a non-FSDP root moduletest_unshard_singleton_param_writeback()
writeback=True
for a singleton parameter, which includes testing that writing to padding does not persisttest_unshard_params_respects_reshard()
test_unshard_params_recurse()
recurse
argument (using default for all others)test_offload_to_cpu_no_shard_raises()
offload_to_cpu=True
withNO_SHARD
raises an errorSummary of Unit Test Changes
test_summon_full_param_writeback
->test_unshard_params_writeback()
test_summon_full_params_equivalence()
,test_params_are_unflattenned()
,test_params_count_and_value()
->test_unshard_params_param_data()
test_summon_full_params_respects_reshard_after_forward()
,test_reshard_outside_forward_backward_iteration()
->test_unshard_params_respects_reshard()
test_summon_full_param_recursive()
->test_unshard_params_recurse()
test_named_parameters_and_buffers()
unchangedtest_with_grads_core()
unchangedtest_with_grads_none_grads()
unchangedtest_cannot_summon_full_params_from_forward()
,test_cannot_summon_full_params_from_backward()
->test_unshard_params_from_forward_raises()
,test_unshard_params_from_backward_raises()
test_raises_rank0_with_writeback()
->test_rank0_only_with_writeback_raises()
test_offload_to_cpu_no_shard_raises()
newtest_summon_full_param_shard_value()
removed