Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of variadic length/type argument in result_type #61168

Closed
wants to merge 11 commits into from

Conversation

asi1024
Copy link
Contributor

@asi1024 asi1024 commented Jul 2, 2021

Fixes #51284 (cc/ @mruberry, @heitorschueroff, @rgommers, @emcastillo, @kmaehashi)

This PR adds support of variadic length/type argument in torch.result_type for the compatibility with NumPy’s interface and Python array API standard.

>>> torch.result_type(torch.int8, torch.tensor([1], dtype=torch.uint8), 10)
torch.int16

Reference: #51284 (comment)

TODO:

  • tests

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 2, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 9d53b03 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@rgommers
Copy link
Collaborator

rgommers commented Jul 4, 2021

Thanks @asi1024!

A Windows vmap test seems unhappy:

RuntimeError: Batching rule not implemented for aten::result_type.TensorList. We could not generate a fallback.

@asi1024
Copy link
Contributor Author

asi1024 commented Jul 5, 2021

@rgommers I found an unexpected behavior in numpy.result_type (perhaps it is a bug in numpy.result_type?)

>>> numpy.result_type(numpy.int8, True)
dtype('int8')
>>> numpy.result_type(numpy.int8, 10)
dtype('int8')
>>> numpy.result_type(numpy.int8, True, 10)
dtype('int16')

torch.result_type should return int16 for (numpy.int8, True, 10) input, or int8?

@ezyang ezyang removed their request for review July 6, 2021 14:06
@mruberry
Copy link
Collaborator

mruberry commented Jul 7, 2021

@rgommers I found an unexpected behavior in numpy.result_type (perhaps it is a bug in numpy.result_type?)

>>> numpy.result_type(numpy.int8, True)
dtype('int8')
>>> numpy.result_type(numpy.int8, 10)
dtype('int8')
>>> numpy.result_type(numpy.int8, True, 10)
dtype('int16')

torch.result_type should return int16 for (numpy.int8, True, 10) input, or int8?

I think we should return int8 in this case. @rgommers?

@iramazanli iramazanli added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 8, 2021
@asi1024
Copy link
Contributor Author

asi1024 commented Jul 19, 2021

@rgommers Could you help me how to resolve the vmap and mypy CI failures?

@asi1024
Copy link
Contributor Author

asi1024 commented Jul 20, 2021

@mruberry @rgommers Now all tests have passed! Could you take another look?

inputs = [_convert_to_numpy_input(x) for x in inputs]
return np.result_type(*inputs)

for x1 in inputs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit:

for x0, x1, x2 in product(inputs * 3):

# if inputs of mixed dtypes are given. The tests for floating type inputs are left to
# `test_result_type` and `test_result_type_tensor_vs_scalar`.
dtypes = [torch.bool, torch.uint8, torch.int16, torch.int64]
inputs = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of comparing with NumPy what about implementing the reference using the addition operator?

def result_type_ref(args):
  assert len(args) > 0

  t = torch.tensor(True)
  for arg in args:
    if isinstance(arg, torch.dtype):
      # See comment below for a discussion of what type of object dtype should be represented by
      t = t + torch.tensor((1,), dtype=arg)
    else: 
      t = t + arg
  return t.dtype

I think this would allow for a simpler test structure and more thorough testing, plus it would compare with the ground truth of what PyTorch is doing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reference implementation using + sometimes returns unexpected results.

>>> a = torch.tensor([1], dtype=torch.int8)
>>> b = 1.
>>> c = torch.tensor([1], dtype=torch.float16)
>>> (a + b + c).dtype
torch.float32
>>> (a + c + b).dtype
torch.float16

Do you have some ideas for this issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's really cool!

And it's true, the where our dtypes work the addition operator isn't commutative. Both JAX and NumPy's result_type appears to handle this case correctly.

So I think my suggestion was mistaken. Sorry about that @asi1024. I didn't realize that result_type() can't be thought of as a series of binary elementwise operations. You are right and we can't use this as a reference implementation.

Returning to the previous tests could be a reasonable solution. If we want to extend that further we could add some "golden value" tests with manually generated checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After I thought further, I have an idea for a reference implementation.

def result_type_ref(args):
  assert len(args) > 0
  tensor = torch.tensor([True])
  tensor_0dim = torch.tensor(True)
  scalar = True

  for arg in args:
    if isinstance(arg, torch.dtype):
      tensor = tensor + torch.tensor([1], dtype=arg)
    elif isinstance(arg, torch.Tensor):
      if arg.ndim == 0:
        tensor_0dim = tensor_0dim + arg
      else:
        tensor = tensor + arg
    elif isinstance(arg, Number):
      scalar = scalar + arg
    else:
      assert False, "unknown type"
  return ((tensor + tensor_0dim) + scalar).dtype

It just uses the same logic as the result_type implementation in this PR. Can I use this implementation in my tests?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right, that's a clever approach, @asi1024. Let's try using this

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks great to me, @asi1024; I made a couple inline comments for your review

cc @heitorschueroff -- would you take a look at the functional implementation?
cc @bhosmer -- would you or someone else from the composability team like to look at the dispatch extension?

@ezyang ezyang self-requested a review July 22, 2021 14:04
@ezyang
Copy link
Contributor

ezyang commented Jul 22, 2021

the python arg parser changes look fine

@ezyang ezyang closed this Jul 22, 2021
@ezyang ezyang reopened this Jul 22, 2021
@@ -402,6 +402,39 @@ def _test_spot(a, b, res_dtype):
torch.tensor(1., dtype=torch.complex64, device=device), torch.complex128)
_test_spot(torch.tensor([1, 1], dtype=torch.bool, device=device), 1., torch.get_default_dtype())

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry Is this skip necessary? I thought we always include NumPy now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We require NumPy as a dependency when running the test suite

Copy link
Contributor

@heitorschueroff heitorschueroff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this PR looks great! Besides a few nit comments, the main thing we need to discuss is the behavior of scalar tensors. I think we should follow the array api standard and treat scalar tensors (0 dimensions) the same as every other tensor. And treat tensors with higher priority than python scalars. If this is the case, I think the logic in functional.py has to change a bit. @mruberry and @rgommers what do you think?

Thanks for this excellent PR.

if isinstance(x, int):
return np.int64(x)
return torch.tensor([1], dtype=x).numpy().dtype
inputs = [_convert_to_numpy_input(x) for x in inputs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: writing this as a regular for loop instead of defining a method just to be able to use a list comprehension is more readable.

torch/csrc/utils/python_arg_parser.cpp Show resolved Hide resolved
@@ -1589,3 +1590,49 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):

def align_tensors(*tensors):
raise RuntimeError('`align_tensors` not yet implemented.')

def result_type(*arrays_and_dtypes: Union[Tensor, torch.dtype, bool, int, float, complex]) -> torch.dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complex can be used to represent any complex, float and int in Python typing (see https://www.python.org/dev/peps/pep-0484/#the-numeric-tower).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More generically there's Number (https://docs.python.org/3/library/numbers.html), but this seems fine

@@ -1589,3 +1590,49 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):

def align_tensors(*tensors):
raise RuntimeError('`align_tensors` not yet implemented.')

def result_type(*arrays_and_dtypes: Union[Tensor, torch.dtype, bool, int, float, complex]) -> torch.dtype:
"""result_type(*arrays_and_dtypes: Union[Tensor, dtype, bool, int, float, complex]) -> dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the python array API standard mentions leaving out type annotations from signature and only including them in the parameter list (https://data-apis.org/array-api/latest/API_specification/data_type_functions.html?highlight=result_type#result-type-arrays-and-dtypes).

torch/functional.py Show resolved Hide resolved
>>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1))
torch.uint8
>>> torch.result_type(torch.int32, torch.float32)
torch.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an example including tensor, scalar and dtype?

Comment on lines 1614 to 1616
tensors = []
scalars: List[Union[bool, int, float, complex]] = []
dtypes = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: type annotation for the other lists

scalars.append(x)
elif isinstance(x, torch.dtype):
dtypes.append(x)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to elif isinstance(x, torch.Tensor) and add an else clause and raise an Error

if dtypes:
return _VF._result_type_dtypes(dtypes)
else:
raise TypeError("at least one argument is required.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change message to result_type(): must provide at least one argument

torch/functional.py Show resolved Hide resolved
@heitorschueroff
Copy link
Contributor

On second thought, if we are giving higher priority to tensors over python scalars such that if there is at least one tensor, python scalars do no affect the result. Then is there a use case for wanting to know the result_type between python scalars? such as torch.result_type(1, 2.0)? The parameter name itself says arrays_and_dtypes. I think that unless there has been request for it, we could drop the support for python scalars and that would also simplify the logic. @mruberry

@mruberry
Copy link
Collaborator

On second thought, if we are giving higher priority to tensors over python scalars such that if there is at least one tensor, python scalars do no affect the result.

Scalars do affect the choice of computation dtype because the scalar might have a higher type kind than the tensor. For example when adding a float tensor to an int tensor the result is a float tensor.

Then is there a use case for wanting to know the result_type between python scalars? such as torch.result_type(1, 2.0)? The parameter name itself says arrays_and_dtypes. I think that unless there has been request for it, we could drop the support for python scalars and that would also simplify the logic. @mruberry

Most PyTorch operations don't support exclusively scalar arguments but some do and it's a nice feature to support.

@mruberry
Copy link
Collaborator

Overall this PR looks great! Besides a few nit comments, the main thing we need to discuss is the behavior of scalar tensors. I think we should follow the array api standard and treat scalar tensors (0 dimensions) the same as every other tensor. And treat tensors with higher priority than python scalars. If this is the case, I think the logic in functional.py has to change a bit. @mruberry and @rgommers what do you think?

Thanks for this excellent PR.

We should do this but we don't today, so this PR correctly implements PyTorch's current behavior.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @asi1024! Thank you for your thoughtful responses. There are a few minor comments still inline waiting for your review, but per your analysis this look good overall.

Would you make a last pass and then ping me when you're happy to merge this?

@asi1024
Copy link
Contributor Author

asi1024 commented Aug 3, 2021

@mruberry
I found a bug in my result_type implementation and fixed it in 74bfae6.
CIs have passed excepts for one timeout. Could you take another look? 😃

tensors.append(x)
else:
raise TypeError(f"result_type(): cannot interpret '{x}' as a data type")
if dtypes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the length of these lists explicitly instead of just using if on them

10, # int scalar
10.0, # float scalar
10j, # complex scalar
*[torch.tensor([1, 2, 3], dtype=dtype) for dtype in dtypes], # tensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is always running on the CPU because it's ignoring the device arg.

Does the current implementation work with CUDA tensors? What if both CUDA and CPU tensors are passed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests for CUDA/mixed device tensors and confirmed that the current implementation passes the tests!

raise TypeError(f"result_type(): cannot interpret '{x}' as a data type")
if dtypes:
dtype = _VF._result_type_dtypes(dtypes)
tensors.append(torch.tensor([], dtype=dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modeling this as creating a tensor seems a little odd and I'm guessing this won't work if calling result_type() with CUDA tensors (currently CUDA inputs are not tested, see above comment).

scalar_dtype = _VF._result_type_scalars(scalars) # type: ignore[arg-type]
if tensors:
tensor_dtype = _VF.result_type(tensors) # type: ignore[attr-defined]
return _VF.result_type(_VF.tensor([], dtype=tensor_dtype), # type: ignore[attr-defined]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point it might be preferable to perform an addition like in the reference implementation

torch.float32
"""

tensors: List[torch.Tensor] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible alternative implementation of this, inspired by your proposed reference implementation above:

assert len(arrays_and_dtypes) > 0

tensors = [t for t in arrays_and_dtypes if isinstance(t, torch.Tensor)]
scalars = [s for s in arrays_and_dtypes if isinstance(t, Number)]
dtypes = [d for d in arrays_and_dtypes if isinstance(t, torch.dtype)]

tensor_dtype = None if len(tensors) == 0 else _VF._result_type(tensors)
scalar_dtype = None if len(scalars) == 0 else _VF._result_type_scalars(scalars)
dtype_dtype = None if len(dtypes) == 0 else _VF._result_type_dtypes(dtypes)

return (torch.tensor([], dtype=tensor_dtype) + torch.tensor([], dtype=dtype_dtype) + torch.tensor(0, dtype=scalar_dtype)).dtype

Follow-up question: will this or an implementation like the current version work with the jit scripting and tracing? Is there a test for that?

If the jit is still an issue let me know ASAP and I can help modify the PR to fix that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have confirmed that the jit scripting/tracing test works fine locally, but I have not added tests yet. In which file should I add the tests?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_jit.py

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry again for the wait, @asi1024. I made some inline comments. Your suggestion for a new reference implementation looks very good. Let me know if you have any issues, especially issues with scripting or tracing. If you do we can help modify the PR. Your logic is the important part.

@asi1024
Copy link
Contributor Author

asi1024 commented Sep 17, 2021

@mruberry Sorry I totally missed your review comments for long days 🙇
I confirmed that the current implementation works also with CUDA device tensors, but is it still preferable to rewrite it like the reference implementation?

@codecov
Copy link

codecov bot commented Sep 17, 2021

Codecov Report

Merging #61168 (9d53b03) into master (f69cf3c) will decrease coverage by 0.10%.
The diff coverage is 58.82%.

@@            Coverage Diff             @@
##           master   #61168      +/-   ##
==========================================
- Coverage   66.38%   66.28%   -0.11%     
==========================================
  Files         727      734       +7     
  Lines       93573    94079     +506     
==========================================
+ Hits        62117    62357     +240     
- Misses      31456    31722     +266     

@mruberry
Copy link
Collaborator

@mruberry Sorry I totally missed your review comments for long days 🙇 I confirmed that the current implementation works also with CUDA device tensors, but is it still preferable to rewrite it like the reference implementation?

And I missed your update! Sorry @asi1024! I'll take another look at the PR now. I don't know the answer to your question offhand.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a chance to read through and this looks pretty good, @asi1024! One of our new colleagues, @saketh-are, just started investigating PyTorch's type promotion and I'd like him to take a look, too.

@saketh-are
Copy link
Contributor

@asi1024 I had a chance to look through this and it looks good to me.

Just FYI I am working on a type promotion change after which 0-dim tensor operands and dimensioned tensor operands will be treated identically. I don't think this PR needs to change anywhere, though, assuming it's merged in first.

@facebook-github-bot
Copy link
Contributor

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry
Copy link
Collaborator

Unfortunately it looks like this is hitting an internal merge failure that will need further review. Sorry for the delay, @asi1024.

@github-actions
Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: python array api Issues related to the Python Array API open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

result_type doesn't take dtypes and doesn't match numpy
9 participants