-
Notifications
You must be signed in to change notification settings - Fork 25.5k
[FSDP][Replicate] tests replicate parity for shared parameters #162836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162836
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 8e2c2d3 with merge base 77b9aac ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ters" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…ters" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…ters" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…ters" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…ters" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #162836 Approved by: https://github.com/mori360 ghstack dependencies: #162830
…obatching (#162839) **Summary:** In order to ensure that replicate acts as intended (a specialized version of hsdp) we need to make sure that it can pass the same tests that fully_shard can for training. The first test verifies Replicate works with gradient accumulation properly. The second verifies that replicate works correctly with a One-Forward-One-Backward (1F1B) pipeline parallelism schedule **Test Cases** 1. pytest test/distributed/_composable/test_replicate_training.py -k test_gradient_accumulation 2. pytest test/distributed/_composable/test_replicate_training.py -k test_1f1b_microbatching Pull Request resolved: #162839 Approved by: https://github.com/mori360 ghstack dependencies: #162830, #162836
**Summary: tests replicate works when users use custom forward methods** **Test Cases** 1. pytest test/distributed/_composable/test_replicate_training.py -k test_register_fsdp_forward_method Pull Request resolved: #162851 Approved by: https://github.com/mori360 ghstack dependencies: #162830, #162836, #162839
…sion (#162855) **Summary:** Ensures that replicate functionality works the same as fully shard's when mixed precision is used **Test Cases** 1. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k TestReplicateMixedPrecisionTraining Pull Request resolved: #162855 Approved by: https://github.com/mori360 ghstack dependencies: #162830, #162836, #162839, #162851, #162853
…s in mixed precision (#162861) **Summary:** Ensures that replicate can handle the same type casting behavior and edge cases that fully shard can when mixed precision is used **Test Cases** 1. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_float16_on_one_submodule 2. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_submodules_with_external_inputs 3. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_bf16 4. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_fp16 5. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_clamp_reduce_dtype 6. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_dataclass_input Pull Request resolved: #162861 Approved by: https://github.com/mori360 ghstack dependencies: #162830, #162836, #162839, #162851, #162853, #162855
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci