Skip to content

Commit

Permalink
Add dtype checks for min/max where stubs are not called
Browse files Browse the repository at this point in the history
  • Loading branch information
imaginary-person committed Jan 14, 2021
1 parent cef471a commit 52dcc72
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/native/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indic
max_indices.device(), " for indices output");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
TORCH_CHECK(!self.is_complex(), "max does not support complex inputs.");
AT_ASSERT(max.dim() == 0);
max_indices.resize_({}).fill_(0);
return std::forward_as_tuple(max, max_indices);
Expand Down Expand Up @@ -387,6 +388,7 @@ static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min") &&
_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
TORCH_CHECK(!self.is_complex(), "min_max does not support complex inputs.");
return std::forward_as_tuple(min, max);
} else {
_aminmax_stub(self.device().type(), min, max, self, dim, keepdim);
Expand Down Expand Up @@ -418,6 +420,7 @@ static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indic
min_indices.device(), " for indices output");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
TORCH_CHECK(!self.is_complex(), "min does not support complex inputs.");
AT_ASSERT(min.dim() == 0);
min_indices.resize_({}).fill_(0);
return std::forward_as_tuple(min, min_indices);
Expand Down

0 comments on commit 52dcc72

Please sign in to comment.