Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix uniform returning end point for BFloat16 and Half #96962

Closed
wants to merge 5 commits into from

Conversation

peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Mar 16, 2023

Fixes #96947

If we generate 1.0 - float_eps, the BFloat16 and Half constructors will round this to 1.0 which is outside of the half-open range. Instead, we delay the bounds change until after the value has been rounded.

cc @pbelevich @pmeier

Fixes pytorch#96947

If we generate 1.0 - float_eps, the BFloat16 and Half constructors will
round this to 1.0 which is outside of the half-open range. This
changes the rounding of the last bit in the BFloat16 representation to
never round up. The result is we never go outside the end point and
also the from point now equally likely where before it was half as
likely.
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96962

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 0bf366a:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How nice is to have if constexpr? :D

// Note for BFloat16 and Half, the default constructor does
// round to nearest even, which may return the end point of our
// range. Use truncation rounding instead.
return truncate_to<scalar_t>(reverse_bound_rand * range + from);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a FP precision issue. Let

float reverse_bound_rand  = 0.99999994;  // FF FF 7F 3F in memory
float range = 1.0;
half from = 100.0;

then reverse_bound_rand * range + from would be 101 (in float type), and truncate_to<scalar_t>(101) is also 101 (in half type). However, the function is supposed not to output the to_ value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to reserve bounds after truncating: yuantailing@aee8c06

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 20, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 20, 2023

No ciflow labels are configured for this repo.
For information on how to enable CIFlow bot see this wiki

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@peterbell10 peterbell10 added module: random Related to random number generation in PyTorch (rng generator) topic: bug fixes topic category release notes: cuda release notes category labels Mar 20, 2023
@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@pmeier
Copy link
Collaborator

pmeier commented Mar 20, 2023

Hard to say if the test failure is valid or not:

2023-03-20T18:20:38.1245440Z _ TestSDPA.test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_2048_seq_len_k_128_head_dim_8_is_causal_False_dropout_p_0_22_bfloat16_scale_None _
2023-03-20T18:20:38.1245557Z Traceback (most recent call last):
2023-03-20T18:20:38.1245764Z   File "/var/lib/jenkins/workspace/test/test_transformers.py", line 1909, in test_flash_attention_vs_math_ref_grads
2023-03-20T18:20:38.1245943Z     self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2023-03-20T18:20:38.1246248Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3013, in assertEqual
2023-03-20T18:20:38.1246365Z     raise error_metas[0].to_error(
2023-03-20T18:20:38.1246514Z AssertionError: Tensor-likes are not close!
2023-03-20T18:20:38.1246520Z 
2023-03-20T18:20:38.1246632Z Mismatched elements: 1 / 65536 (0.0%)
2023-03-20T18:20:38.1246803Z Greatest absolute difference: 0.0014754831790924072 at index (0, 2, 2033, 0) (up to 0.0011385045945644379 allowed)
2023-03-20T18:20:38.1247035Z Greatest relative difference: 319.4129032258065 at index (0, 2, 2033, 0) (up to 71.4206771850586 allowed)

The dtype and number of elements fit the error case described in #96947.

This check was added (or at least adapted) in #94009. But this comment doesn't really instill confidence:

# TODO: Investigate why grad_q needs larger tolerances
grad_q_deviation = query_ref.grad - query_ref_lp.grad
grad_q_ref_atol = max(2 * torch.abs(grad_q_deviation).max().item(), default_atol[out.dtype])
grad_q_ref_rtol = max(get_rtol(query_ref.grad, query_ref_lp.grad), default_rtol[out.dtype])

@drisspg has there been any progress on investigating this? What do you propose to move forward?

@ngimel
Copy link
Collaborator

ngimel commented Mar 20, 2023

In the meantime I think we should increase tolerance for this case

@pmeier
Copy link
Collaborator

pmeier commented Mar 20, 2023

We would need to increase the rtol five fold and the atol roughly by 50%. And we are currently stopping after the first failure. Maybe there are more parameters that fail here.

@pmeier pmeier added the keep-going Don't stop on first failure, keep running tests until the end label Mar 20, 2023
@ngimel
Copy link
Collaborator

ngimel commented Mar 20, 2023

atol is 0.0011 vs 0.0014, (that's 30%), rtol doesn't matter if atol is fine.

@drisspg
Copy link
Contributor

drisspg commented Mar 20, 2023

In the meantime I think we should increase tolerance for this case

I agree with this. I have not had a chance to fully characterize the nature of these FP errors. I think that it would be fine to bump the mulitiplier of 2 to something greater. There are 65k elements in that tensors and the parametrization sweeps 20k tests, that entry could be an outlier ( not that much confidience installing either)

@pmeier
Copy link
Collaborator

pmeier commented Mar 20, 2023

I've added the keep-going label so let's see what falls out.

@pmeier
Copy link
Collaborator

pmeier commented Mar 21, 2023

It seems the only related failure is the one we observed above. In the light of #96962 (comment), I think we are good just upping the tolerance.

@pmeier
Copy link
Collaborator

pmeier commented Mar 21, 2023

@pytorchbot merge -g

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 21, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: -g

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

@pmeier
Copy link
Collaborator

pmeier commented Mar 21, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_8-cuda11_7-test / test

Details for Dev Infra team Raised by workflow job

@pmeier
Copy link
Collaborator

pmeier commented Mar 21, 2023

@pytorchbot merge -f 'unrelated triton version issue'

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 23, 2023
Fixes #96947

If we generate `1.0 - float_eps`, the BFloat16 and Half constructors will round this to 1.0 which is outside of the half-open range. Instead, we delay the bounds change until after the value has been rounded.

Pull Request resolved: pytorch/pytorch#96962
Approved by: https://github.com/lezcano, https://github.com/ngimel
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 27, 2023
Fixes #96947

If we generate `1.0 - float_eps`, the BFloat16 and Half constructors will round this to 1.0 which is outside of the half-open range. Instead, we delay the bounds change until after the value has been rounded.

Pull Request resolved: pytorch/pytorch#96962
Approved by: https://github.com/lezcano, https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: random Related to random number generation in PyTorch (rng generator) open source release notes: cuda release notes category topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.rand can sample the upper bound for lower precision floating point dtypes on CUDA
8 participants