Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for boolean values as argument on pow function. #114133

Closed

Conversation

pdrocaldeira
Copy link
Contributor

@pdrocaldeira pdrocaldeira commented Nov 20, 2023

Hello everyone! 😄
Also @lezcano , nice to meet you! :)

Sorry if I miss anything, this is my first time around here. 🙃

This PR basically makes the same behaviour for cuda when using torch.pow. Basically Python considers True as 1 and False as 0. I just added this check into pow function. From what I understood, when I do .equal for Scalar that is boolean, I'm sure that types match so that won't cause more trouble.

I know that the issue suggest to disable this case but that could be a little more complicated, in my humble opinion. And that can create some compability problems too, I guess.

My argument is that code below is correct for native language, so I guess it does makes sense sending booleans as Scalar.

$ x = True
$ x + x
2

This was my first test:

Python 3.12.0 | packaged by Anaconda, Inc. | (main, Oct  2 2023, 17:29:18) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.pow(torch.tensor([1, 2], device='cuda'), True)
tensor([1, 2], device='cuda:0')
>>> torch.pow(torch.tensor([1, 2]), True)
tensor([1, 2])
>>> torch.pow(torch.tensor([1, 2]), False)
tensor([1, 1])
>>> torch.pow(torch.tensor([1, 2], device='cuda'), False)
tensor([1, 1], device='cuda:0')

I've run test_torch.py and got following results, so my guess is that I didn't break anything. I was just looking for a test that uses linear regression, as suggested.

Ran 1619 tests in 52.363s

OK (skipped=111)
[TORCH_VITAL] Dataloader.enabled		 True
[TORCH_VITAL] Dataloader.basic_unit_test		 TEST_VALUE_STRING
[TORCH_VITAL] CUDA.used		 true

(I can paste whole log, if necessary)

If this is a bad idea overall, dont worry about it. It's not a big deal, it's actually a two line change 😅 so can we talk of how do things in a different strategy.

For the record I've signed the agreement already. And I didn't run linter because it's not working 😞 . Looks like PyYaml 6.0 is broken and there's a 6.0.1 fix already but I have no idea how to update that 😅

Fixes #113198

…the same as 1 and False should be the same as 0 as exponent argument
Copy link

pytorch-bot bot commented Nov 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114133

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0d4a89c with merge base 7d5e8c1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@lezcano
Copy link
Collaborator

lezcano commented Nov 20, 2023

This PR needs tests

The tests you want to run are in test/test_ops.py. Some will fail, you'll need to change those. Those tests are autogenerated. You may want to grep for OpInfo to find how they are autogenerated. If nothing fails, you'll need to add a sample to the relevant OpInfo that fails without this PR and passes with this PR.

@lezcano
Copy link
Collaborator

lezcano commented Nov 20, 2023

Also note that this PR is doing the opposite of what we proposed to do in the issue, which is always fail on x ** bool...

@pdrocaldeira
Copy link
Contributor Author

pdrocaldeira commented Nov 20, 2023

Also note that this PR is doing the opposite of what we proposed to do in the issue, which is always fail on x ** bool...

Yeah, you're absolutely correct. I did try to do what you asked but it was way more complicated but I wanted to bring this to the table because it was an easy change that creates a standard behaviour between everything considering standard Python behaviour.

But, of course, if you don't think this is a good idea I won't mind just closing this and moving forward with your suggestion. You just got to say the words :) [I know that you might have said that already but I'm not sure]


If I understood code correctly to disable bool I would have to change native_functions.yaml to instead of having a Scalar argument I would need to overload to every other case that isn't bool. Because from what I'm seeing, Scalar can be bool. And the consequence is that when we run torch.pow(torch.tensor([1, 2], device='cuda'), True) this function won't even exist, right?

EDIT: Or I could test for types inside CPP function and for [int, bool] I could throw an exception/error. If you have an example in mind it could helpful in this case. 💪🏾

I just found this:

  TORCH_CHECK(!(isIntegralType(base.scalar_type(), true) &&
              exp.isIntegral(true) && exp.toLong() < 0),
              "Integers to negative integer powers are not allowed.");

Perhaps I could adapt this into pow and disable [int, bool] for everyone. What you think?


The tests you want to run are in test/test_ops.py. Some will fail, you'll need to change those. Those tests are autogenerated. You may want to grep for OpInfo to find how they are autogenerated. If nothing fails, you'll need to add a sample to the relevant OpInfo that fails without this PR and passes with this PR.

Noted! Regardless of any strategy I think that's a good test, thank you!


Thanks for such fast response, I really appreciate your input! 😄

@lezcano
Copy link
Collaborator

lezcano commented Nov 20, 2023

You just need to TORCH_CHECK that the input is not boolean, no?

At any rate, I'm fine with making the behaviour equal to the current behaviour of CPU or CUDA, whatever works as long as it's tested.

@bdhirsh bdhirsh requested a review from lezcano November 21, 2023 22:53
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 21, 2023
@pdrocaldeira
Copy link
Contributor Author

pdrocaldeira commented Nov 22, 2023

======================================================================
ERROR: test_python_ref_meta__refs_linalg_svd_cpu_complex128 (__main__.TestCommonCPU.test_python_ref_meta__refs_linalg_svd_cpu_complex128)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 428, in instantiated_test
    raise rte
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 908, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 1120, in only_fn
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 1299, in wrapper
    fn(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/test/test_ops.py", line 347, in test_python_ref_meta
    prims.utils.compare_tensor_meta(a, b)
  File "/home/pedro/workspace/pytorch/torch/_prims_common/__init__.py", line 166, in compare_tensor_meta
    raise RuntimeError(
RuntimeError: Conj mismatch! is_conj is set to False and True

To execute this test, run the following from the base repo dir:
     python test/test_ops.py -k test_python_ref_meta__refs_linalg_svd_cpu_complex128

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_python_ref_meta__refs_linalg_svd_cpu_complex64 (__main__.TestCommonCPU.test_python_ref_meta__refs_linalg_svd_cpu_complex64)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 428, in instantiated_test
    raise rte
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 908, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 1120, in only_fn
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 1299, in wrapper
    fn(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/test/test_ops.py", line 347, in test_python_ref_meta
    prims.utils.compare_tensor_meta(a, b)
  File "/home/pedro/workspace/pytorch/torch/_prims_common/__init__.py", line 166, in compare_tensor_meta
    raise RuntimeError(
RuntimeError: Conj mismatch! is_conj is set to False and True

To execute this test, run the following from the base repo dir:
     python test/test_ops.py -k test_python_ref_meta__refs_linalg_svd_cpu_complex64

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
FAIL: test_complex_half_reference_testing_fft_hfft2_cuda_complex32 (__main__.TestCommonCUDA.test_complex_half_reference_testing_fft_hfft2_cuda_complex32)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 945, in dep_fn
    return fn(slf, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 908, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/test/test_ops.py", line 1232, in test_complex_half_reference_testing
    self.assertEqual(actual, expected, exact_dtype=False)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 3478, in assertEqual
    raise error_metas.pop()[0].to_error(
AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 256 (0.4%)
Greatest absolute difference: 0.07051944732666016 at index (0, 4, 10) (up to 0.04 allowed)
Greatest relative difference: 0.36067742109298706 at index (0, 4, 10) (up to 0.04 allowed)

To execute this test, run the following from the base repo dir:
     python test/test_ops.py -k test_complex_half_reference_testing_fft_hfft2_cuda_complex32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 60396 tests in 6626.686s

FAILED (failures=1, errors=2, skipped=10863, expected failures=831)

I tested both my branch and main and both have these three failing errors for test/test_ops.py


Three questions:

  • Do I need to rebase?
  • What should I write for release notes?
  • Is there something else to test?

Thanks a lot, see you around!


(It took me a while because I didnt have scipy and there were more failing because of that [and also because my cpu is slow 😅 ])

@lezcano
Copy link
Collaborator

lezcano commented Nov 22, 2023

You need to add a sample to the op info that tests this case. One that fails in main and passes with this PR

Co-authored-by: Gary Yao <garyyaoresearch@gmail.com>
@pdrocaldeira
Copy link
Contributor Author

You need to add a sample to the op info that tests this case. One that fails in main and passes with this PR

Duh! That was obvious, sorry about that 😅

Yes, about OpInfo I went ahead and read the comments/instructions about it to understand how they work. But I knew there should be something written already and then I found it 😄

But it was good info to look for OpInfo, only after that I could find about binary_ufunc 😉

I just changed test/test_binary_ufuncs.py and its already existing pow test and added both True and False on the list of testing exponents.

Then, when I run the main branch I will have:

======================================================================
ERROR: test_pow_cuda_int16 (__main__.TestBinaryUfuncsCUDA.test_pow_cuda_int16)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 428, in instantiated_test
    raise rte
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1362, in test_pow
    self._do_pow_for_exponents(m1, exponents, math.pow, None)
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1310, in _do_pow_for_exponents
    res1 = torch.pow(m1[4], num)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "/home/pedro/workspace/pytorch/aten/src/ATen/native/cuda/PowKernel.cu":199, please report a bug to PyTorch. invalid combination of type in Pow function, common dtype:Shortexp is integral?0

To execute this test, run the following from the base repo dir:
     python test/test_binary_ufuncs.py -k test_pow_cuda_int16

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_pow_cuda_int32 (__main__.TestBinaryUfuncsCUDA.test_pow_cuda_int32)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 428, in instantiated_test
    raise rte
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1362, in test_pow
    self._do_pow_for_exponents(m1, exponents, math.pow, None)
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1310, in _do_pow_for_exponents
    res1 = torch.pow(m1[4], num)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "/home/pedro/workspace/pytorch/aten/src/ATen/native/cuda/PowKernel.cu":199, please report a bug to PyTorch. invalid combination of type in Pow function, common dtype:Intexp is integral?0

To execute this test, run the following from the base repo dir:
     python test/test_binary_ufuncs.py -k test_pow_cuda_int32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_pow_cuda_int64 (__main__.TestBinaryUfuncsCUDA.test_pow_cuda_int64)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 428, in instantiated_test
    raise rte
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1362, in test_pow
    self._do_pow_for_exponents(m1, exponents, math.pow, None)
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1310, in _do_pow_for_exponents
    res1 = torch.pow(m1[4], num)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "/home/pedro/workspace/pytorch/aten/src/ATen/native/cuda/PowKernel.cu":199, please report a bug to PyTorch. invalid combination of type in Pow function, common dtype:Longexp is integral?0

To execute this test, run the following from the base repo dir:
     python test/test_binary_ufuncs.py -k test_pow_cuda_int64

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_pow_cuda_int8 (__main__.TestBinaryUfuncsCUDA.test_pow_cuda_int8)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 428, in instantiated_test
    raise rte
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1362, in test_pow
    self._do_pow_for_exponents(m1, exponents, math.pow, None)
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1310, in _do_pow_for_exponents
    res1 = torch.pow(m1[4], num)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "/home/pedro/workspace/pytorch/aten/src/ATen/native/cuda/PowKernel.cu":199, please report a bug to PyTorch. invalid combination of type in Pow function, common dtype:Charexp is integral?0

To execute this test, run the following from the base repo dir:
     python test/test_binary_ufuncs.py -k test_pow_cuda_int8

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_pow_cuda_uint8 (__main__.TestBinaryUfuncsCUDA.test_pow_cuda_uint8)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 428, in instantiated_test
    raise rte
  File "/home/pedro/workspace/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1362, in test_pow
    self._do_pow_for_exponents(m1, exponents, math.pow, None)
  File "/home/pedro/workspace/pytorch/test/test_binary_ufuncs.py", line 1310, in _do_pow_for_exponents
    res1 = torch.pow(m1[4], num)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "/home/pedro/workspace/pytorch/aten/src/ATen/native/cuda/PowKernel.cu":199, please report a bug to PyTorch. invalid combination of type in Pow function, common dtype:Byteexp is integral?0

To execute this test, run the following from the base repo dir:
     python test/test_binary_ufuncs.py -k test_pow_cuda_uint8

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 25675 tests in 180.119s

FAILED (errors=5, skipped=1679, expected failures=66)

But for this branch it runs fine:

Ran 25675 tests in 180.793s

OK (skipped=1679, expected failures=66)


I added @garyyaoresearch as co author as he gave me the courtesy to mention that he was working on that too. We been changing mails since then :)

@lezcano lezcano added the topic: not user facing topic category label Nov 22, 2023
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough

@lezcano
Copy link
Collaborator

lezcano commented Nov 22, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 22, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xunsongh pushed a commit to xunsongh/pytorch that referenced this pull request Nov 24, 2023
Hello everyone! 😄
Also @lezcano , nice to meet you! :)

Sorry if I miss anything, this is my first time around here. 🙃

This PR basically makes the same behaviour for cuda when using `torch.pow`. Basically Python considers True as 1 and False as 0. I just added this check into `pow` function. From what I understood, when I do `.equal` for `Scalar` that is boolean, I'm sure that types match so that won't cause more trouble.

I know that the issue suggest to disable this case but that could be a little more complicated, in my humble opinion. And that can create some compability problems too, I guess.

My argument is that code below is correct for native language, so I guess it does makes sense sending booleans as Scalar.

```
$ x = True
$ x + x
2
```

This was my first test:
```
Python 3.12.0 | packaged by Anaconda, Inc. | (main, Oct  2 2023, 17:29:18) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.pow(torch.tensor([1, 2], device='cuda'), True)
tensor([1, 2], device='cuda:0')
>>> torch.pow(torch.tensor([1, 2]), True)
tensor([1, 2])
>>> torch.pow(torch.tensor([1, 2]), False)
tensor([1, 1])
>>> torch.pow(torch.tensor([1, 2], device='cuda'), False)
tensor([1, 1], device='cuda:0')
```

I've run `test_torch.py` and got following results, so my guess is that I didn't break anything. I was just looking for a test that uses linear regression, as suggested.

```
Ran 1619 tests in 52.363s

OK (skipped=111)
[TORCH_VITAL] Dataloader.enabled		 True
[TORCH_VITAL] Dataloader.basic_unit_test		 TEST_VALUE_STRING
[TORCH_VITAL] CUDA.used		 true

```
(I can paste whole log, if necessary)

If this is a bad idea overall, dont worry about it. It's not a big deal, it's actually a two line change 😅  so can we talk of how do things in a different strategy.

For the record I've signed the agreement already. And I didn't run linter because it's not working 😞 . Looks like PyYaml 6.0 is broken and there's a 6.0.1 fix already but I have no idea how to update that 😅

Fixes pytorch#113198

Pull Request resolved: pytorch#114133
Approved by: https://github.com/lezcano
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Different behaviour of pow(tensor, bool) for CPU vs CUDA
5 participants