Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix consistentcy of histc on CPU and CUDA #87832

Conversation

Aidyn-A
Copy link
Collaborator

@Aidyn-A Aidyn-A commented Oct 27, 2022

Fixes #87657

The main reason why histc returns slightly different outputs is the difference on how bin position is calculate.
The CPU calculates it as:

pos = static_cast<int64_t>((elt - leftmost_edge[dim])
/ (rightmost_edge[dim] - leftmost_edge[dim])
* (num_bin_edges[dim] - 1));

which is basically (i - a) / (b - a) * N, while cuda code
IndexType bin = (int)(((bVal - minvalue)) * nbins / (maxvalue - minvalue));

which is (i - a) * N / (b - a).

For some cases like in #87657 the order of arithmetic operations matters due to the floating point round-off.


Not sure where would be the most appropriate place to put the unit test. Hope test_reductions::test_histc will do.

cc @VitalyFedyunin @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87832

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit e187347:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 27, 2022
@soumith
Copy link
Member

soumith commented Oct 27, 2022

great, if you fix lint we can merge it.

@@ -2843,6 +2843,13 @@ def test_against_np(tensor, bins=100, min=0, max=0):
expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
test_against_np(expanded)

if torch.cuda.is_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would test_against_np catch it instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, there is a discrepancy with numpy for this input tensor.
The code

import torch
import numpy as np

def fn(x):
    return torch.histc(x, bins=10, min=0, max=0.99)

def fn_np(x):
    return torch.from_numpy(np.histogram(x.numpy(), bins=10, range=(0, 0.99))[0])


x = torch.linspace(0, 0.99, 1001, dtype=torch.float32)
torch.histc(x, bins=10, min=0, max=0.99)

print(fn(x))
print(fn(x.cuda()))
print(fn_np(x))

Produces:

tensor([101.,  99., 101.,  99., 101.,  99., 100., 100., 100., 101.])
tensor([101.,  99., 101.,  99., 101.,  99., 100., 100., 100., 101.], device='cuda:0')
tensor([101,  99, 100, 100, 100, 100, 100, 100, 100, 101])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is NumPy's histogram more accurate? Making histc consistent on CPU and CUDA is nice, but is the histc algorithm we're using actually correct?

Copy link
Collaborator Author

@Aidyn-A Aidyn-A Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumPy might be more accurate because it uses direct comparison https://github.com/numpy/numpy/blob/13d55a3c2f016a58a6e9d6b8086f338e07c7478f/numpy/lib/histograms.py#L862-L866 to find out if the value is within the range or not. However, I noticed that it struggles with this case too as it produces different results for float32 and float64.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of conditioning on CUDA being available, just compare the result to the expected outcome of

tensor([101.,  99., 101.,  99., 101.,  99., 100., 100., 100., 101.])

directly, and add a comment that NumPy produces a different result

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you change boundary computation to what cpu is doing now (instead of to what cuda is doing) would the results match?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the results will exactly match if I slightly change the input tensor or histogram boundaries by a machine epsilon.

@mruberry
Copy link
Collaborator

One test tweak needed (see comment inline) and then this should be OK. I'd really like to get the CUDA version of histogramdd added, so we can consistently direct people to that. It's more accurate and more consistent with NumPy.

@kit1980
Copy link
Member

kit1980 commented Nov 18, 2022

@pytorchbot rebase

@kit1980
Copy link
Member

kit1980 commented Nov 18, 2022

@Aidyn-A should we merge this now?

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix_consistency_between_cpu_and_cuda_for_histc onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix_consistency_between_cpu_and_cuda_for_histc && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the fix_consistency_between_cpu_and_cuda_for_histc branch from ad890b0 to e187347 Compare November 18, 2022 02:05
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Nov 18, 2022
@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Nov 18, 2022

@kit1980 yes, we can merge it now.
@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Fixes pytorch#87657

The main reason why `histc` returns slightly different outputs is the difference on how bin position is calculate.
The CPU calculates it as: https://github.com/pytorch/pytorch/blob/449778a939f2adc8867c5035b08be4e2d88339d8/aten/src/ATen/native/cpu/HistogramKernel.cpp#L168-L170
which is basically `(i - a) / (b - a) * N`, while cuda code https://github.com/pytorch/pytorch/blob/449778a939f2adc8867c5035b08be4e2d88339d8/aten/src/ATen/native/cuda/SummaryOps.cu#L41
 which is `(i - a) * N / (b - a)`.

For some cases like in pytorch#87657 the order of arithmetic operations matters due to the floating point round-off.

________________

Not sure where would be the most appropriate place to put the unit test. Hope `test_reductions::test_histc` will do.

Pull Request resolved: pytorch#87832
Approved by: https://github.com/soumith
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

histc return inconsistent value on CPU and CUDA
7 participants