-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add mixed data type support for GroupNorm #81852
Conversation
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. [ghstack-poisoned]
🔗 Helpful links
❌ 1 New FailuresAs of commit 624f804 (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakagespull / linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge) (1/1)Step: "Test" (full log | diagnosis details)
|
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. ghstack-source-id: 31fb4afe21fc0afb240eb9038501dfdd2933a57b Pull Request resolved: #81852
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. [ghstack-poisoned]
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. ghstack-source-id: 9a20d8f52a8579c8bef69abf367be35eba34cc93 Pull Request resolved: #81852
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. [ghstack-poisoned]
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. ghstack-source-id: d6c20b15e0d7da0fee1cd765031f43a209921bdc Pull Request resolved: #81852
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. ghstack-source-id: d6c20b15e0d7da0fee1cd765031f43a209921bdc Pull Request resolved: pytorch#81852
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/81852
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d7f5ccd: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/81852
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. [ghstack-poisoned]
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. ghstack-source-id: e9198d5ade2a328bbfddfbdd3394e202c259ae7c Pull Request resolved: #81852
Seems to be caused by incompatible types in group_norm when we use autocast. Patch group_norm to cast the weights to the same type as the inputs From what I can understand all the other repos just switch to full precision instead of addressing this. I think this would make things slower but I'm not sure. So maybe the patching solution I'm doing is better? pytorch/pytorch#81852
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. ghstack-source-id: e402edfac1d5252a979217a2ccb6412ea58839cc Pull Request resolved: #81852
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. ghstack-source-id: 02e74b34efad7e03e616d314ef928b695d266c2f Pull Request resolved: #81852
@pytorchbot merge |
Merge failedReason: Approval needed from one of the following (Rule 'superuser'): Details for Dev Infra teamRaised by workflow job |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit on the perf, we introduce unnecessary if (mixedtype)
for all dtypes but BFloat16, perhaps one can refactor this code as:
template<scalar_t>
void CallGroupNorm(...)
And specialize it for BFloat16 to call mixed_type
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. ghstack-source-id: 31172bb8d22486bdaf0681693f04e5052da71cc9 Pull Request resolved: #81852
We will add minxed_type for float16 in near future for the coming new platform of Xeon (Granite Rapids). Also I noticed that |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Seems to be caused by incompatible types in group_norm when we use autocast. Patch group_norm to cast the weights to the same type as the inputs From what I can understand all the other repos just switch to full precision instead of addressing this. I think this would make things slower but I'm not sure. So maybe the patching solution I'm doing is better? pytorch/pytorch#81852
Stack from ghstack:
If user uses amp to run bfloat16 models,
torch.autocast
willkeep module paramters in acc dtype which will leave
gamma
andbeta
in float while input/output will be in bfloat16.
If user explicitly cast the model to bfloat16,
the input/output and gamma/beta will all be in bfloat16.
cc @VitalyFedyunin @jgong5 @XiaobingSuper @sanchitintel @ashokei @jingxu10