-
Couldn't load subscription status.
- Fork 25.7k
[decomp] Use var_mean in native_batch_norm decomposition #94140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94140
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit c9e2a65: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
ghstack-source-id: efcfcf7 Pull Request resolved: pytorch#94140
[ghstack-poisoned]
ghstack-source-id: 9f99f25 Pull Request resolved: pytorch#94140
|
Did you check that inductor perf didn't change? |
| (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, | ||
| (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4, | ||
| (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4, | ||
| (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does tolerance change here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten.var_mean uses a different algorithm from aten.mean which ends up being slightly more precise. 5e-7 is still incredibly good for half-precision though. The default rtol for torch.testing is 1e-5.
This improves perf by removing the duplicate mean calculation. However it's very slight since the second mean was being fused with the sum of square deviations in the variance. In the following example, I see a 0.6% speedup from 366 us to 364 us import torch
import torch._dynamo
from torch._inductor import config
config.debug = True
a = torch.nn.BatchNorm3d(10).train().cuda()
b = torch.rand(10, 10, 16, 64, 64, device="cuda")
@torch._dynamo.optimize()
def fn(x):
return a(x)
_ = fn(b)
%timeit fn(b) |
Stack from ghstack (oldest at bottom):