Skip to content

Commit

Permalink
Remove unnecessary dtype checks for complex types & disable complex d…
Browse files Browse the repository at this point in the history
…ispatch for CPU min/max pointwise ops (#50465)

Summary:
Fixes #50064

**PROBLEM DESCRIPTION:**
1. Had not removed dtype checks for complex types in the previous PR (#50347) for this issue.
These type-checks were added in #36377, but are no longer necessary,
as we now rely upon dispatch macros to produce error messages.
2. dtype checks in `clamp_max()` and `clamp_min()` for complex inputs had not been removed either.
3. For min/max pointwise ops in TensorCompareKernel.cpp, complex dispatch had not been removed for min/max functions.

### **FIX DESCRIPTION:**
**FIX SUMMARY:**
1. Removed dtype checks added in #36377, and added 3 more in TensorCompare.cpp.
2. Removed dtype checks for complex inputs in `clamp_max()` and `clamp_min()`.
3.  Disabled complex dispatch for min/max pointwise ops in TensorCompareKernel.cpp.
4. Error messages in the exceptions raised due to min/max ops not being implemented are now checked for containing the text _not support_ (which can also be present in _not supported_), or _not implemented_, so one of them should be a part of error messages, in order for them to be informative.

**REASON FOR NOT CHANGING DISPATCH FOR CUDA AND CLAMP OPS**:

As for the CUDA min/max operations, their kernels do not seem to be compiled & dispatched for complex types anyway, so no further changes seem to be required. Basically, the dispatch macros currently being used don't have cases for complex types.

For example,

1. the reduce CUDA ops use [AT_DISPATCH_ALL_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h#L548-L575) in [ReduceMinMaxKernel.cu](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu), and that macro doesn't allow complex types.

2. In [MinMaxElementwiseKernel.cu](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu), the CUDA pointwise ops use [`AT_DISPATCH_FLOATING_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)`](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h#L240-L263) for non-integral & non-boolean types, and this marco doesn't have a case for complex types either.

3. [clamp CUDA ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/UnaryOpsKernel.cu#L170-L211) use `AT_DISPATCH_ALL_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)`, which doesn't have a case for complex types.

Similarly, [CPU clamp min/max ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp#L428-L458) use the `AT_DISPATCH_ALL_TYPES_AND `dispatch macro, which doesn't have a case for complex types.

**REASON FOR ADDING 3 dtype CHECKS:**
There are a few cases in which the methods corresponding to `min_stub()` or `max_stub()` are not called, so dispatch macros don't get invoked, resulting in no exceptions being raised. Hence, `dtype` checks are necessary at 3 places to raise exceptions:

1. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L342
2. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L422
3. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L389

The first dtype check requirement can be verified from the following example Python code based on `test_complex_unsupported()`:
```
import unittest
import torch

class MyTestCase(unittest.TestCase):

   def test_1(self):
      t = torch.tensor((1 + 1j), device='cpu', dtype=torch.complex128)
      with self.assertRaises(Exception):
         torch.max(t, dim=0)

if __name__ == '__main__':
    unittest.main()
```

Pull Request resolved: #50465

Reviewed By: mruberry

Differential Revision: D25938106

Pulled By: ngimel

fbshipit-source-id: 95e2df02ba8583fa3ce87d4a2fdcd60b912dda46
  • Loading branch information
imaginary-person authored and facebook-github-bot committed Jan 18, 2021
1 parent 1fdc35d commit 3f052ba
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 52 deletions.
8 changes: 0 additions & 8 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,16 +826,12 @@ Tensor logical_xor(const Tensor& self, Scalar other) { return comparison_op(self
Tensor& logical_xor_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); }

Tensor& maximum_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs.");

auto iter = TensorIterator::binary_op(result, self, other);
maximum_stub(iter.device_type(), iter);
return result;
}

Tensor maximum(const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs.");

Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
maximum_stub(iter.device_type(), iter);
Expand All @@ -852,16 +848,12 @@ Tensor max(const Tensor& self, const Tensor& other) {
}

Tensor& minimum_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs.");

auto iter = TensorIterator::binary_op(result, self, other);
minimum_stub(iter.device_type(), iter);
return result;
}

Tensor minimum(const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs.");

Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
minimum_stub(iter.device_type(), iter);
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/native/ReduceAllOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,20 @@ DEFINE_DISPATCH(max_all_stub);
DEFINE_DISPATCH(_aminmax_all_stub);

Tensor min(const Tensor &self) {
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
TORCH_CHECK(self.numel() > 0, "operation does not have an identity.");
Tensor result = at::empty({}, self.options());
min_all_stub(self.device().type(), result, self.contiguous());
return result;
}

Tensor max(const Tensor &self) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
TORCH_CHECK(self.numel() > 0, "operation does not have an identity.");
Tensor result = at::empty({}, self.options());
max_all_stub(self.device().type(), result, self.contiguous());
return result;
}

std::tuple<Tensor, Tensor> _aminmax_all(const Tensor &self) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
TORCH_CHECK(self.numel() > 0, "operation does not have an identity.");
Tensor min_result = at::empty({}, self.options());
Tensor max_result = at::empty({}, self.options());
Expand Down
15 changes: 3 additions & 12 deletions aten/src/ATen/native/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ std::tuple<Tensor &,Tensor &> mode_out(Tensor& values, Tensor& indices,
}

std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
Tensor max_indices = at::empty({0}, self.options().dtype(kLong));
if (self.is_quantized()) {
Tensor max = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
Expand All @@ -329,7 +328,6 @@ std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {

static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
"max only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
Expand All @@ -342,6 +340,7 @@ static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indic
max_indices.device(), " for indices output");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
TORCH_CHECK(!self.is_complex(), "max does not support complex inputs.");
AT_ASSERT(max.dim() == 0);
max_indices.resize_({}).fill_(0);
return std::forward_as_tuple(max, max_indices);
Expand All @@ -353,7 +352,6 @@ static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indic

std::tuple<Tensor&,Tensor&> max_out(Tensor& max, Tensor& max_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
auto result = [&]() {
NoNamesGuard guard;
return max_out_impl(max, max_indices, self, dim, keepdim);
Expand All @@ -364,7 +362,6 @@ std::tuple<Tensor&,Tensor&> max_out(Tensor& max, Tensor& max_indices,
}

std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
Tensor min_indices = at::empty({0}, self.options().dtype(kLong));
if (self.is_quantized()) {
Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
Expand All @@ -378,7 +375,6 @@ std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {

static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
"min_max_val only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
Expand All @@ -392,6 +388,7 @@ static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min") &&
_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
TORCH_CHECK(!self.is_complex(), "min_max does not support complex inputs.");
return std::forward_as_tuple(min, max);
} else {
_aminmax_stub(self.device().type(), min, max, self, dim, keepdim);
Expand All @@ -400,7 +397,6 @@ static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max
}

std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "min_max is not yet implemented for complex tensors.");
TORCH_CHECK(!self.is_quantized(), "min is not yet implemented for quantized tensors.");

Tensor min = at::empty({0}, self.options());
Expand All @@ -412,7 +408,6 @@ std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdi

static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
"min only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
Expand All @@ -425,6 +420,7 @@ static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indic
min_indices.device(), " for indices output");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
TORCH_CHECK(!self.is_complex(), "min does not support complex inputs.");
AT_ASSERT(min.dim() == 0);
min_indices.resize_({}).fill_(0);
return std::forward_as_tuple(min, min_indices);
Expand All @@ -436,7 +432,6 @@ static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indic

std::tuple<Tensor&,Tensor&> min_out(Tensor& min, Tensor& min_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
auto result = [&]() {
NoNamesGuard guard;
return min_out_impl(min, min_indices, self, dim, keepdim);
Expand All @@ -450,21 +445,17 @@ std::tuple<Tensor&,Tensor&> min_out(Tensor& min, Tensor& min_indices,
// Named tensor overloads

std::tuple<Tensor, Tensor> min(const Tensor& self, Dimname dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
return at::min(self, dimname_to_position(self, dim), keepdim);
}
std::tuple<Tensor &,Tensor &> min_out(Tensor& min, Tensor& min_indices,
const Tensor& self, Dimname dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
return at::min_out(min, min_indices, self, dimname_to_position(self, dim), keepdim);
}
std::tuple<Tensor, Tensor> max(const Tensor& self, Dimname dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
return at::max(self, dimname_to_position(self, dim), keepdim);
}
std::tuple<Tensor &,Tensor &> max_out(Tensor& max, Tensor& max_indices,
const Tensor& self, Dimname dim, bool keepdim) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
return at::max_out(max, max_indices, self, dimname_to_position(self, dim), keepdim);
}
Tensor argmax(const Tensor& self, Dimname dim, bool keepdim) {
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,6 @@ Tensor signbit(const Tensor& self) {
}

Tensor& clamp_out(Tensor& result, const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
if (min && max) {
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp only supports strided layout, got: ", self.layout());
Expand All @@ -575,7 +574,6 @@ Tensor& clamp_(Tensor& self, optional<Scalar> min, optional<Scalar> max) {
}

Tensor& clamp_max_out(Tensor& result, const Tensor& self, Scalar max) {
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp_max only supports strided layout, got: ", self.layout());
auto iter = TensorIterator::unary_op(result, self);
Expand All @@ -593,7 +591,6 @@ Tensor& clamp_max_(Tensor& self, Scalar max) {
}

Tensor& clamp_min_out(Tensor& result, const Tensor& self, Scalar min) {
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp_min only supports strided layout, got: ", self.layout());
auto iter = TensorIterator::unary_op(result, self);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/TensorCompareKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ static void min_kernel_impl(
TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong,
"Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type());

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
compare_base_kernel<scalar_t>(result, indice, self, wrap_dim, keepdim, [&] (
scalar_t* result_data, int64_t* indice_data,
const scalar_t* self_data, auto self_dim_stride) {
Expand Down Expand Up @@ -118,7 +118,7 @@ static void max_kernel_impl(
TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong,
"Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type());

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
compare_base_kernel<scalar_t>(result, indice, self, wrap_dim, keepdim, [&] (
scalar_t* result_data, int64_t* indice_data,
const scalar_t* self_data, auto self_dim_stride) {
Expand Down
4 changes: 2 additions & 2 deletions test/test_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,11 +1067,11 @@ def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
def test_maximum_minimum_complex(self, device, dtypes):
for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min):
with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'):
with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'):
torch_op(torch.ones(1, device=device, dtype=dtypes[0]),
torch.ones(1, device=device, dtype=dtypes[1]))

with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'):
with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'):
torch_op(torch.ones(1, device=device, dtype=dtypes[1]),
torch.ones(1, device=device, dtype=dtypes[0]))

Expand Down
63 changes: 41 additions & 22 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5984,60 +5984,79 @@ def test_complex_unsupported(self, device, dtype):
# Note: whether PyTorch should support min and max on complex
# tensors is an open question.
# See https://github.com/pytorch/pytorch/issues/36374
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.min(t)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
t.min()
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.min(t, dim=0)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.min(t, t)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.min(t, t, out=t)

with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.max(t)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
t.max()
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.max(t, dim=0)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.max(t, t)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.max(t, t, out=t)

with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.amin(t)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
t.amin()
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.amin(t, dim=0)

with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.amax(t)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
t.amax()
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.amax(t, dim=0)

# Tests _aminmax() variants with complex inputs,
# which are currently not supported due to min & max being unsupported
# for complex inputs, as per https://github.com/pytorch/pytorch/issues/36374
# Test with a single-element tensor t, as well as a multi-element tensor x
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
min_val, max_val = torch._aminmax(t)
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
min_val = torch._aminmax(t, dim=0)[0]
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
max_val = torch._aminmax(t, dim=0)[1]
# Test _aminmax() with a multi-element tensor
x = torch.tensor([(1 + 1j), (2 + 3j)], device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
min_val, max_val = torch._aminmax(x)
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
min_val = torch._aminmax(x, dim=0)[0]
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
max_val = torch._aminmax(x, dim=0)[1]

# Tests clamp variants with complex inputs
# Note: whether PyTorch should support clamp on complex
# tensors is an open question.
# See https://github.com/pytorch/pytorch/issues/33568
min_val = 1 + 1j
max_val = 4 + 4j
out = torch.empty((0,), device=device, dtype=dtype)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.clamp(t, min=min_val)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.clamp(t, max=max_val)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.clamp(t, min_val, max_val)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.clamp(t, min=min_val, out=out)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.clamp(t, max=max_val, out=out)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
torch.clamp(t, min_val, max_val, out=out)

def test_pickle_gradscaler(self, device):
Expand Down

0 comments on commit 3f052ba

Please sign in to comment.