-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DTensor] Turn on foreach implementation for clip_grad_norm_ for DTensor by default #126423
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
@@ -67,11 +67,11 @@ def _test_clip_grad_norm( | |||
) | |||
comm_mode = CommDebugMode() | |||
with comm_mode: | |||
# foreach is default to turn on so we don't need to specify it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: haha I would be okay to just not include this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. You are so fast! Looks like the test_clip_grad_norm_2d
errors out on main T_T and thereby erroring out on the PR as well. Looking now.
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…sor by default (pytorch#126423) Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#126423 Approved by: https://github.com/awgu
With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer. pytorch/pytorch#125369 pytorch/pytorch#126423 pytorch/pytorch#126750 Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default [llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml). Observation: For `fused = True` and `fused = False`, we observed similar loss curve and memory usage. wps is + ~100 and mfu is + 1.5-2% when fused = True. Below are the logs for the last 100 steps for both. ``` **Fused = False** [rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds [rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910 loss: 4.6039 memory: 59.48GiB(75.15%) wps: 2,217 mfu: 41.16% [rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920 loss: 4.6427 memory: 59.48GiB(75.15%) wps: 2,632 mfu: 48.85% [rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930 loss: 4.6339 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940 loss: 4.5964 memory: 59.48GiB(75.15%) wps: 2,631 mfu: 48.84% [rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950 loss: 4.6477 memory: 59.48GiB(75.15%) wps: 2,611 mfu: 48.47% [rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960 loss: 4.8137 memory: 59.48GiB(75.15%) wps: 2,626 mfu: 48.75% [rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970 loss: 4.5962 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980 loss: 4.5450 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990 loss: 4.5840 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000 loss: 4.5351 memory: 59.48GiB(75.15%) wps: 2,615 mfu: 48.53% [rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds [rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed **Fused = True** [rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910 loss: 4.6091 memory: 59.48GiB(75.15%) wps: 2,341 mfu: 43.46% [rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920 loss: 4.6468 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930 loss: 4.6365 memory: 59.48GiB(75.15%) wps: 2,715 mfu: 50.40% [rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940 loss: 4.6021 memory: 59.48GiB(75.15%) wps: 2,716 mfu: 50.41% [rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950 loss: 4.6522 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960 loss: 4.8163 memory: 59.48GiB(75.15%) wps: 2,717 mfu: 50.44% [rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970 loss: 4.6026 memory: 59.48GiB(75.15%) wps: 2,721 mfu: 50.51% [rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980 loss: 4.5491 memory: 59.48GiB(75.15%) wps: 2,720 mfu: 50.49% [rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990 loss: 4.5859 memory: 59.48GiB(75.15%) wps: 2,719 mfu: 50.47% [rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000 loss: 4.5396 memory: 59.48GiB(75.15%) wps: 2,714 mfu: 50.38% [rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed ```
Fixes #ISSUE_NUMBER
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @yf225 @chauhang @d4l3k @msaroufim