Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex dispatch should be disabled on min/max functions #50064

Closed
ngimel opened this issue Jan 4, 2021 · 5 comments
Closed

Complex dispatch should be disabled on min/max functions #50064

ngimel opened this issue Jan 4, 2021 · 5 comments
Labels
good first issue module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ngimel
Copy link
Collaborator

ngimel commented Jan 4, 2021

As of #36377, min/max functions were disabled for complex inputs (via dtype checks), however, min/max kernels are still compiled and dispatched for complex, see e.g.

AT_DISPATCH_ALL_TYPES_AND_COMPLEX(input.scalar_type(), "min_all", [&] {

This dispatch should be disabled, and we should rely on errors produced by dispatch macro to not run those ops on complex, instead of doing redundant dtype checks.

cc @ezyang @anjali411 @dylanbespalko @mruberry
cc @t-vi, thanks for reporting!

@ngimel ngimel added module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 4, 2021
@rishabhvarshney14
Copy link

Can I work on it @anjali411?

@anjali411
Copy link
Contributor

@rishabhvarshney14 yes absolutely! please feel free to add me as a reviewer for your PR

@ngimel
Copy link
Collaborator Author

ngimel commented Jan 12, 2021

Reopening. The linked PR only disables dispatch on cpu reductions, it does not handle pointwise operations, and it does not handle cuda operations.
I also asked for redundant type checks to be removed (because dispatch would produce necessary error messages), and they are still in place.

@imaginary-person
Copy link
Contributor

imaginary-person commented Jan 13, 2021

Reopening. The linked PR only disables dispatch on cpu reductions, it does not handle pointwise operations, and it does not handle cuda operations.
I also asked for redundant type checks to be removed (because dispatch would produce necessary error messages), and they are still in place.

Hello @ngimel, Sorry about that! l had realized that I hadn't disabled the type-checks introduced in #36377, after Richard suggested writing a test, and I recently submitted another PR (#50465). For CPU min/max ops in TensorCompareKernel.cpp, I had skipped it by mistake due to some confusion between the names of max_stub & maximum_stub. So, I changed them in this PR.

As for the CUDA operations, their kernels do not seem to be compiled & dispatched for complex types anyway, so I think that no further changes are required. Please confirm if my rationale is correct. Thanks!

Basically, the dispatch macros currently being used don't have cases (of switch statements) for complex types, so they'd default to error for complex types.
For example,

  1. the reduce CUDA ops use AT_DISPATCH_ALL_TYPES_AND2 in ReduceMinMaxKernel.cu, and that macro doesn't allow complex types.

  2. In MinMaxElementwiseKernel.cu, the CUDA pointwise ops use AT_DISPATCH_FLOATING_TYPES_AND2 for non-integral & non-boolean types, and this marco doesn't have a case for complex types either.

@imaginary-person
Copy link
Contributor

Duplicate comment from #50465 for the convenience of any potential readers in the future. Please skip if you read the PR #50465.

There are a few cases in which the methods corresponding to min_stub() or max_stub() are not called, so dispatch macros don't get invoked, resulting in no exceptions being raised. Hence, dtype checks are necessary at 3 places:

  1. if (_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
  2. if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
  3. if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min") &&

The first dtype check requirement can be verified from the following example Python code:

import unittest 
import torch

class MyTestCase(unittest.TestCase): 
  
   def test_1(self):
      t = torch.tensor((1 + 1j), device='cpu', dtype=torch.complex128) 
      with self.assertRaises(RuntimeError): 
         torch.max(t, dim=0)
  
if __name__ == '__main__':  
    unittest.main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
4 participants