-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix consistentcy of histc on CPU and CUDA #87832
Fix consistentcy of histc on CPU and CUDA #87832
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87832
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit e187347: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
great, if you fix lint we can merge it. |
test/test_reductions.py
Outdated
@@ -2843,6 +2843,13 @@ def test_against_np(tensor, bins=100, min=0, max=0): | |||
expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2) | |||
test_against_np(expanded) | |||
|
|||
if torch.cuda.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would test_against_np catch it instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, there is a discrepancy with numpy for this input tensor.
The code
import torch
import numpy as np
def fn(x):
return torch.histc(x, bins=10, min=0, max=0.99)
def fn_np(x):
return torch.from_numpy(np.histogram(x.numpy(), bins=10, range=(0, 0.99))[0])
x = torch.linspace(0, 0.99, 1001, dtype=torch.float32)
torch.histc(x, bins=10, min=0, max=0.99)
print(fn(x))
print(fn(x.cuda()))
print(fn_np(x))
Produces:
tensor([101., 99., 101., 99., 101., 99., 100., 100., 100., 101.])
tensor([101., 99., 101., 99., 101., 99., 100., 100., 100., 101.], device='cuda:0')
tensor([101, 99, 100, 100, 100, 100, 100, 100, 100, 101])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is NumPy's histogram
more accurate? Making histc
consistent on CPU and CUDA is nice, but is the histc
algorithm we're using actually correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NumPy might be more accurate because it uses direct comparison https://github.com/numpy/numpy/blob/13d55a3c2f016a58a6e9d6b8086f338e07c7478f/numpy/lib/histograms.py#L862-L866 to find out if the value is within the range or not. However, I noticed that it struggles with this case too as it produces different results for float32
and float64
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of conditioning on CUDA being available, just compare the result to the expected outcome of
tensor([101., 99., 101., 99., 101., 99., 100., 100., 100., 101.])
directly, and add a comment that NumPy produces a different result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you change boundary computation to what cpu is doing now (instead of to what cuda is doing) would the results match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the results will exactly match if I slightly change the input tensor or histogram boundaries by a machine epsilon.
One test tweak needed (see comment inline) and then this should be OK. I'd really like to get the CUDA version of histogramdd added, so we can consistently direct people to that. It's more accurate and more consistent with NumPy. |
@pytorchbot rebase |
@Aidyn-A should we merge this now? |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
ad890b0
to
e187347
Compare
@kit1980 yes, we can merge it now. |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#87657 The main reason why `histc` returns slightly different outputs is the difference on how bin position is calculate. The CPU calculates it as: https://github.com/pytorch/pytorch/blob/449778a939f2adc8867c5035b08be4e2d88339d8/aten/src/ATen/native/cpu/HistogramKernel.cpp#L168-L170 which is basically `(i - a) / (b - a) * N`, while cuda code https://github.com/pytorch/pytorch/blob/449778a939f2adc8867c5035b08be4e2d88339d8/aten/src/ATen/native/cuda/SummaryOps.cu#L41 which is `(i - a) * N / (b - a)`. For some cases like in pytorch#87657 the order of arithmetic operations matters due to the floating point round-off. ________________ Not sure where would be the most appropriate place to put the unit test. Hope `test_reductions::test_histc` will do. Pull Request resolved: pytorch#87832 Approved by: https://github.com/soumith
Fixes #87657
The main reason why
histc
returns slightly different outputs is the difference on how bin position is calculate.The CPU calculates it as:
pytorch/aten/src/ATen/native/cpu/HistogramKernel.cpp
Lines 168 to 170 in 449778a
which is basically
(i - a) / (b - a) * N
, while cuda codepytorch/aten/src/ATen/native/cuda/SummaryOps.cu
Line 41 in 449778a
which is
(i - a) * N / (b - a)
.For some cases like in #87657 the order of arithmetic operations matters due to the floating point round-off.
Not sure where would be the most appropriate place to put the unit test. Hope
test_reductions::test_histc
will do.cc @VitalyFedyunin @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10