Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chalf] reference_testing: low quality test for fast growing ops #78170

Open
kshitij12345 opened this issue May 24, 2022 · 4 comments
Open

[chalf] reference_testing: low quality test for fast growing ops #78170

kshitij12345 opened this issue May 24, 2022 · 4 comments
Labels
module: complex Related to complex number support in PyTorch module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kshitij12345
Copy link
Collaborator

kshitij12345 commented May 24, 2022

馃悰 Describe the bug

In PR #77640:
Since range of chalf is much less compared to cfloat, we get infs easily (eg. with pow, exp), so we cast cfloat back to chalf.

However, this is might mask an actual issue as we don't control the percent of input that will be valid. The correct approach would be to sample input which are valid given the range of chalf.

One of the approach would be to add extra meta-data to OpInfo.

cc: @ngimel @mruberry @anjali411

Versions

master

cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved

@kshitij12345 kshitij12345 added module: complex Related to complex number support in PyTorch module: half Related to float16 half-precision floats labels May 24, 2022
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 25, 2022
@ngimel
Copy link
Collaborator

ngimel commented May 25, 2022

We should have these problems already with at::Half?

@kshitij12345
Copy link
Collaborator Author

Reference test against NumPy doesn't run on half

pytorch/test/test_ops.py

Lines 331 to 332 in 141238a

@ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128))
def test_numpy_ref(self, device, dtype, op):

and for UnaryUfuncs we use values based on dtype
if dtype.is_floating_point:
if dtype is torch.float16:
# float16 has smaller range.
vals = _large_float16_vals
else:
vals = _large_float_vals
elif dtype.is_complex:

However in this case, since we are testing all types of functions.

Also, te value range of Tensors generated by make_tensor is conservative. So this an issue only for selected functions like pow, rsqrt (for values close to zero).

@anjali411
Copy link
Contributor

Does this issue arise because we go back and forth between cfloat and chalf? If yes, then BFloat16 should also have a similar issue right?

The correct approach would be to sample input which are valid given the range of chalf.

It seems like the permissible range should also depend on the type on function right?

@kshitij12345
Copy link
Collaborator Author

It happens because of the difference in the range of values supported by cfloat and chalf

>>> torch.finfo(torch.cfloat)
finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=float32)
>>> torch.finfo(torch.chalf)
finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, smallest_normal=6.10352e-05, tiny=6.10352e-05, dtype=float16)

>>> t = torch.tensor([0.00005], dtype=torch.chalf, device='cuda')
# inverse sqrt
>>> torch.pow(t, -2)
tensor([inf-0.j], device='cuda:0', dtype=torch.complex32) # chalf returns inf
>>> torch.pow(t.to(torch.cfloat), -2)
tensor([3.9987e+08-0.j], device='cuda:0') # cfloat returns valid value

Yes, the permissible range is dependent on function. I think only few functions pow and rsqrt will suffer from this. Even exp will work as the sample inputs are generated from make_tensor which by defaults generates tensor with value only between -9 and 9 for complex types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: complex Related to complex number support in PyTorch module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants