Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding dtype argument to the Unary Ops for dtype promotion (testing on expm1 function) #33063

Closed
wants to merge 10 commits into from

Conversation

krshrimali
Copy link
Contributor

This PR is currently WIP and is intended for discussion on adding dtype argument for the type promotion support of Unary Ops. This PR:

  1. Adds expm1 overload with dtype argument.
  2. Modifies TensorIterator::unary_op to have promoting argument (default to false, CommonDTypeStrategy::CHECK) - if true, then changes the flag to PROMOTE.
  3. Modifies the existing helper functions for Unary Ops (unary_op_impl and unary_op_impl_out).

The objective is to allow the Unary Ops to have a dtype argument, to allow the internal promotion logic to promote the dtype if required (and allowed). The question is, if these changes have any unexpected (and unwanted) effect on the current flow of PyTorch.

cc: @mcarilli @nairbv

@krshrimali krshrimali changed the title [WIP] Adding dtype argument to the Unary Ops for dtype promotion (testing on expm1 function) [WIP] Adding dtype argument to the Unary Ops for dtype promotion (testing on expm1 function) Feb 6, 2020
@dr-ci
Copy link

dr-ci bot commented Feb 6, 2020

💊 CircleCI build failures summary and remediations

As of commit 7aac2ac:

  • 1/3 broken upstream at merge base b0476dc since Feb 06

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch origin viable/strict
    git rebase --onto viable/strict $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch origin viable/strict
    git rebase viable/strict
    

    Check out the recency history of this "viable master" tracking branch.

  • 2/3 failures introduced in this PR

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🕵️ 2 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakage:

See CircleCI build caffe2_onnx_py2_gcc5_ubuntu16_04_test (1/2)

Step: "Test" (full log | pattern match details)

Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_squeezenet FAILED [ 96%]
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_rsqrt PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_rsub PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_scalar_type PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_scatter PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_select PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_size PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_softmax PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_softmax_dtype PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_sqrt PASSED [ 96%] 
Feb 11 20:13:37 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_squeeze PASSED [ 96%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_squeezenet FAILED [ 96%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_srresnet SKIPPED [ 96%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_std PASSED [ 96%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_std_along_dims PASSED [ 96%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_subconstant PASSED [ 97%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_sum PASSED [ 97%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_super_resolution SKIPPED [ 97%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_tensor_factories PASSED [ 97%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_tensor_factories_script PASSED [ 97%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_tensor_index_1d PASSED [ 97%] 
Feb 11 20:13:38 test/onnx/test_pytorch_onnx_caffe2.py::TestCaffe2Backend_opset9::test_tensor_index_2d_1dconstant PASSED [ 97%] 

See CircleCI build pytorch_linux_xenial_py2_7_9_test (2/2)

Step: "Test" (full log | pattern match details)

Feb 11 20:25:17 RuntimeError: test_nn failed!
Feb 11 20:25:17   File "/opt/python/2.7.9/lib/python2.7/site-packages/xmlrunner/runner.py", line 7, in <module> 
Feb 11 20:25:17     from .result import _XMLTestResult 
Feb 11 20:25:17   File "/opt/python/2.7.9/lib/python2.7/site-packages/xmlrunner/result.py", line 42, in <module> 
Feb 11 20:25:17     for (low, high) in _illegal_unichrs 
Feb 11 20:25:17 ValueError: chr() arg not in range(256) 
Feb 11 20:25:17 Traceback (most recent call last): 
Feb 11 20:25:17   File "test/run_test.py", line 487, in <module> 
Feb 11 20:25:17     main() 
Feb 11 20:25:17   File "test/run_test.py", line 480, in main 
Feb 11 20:25:17     raise RuntimeError(message) 
Feb 11 20:25:17 RuntimeError: test_nn failed! 
Feb 11 20:25:17 =================== sccache compilation log =================== 
Feb 11 20:25:17 + cleanup 
Feb 11 20:25:17 + retcode=1 
Feb 11 20:25:17 + set +x 
Feb 11 20:25:17 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 11 20:25:17 Compile requests                 10 
Feb 11 20:25:17 Compile requests executed         8 
Feb 11 20:25:17 Cache hits                        2 
Feb 11 20:25:17 Cache misses                      6 
Feb 11 20:25:17 Cache timeouts                    0 

🚧 1 upstream failure recognized by patterns:

These builds matched patterns, but were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 6 times.

@@ -1210,6 +1210,10 @@
supports_named_tensor: True
variants: function, method

- func: expm1.dtype(Tensor self, ScalarType dtype) -> Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work as an additional optional argument on the existing signature? e.g. func: expm1(Tensor self, *, ScalarType? dtype=None)

Copy link
Contributor Author

@krshrimali krshrimali Feb 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, no! I have been trying this for a while now, but it always hits a roadblock, and eventually we'll have to change the functions name (not overloading) and maybe even add more helper functions. This seems messy to me at first.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As-written this allows calls like expm1(self, dtype) (positionally) which wouldn't be allowed by the equivalent numpy api.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be made keyword only argument (by putting the argument after * in native_functions.yaml) without needing an optional argument on the existing signature. I'll make a commit with a keyword-only approach and ping once done.

I'm still trying with optional argument, but it's gonna change a lot of signatures in the yaml file. I shall post the exact errors I get once am done with the possible combinations of modifications I've in my mind.

@krshrimali
Copy link
Contributor Author

Some additional comments on using iter.promote_common_dtype() and it's impacts on the current flow:

  1. The usage of iter.promote_common_dtype() is intended to allow the internal dtype promotion logic to take place for the supported dtypes.
  2. The requests of dtype promotion for unsupported dtypes still throw a RuntimeError: result type Half can't be cast to the desired output type Int error (taking an example of promoting input tensor of dtype torch.float16 to torch.int on our torch.expm1 function). This is because of the TORCH_CHECK(canCast(common_dtype, op.current_dtype), "result type ", common_dtype, " can't be cast to the desired output type ", op.current_dtype), which is called if the strategy is set to either CHECK or PROMOTE.

I don't see, so far to my knowledge, any possible breaking to the current flow with using this approach. Although, I may be wrong and am open to any possible challenges to this approach.

@@ -1210,7 +1210,7 @@
supports_named_tensor: True
variants: function, method

- func: expm1.dtype(Tensor self, ScalarType dtype) -> Tensor
- func: expm1.dtype(Tensor self, *, ScalarType dtype) -> Tensor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv - This enables keyword-only argument feature for expm1 and tests OK. Would like to hear your comments on this.

  1. This enables dtype handling, but as I mentioned before - this does not change any promotion logic for the ops, except allowing the dtype promotion for supported ops with supported dtypes.

For example:

x = torch.tensor(2.3, dtype=torch.float16, device="cuda")
output = torch.expm1(x, dtype=torch.float32)
print(output, output.dtype)

Outputs: tensor(8.9820, device="cuda:0"), torch.float32 (expected)

While:

  • torch.expm1(x, torch.float32) # not kwonly friendly gives an error (expected):
TypeError: expm1() received an invalid combination of arguments - got (Tensor, torch.dtype), but expected one of:
 * (Tensor input, Tensor out)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
 * (Tensor input, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
  • Trying to promote fp16 to int dtype also throws an error (expected): RuntimeError: result type Half can't be cast to the desired output type Int.
  1. The out=... handling for dtype promotion is absent for now. This can be one of the TODOs, depending on your comments & user requirements.

For example, doing: torch.expm1(x, out=torch.zeros((), dtype=torch.float32)) throws a RuntimeError: RuntimeError: expected dtype Float but got dtype Half. This is because the function called when out=... is passed does not go through our call chain with dtype argument.

Will appreciate your review & comments on this!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enables keyword-only argument feature for expm1 and tests OK. Would like to hear your comments on this.

Yes I think the , *, syntax is the correct function definition, though I think we'll want ScalarType? dtype=None so that it is optional. This matches torch.mean and prod, as well as numpy.expm1.

This enables dtype handling, but as I mentioned before - this does not change any promotion logic for the ops, except allowing the dtype promotion for supported ops with supported dtypes.

I think we'll expect the behavior to conform to how these options are used in other functions and/or in numpy.

In the docs for ops like torch.prod/sum/mean:
dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

TypeError: expm1() received an invalid combination of arguments - got (Tensor, torch.dtype)

This looks correct. dtype as a positional arg shouldn't be supported here.

doing: torch.expm1(x, out=torch.zeros((), dtype=torch.float32)) throws a RuntimeError

I think we'll need to support this. We'll have to ensure we've set appropriate flags to make the correct checks in validate_dtype etc in TensorIterator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv - Thanks for the comments. Binary Ops use PROMOTE flag by default and is there any reason why Unary Ops should not use the PROMOTE flag by default? As far as we can tell, the current solution works fine for tensors on both CPU and CUDA devices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Binary ops iterate through a list containing their two operands to determine a common output type (using type promotion rules implemented in at::result_type).

You're modifying a unary op (only one operand, no list), so it seems unclear what behavior you're trying to control through use of this flag.

@@ -95,7 +95,11 @@ Tensor& ceil_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(
Tensor ceil(const Tensor& self) { return unary_op_impl(self, at::ceil_out); }
Tensor& ceil_(Tensor& self) { return unary_op_impl_(self, at::ceil_out); }

Tensor& expm1_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, expm1_stub); }
Tensor& expm1_out(Tensor& result, const Tensor& self) {
if (result.defined()) return unary_op_impl_out(result, self, expm1_stub, /*strategy_promote=*/ true);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enables dtype prmotion support for out=... arguments to the expm1 function. Sample outputs:

input: torch.expm1(x /*dtype=torch.float16*/, out = torch.zeros((), dtype=torch.float32, device="cuda")).dtype
output: float32

input: torch.expm1(x /*dtype=torch.float16*/, out = torch.zeros((), dtype=torch.double, device="cuda")).dtype
output: float64

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result.defined() is always going to be true at this point, I think.

auto iter = TensorIterator();
iter.set_check_mem_overlap(check_mem_overlap);
iter.add_output(out);
iter.add_input(a);
iter.num_outputs_ = 1;
if(promoting == true) iter.promote_common_dtype();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

promote_common_dtype implies we're finding the common dtype across all operands and setting the output dtype based on that. These are unary ops, so this doesn't look right.

Also, nit, don't need to check ==true when using a boolean in a condition.

@krshrimali
Copy link
Contributor Author

Thank you for your help on this PR, @mcarilli @nairbv - really appreciate the time, reviews and patience. Closing this PR, apologies that this couldn't make it to upstream.

@krshrimali krshrimali closed this Jun 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants