-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of variadic length/type argument in result_type
#61168
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 9d53b03 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Thanks @asi1024! A Windows
|
@rgommers I found an unexpected behavior in >>> numpy.result_type(numpy.int8, True)
dtype('int8')
>>> numpy.result_type(numpy.int8, 10)
dtype('int8')
>>> numpy.result_type(numpy.int8, True, 10)
dtype('int16')
|
I think we should return int8 in this case. @rgommers? |
@rgommers Could you help me how to resolve the vmap and mypy CI failures? |
test/test_type_promotion.py
Outdated
inputs = [_convert_to_numpy_input(x) for x in inputs] | ||
return np.result_type(*inputs) | ||
|
||
for x1 in inputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nit:
for x0, x1, x2 in product(inputs * 3):
# if inputs of mixed dtypes are given. The tests for floating type inputs are left to | ||
# `test_result_type` and `test_result_type_tensor_vs_scalar`. | ||
dtypes = [torch.bool, torch.uint8, torch.int16, torch.int64] | ||
inputs = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of comparing with NumPy what about implementing the reference using the addition operator?
def result_type_ref(args):
assert len(args) > 0
t = torch.tensor(True)
for arg in args:
if isinstance(arg, torch.dtype):
# See comment below for a discussion of what type of object dtype should be represented by
t = t + torch.tensor((1,), dtype=arg)
else:
t = t + arg
return t.dtype
I think this would allow for a simpler test structure and more thorough testing, plus it would compare with the ground truth of what PyTorch is doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reference implementation using +
sometimes returns unexpected results.
>>> a = torch.tensor([1], dtype=torch.int8)
>>> b = 1.
>>> c = torch.tensor([1], dtype=torch.float16)
>>> (a + b + c).dtype
torch.float32
>>> (a + c + b).dtype
torch.float16
Do you have some ideas for this issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's really cool!
And it's true, the where our dtypes work the addition operator isn't commutative. Both JAX and NumPy's result_type appears to handle this case correctly.
So I think my suggestion was mistaken. Sorry about that @asi1024. I didn't realize that result_type() can't be thought of as a series of binary elementwise operations. You are right and we can't use this as a reference implementation.
Returning to the previous tests could be a reasonable solution. If we want to extend that further we could add some "golden value" tests with manually generated checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After I thought further, I have an idea for a reference implementation.
def result_type_ref(args):
assert len(args) > 0
tensor = torch.tensor([True])
tensor_0dim = torch.tensor(True)
scalar = True
for arg in args:
if isinstance(arg, torch.dtype):
tensor = tensor + torch.tensor([1], dtype=arg)
elif isinstance(arg, torch.Tensor):
if arg.ndim == 0:
tensor_0dim = tensor_0dim + arg
else:
tensor = tensor + arg
elif isinstance(arg, Number):
scalar = scalar + arg
else:
assert False, "unknown type"
return ((tensor + tensor_0dim) + scalar).dtype
It just uses the same logic as the result_type implementation in this PR. Can I use this implementation in my tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right, that's a clever approach, @asi1024. Let's try using this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks great to me, @asi1024; I made a couple inline comments for your review
cc @heitorschueroff -- would you take a look at the functional implementation?
cc @bhosmer -- would you or someone else from the composability team like to look at the dispatch extension?
the python arg parser changes look fine |
@@ -402,6 +402,39 @@ def _test_spot(a, b, res_dtype): | |||
torch.tensor(1., dtype=torch.complex64, device=device), torch.complex128) | |||
_test_spot(torch.tensor([1, 1], dtype=torch.bool, device=device), 1., torch.get_default_dtype()) | |||
|
|||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry Is this skip necessary? I thought we always include NumPy now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. We require NumPy as a dependency when running the test suite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this PR looks great! Besides a few nit comments, the main thing we need to discuss is the behavior of scalar tensors. I think we should follow the array api standard and treat scalar tensors (0 dimensions) the same as every other tensor. And treat tensors with higher priority than python scalars. If this is the case, I think the logic in functional.py
has to change a bit. @mruberry and @rgommers what do you think?
Thanks for this excellent PR.
test/test_type_promotion.py
Outdated
if isinstance(x, int): | ||
return np.int64(x) | ||
return torch.tensor([1], dtype=x).numpy().dtype | ||
inputs = [_convert_to_numpy_input(x) for x in inputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: writing this as a regular for loop instead of defining a method just to be able to use a list comprehension is more readable.
torch/functional.py
Outdated
@@ -1589,3 +1590,49 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None): | |||
|
|||
def align_tensors(*tensors): | |||
raise RuntimeError('`align_tensors` not yet implemented.') | |||
|
|||
def result_type(*arrays_and_dtypes: Union[Tensor, torch.dtype, bool, int, float, complex]) -> torch.dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex
can be used to represent any complex
, float
and int
in Python typing (see https://www.python.org/dev/peps/pep-0484/#the-numeric-tower).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More generically there's Number (https://docs.python.org/3/library/numbers.html), but this seems fine
torch/functional.py
Outdated
@@ -1589,3 +1590,49 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None): | |||
|
|||
def align_tensors(*tensors): | |||
raise RuntimeError('`align_tensors` not yet implemented.') | |||
|
|||
def result_type(*arrays_and_dtypes: Union[Tensor, torch.dtype, bool, int, float, complex]) -> torch.dtype: | |||
"""result_type(*arrays_and_dtypes: Union[Tensor, dtype, bool, int, float, complex]) -> dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the python array API standard mentions leaving out type annotations from signature and only including them in the parameter list (https://data-apis.org/array-api/latest/API_specification/data_type_functions.html?highlight=result_type#result-type-arrays-and-dtypes).
>>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) | ||
torch.uint8 | ||
>>> torch.result_type(torch.int32, torch.float32) | ||
torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an example including tensor, scalar and dtype?
torch/functional.py
Outdated
tensors = [] | ||
scalars: List[Union[bool, int, float, complex]] = [] | ||
dtypes = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: type annotation for the other lists
torch/functional.py
Outdated
scalars.append(x) | ||
elif isinstance(x, torch.dtype): | ||
dtypes.append(x) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to elif isinstance(x, torch.Tensor)
and add an else
clause and raise an Error
torch/functional.py
Outdated
if dtypes: | ||
return _VF._result_type_dtypes(dtypes) | ||
else: | ||
raise TypeError("at least one argument is required.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change message to result_type(): must provide at least one argument
On second thought, if we are giving higher priority to tensors over python scalars such that if there is at least one tensor, python scalars do no affect the result. Then is there a use case for wanting to know the result_type between python scalars? such as |
Scalars do affect the choice of computation dtype because the scalar might have a higher type kind than the tensor. For example when adding a float tensor to an int tensor the result is a float tensor.
Most PyTorch operations don't support exclusively scalar arguments but some do and it's a nice feature to support. |
We should do this but we don't today, so this PR correctly implements PyTorch's current behavior. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @asi1024! Thank you for your thoughtful responses. There are a few minor comments still inline waiting for your review, but per your analysis this look good overall.
Would you make a last pass and then ping me when you're happy to merge this?
@mruberry |
torch/functional.py
Outdated
tensors.append(x) | ||
else: | ||
raise TypeError(f"result_type(): cannot interpret '{x}' as a data type") | ||
if dtypes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check the length of these lists explicitly instead of just using if
on them
test/test_type_promotion.py
Outdated
10, # int scalar | ||
10.0, # float scalar | ||
10j, # complex scalar | ||
*[torch.tensor([1, 2, 3], dtype=dtype) for dtype in dtypes], # tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is always running on the CPU because it's ignoring the device arg.
Does the current implementation work with CUDA tensors? What if both CUDA and CPU tensors are passed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added tests for CUDA/mixed device tensors and confirmed that the current implementation passes the tests!
raise TypeError(f"result_type(): cannot interpret '{x}' as a data type") | ||
if dtypes: | ||
dtype = _VF._result_type_dtypes(dtypes) | ||
tensors.append(torch.tensor([], dtype=dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modeling this as creating a tensor seems a little odd and I'm guessing this won't work if calling result_type() with CUDA tensors (currently CUDA inputs are not tested, see above comment).
scalar_dtype = _VF._result_type_scalars(scalars) # type: ignore[arg-type] | ||
if tensors: | ||
tensor_dtype = _VF.result_type(tensors) # type: ignore[attr-defined] | ||
return _VF.result_type(_VF.tensor([], dtype=tensor_dtype), # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point it might be preferable to perform an addition like in the reference implementation
torch.float32 | ||
""" | ||
|
||
tensors: List[torch.Tensor] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible alternative implementation of this, inspired by your proposed reference implementation above:
assert len(arrays_and_dtypes) > 0
tensors = [t for t in arrays_and_dtypes if isinstance(t, torch.Tensor)]
scalars = [s for s in arrays_and_dtypes if isinstance(t, Number)]
dtypes = [d for d in arrays_and_dtypes if isinstance(t, torch.dtype)]
tensor_dtype = None if len(tensors) == 0 else _VF._result_type(tensors)
scalar_dtype = None if len(scalars) == 0 else _VF._result_type_scalars(scalars)
dtype_dtype = None if len(dtypes) == 0 else _VF._result_type_dtypes(dtypes)
return (torch.tensor([], dtype=tensor_dtype) + torch.tensor([], dtype=dtype_dtype) + torch.tensor(0, dtype=scalar_dtype)).dtype
Follow-up question: will this or an implementation like the current version work with the jit scripting and tracing? Is there a test for that?
If the jit is still an issue let me know ASAP and I can help modify the PR to fix that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have confirmed that the jit scripting/tracing test works fine locally, but I have not added tests yet. In which file should I add the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_jit.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry again for the wait, @asi1024. I made some inline comments. Your suggestion for a new reference implementation looks very good. Let me know if you have any issues, especially issues with scripting or tracing. If you do we can help modify the PR. Your logic is the important part.
b36a878
to
9d53b03
Compare
@mruberry Sorry I totally missed your review comments for long days 🙇 |
Codecov Report
@@ Coverage Diff @@
## master #61168 +/- ##
==========================================
- Coverage 66.38% 66.28% -0.11%
==========================================
Files 727 734 +7
Lines 93573 94079 +506
==========================================
+ Hits 62117 62357 +240
- Misses 31456 31722 +266 |
And I missed your update! Sorry @asi1024! I'll take another look at the PR now. I don't know the answer to your question offhand. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a chance to read through and this looks pretty good, @asi1024! One of our new colleagues, @saketh-are, just started investigating PyTorch's type promotion and I'd like him to take a look, too.
@asi1024 I had a chance to look through this and it looks good to me. Just FYI I am working on a type promotion change after which 0-dim tensor operands and dimensioned tensor operands will be treated identically. I don't think this PR needs to change anywhere, though, assuming it's merged in first. |
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Unfortunately it looks like this is hitting an internal merge failure that will need further review. Sorry for the delay, @asi1024. |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Fixes #51284 (cc/ @mruberry, @heitorschueroff, @rgommers, @emcastillo, @kmaehashi)
This PR adds support of variadic length/type argument in
torch.result_type
for the compatibility with NumPy’s interface and Python array API standard.Reference: #51284 (comment)
TODO: