Skip to content

Commit

Permalink
Update on "Actually run backward criterion tests."
Browse files Browse the repository at this point in the history
These were (apparently) accidentally turned off two years ago in #9287.

Differential Revision: [D23408987](https://our.internmc.facebook.com/intern/diff/D23408987)

[ghstack-poisoned]
  • Loading branch information
gchanan committed Aug 31, 2020
2 parents bb2c7f4 + 79a8c14 commit 5937951
Show file tree
Hide file tree
Showing 38 changed files with 898 additions and 291 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ jobs:
--verbose \
--paths torch/csrc/ \
--diff "$MERGE_BASE" \
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp"\
-g"-torch/csrc/jit/serialization/onnx.cpp" \
-g"-torch/csrc/jit/serialization/export.cpp" \
-g"-torch/csrc/jit/serialization/import.cpp" \
-g"-torch/csrc/jit/serialization/import_legacy.cpp" \
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ namespace c10 {
_(onnx, SequenceConstruct) \
_(onnx, SequenceEmpty) \
_(onnx, SequenceInsert) \
_(onnx, SequenceErase) \
_(onnx, ConcatFromSequence) \
_(onnx, Identity) \
_(onnx, SoftmaxCrossEntropyLoss) \
Expand Down
20 changes: 4 additions & 16 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
}

#ifndef __HIP_PLATFORM_HCC__
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
globalContext().alertCuBLASConfigNotDeterministic();
Expand All @@ -183,7 +183,7 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
}
#endif

#ifndef __HIP_PLATFORM_HCC__
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
globalContext().alertCuBLASConfigNotDeterministic();
Expand Down Expand Up @@ -340,7 +340,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
} while (0)

#ifndef __HIP_PLATFORM_HCC__
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
globalContext().alertCuBLASConfigNotDeterministic();
Expand All @@ -355,7 +355,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
#endif

#ifndef __HIP_PLATFORM_HCC__
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
globalContext().alertCuBLASConfigNotDeterministic();
Expand Down Expand Up @@ -492,28 +492,16 @@ void dot<float>(CUDABLAS_DOT_ARGTYPES(float)) {

template <>
void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
#ifndef __HIP_PLATFORM_HCC__
TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
reinterpret_cast<cuDoubleComplex*>(result)));
#else
TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const rocblas_double_complex*>(x),
incx, reinterpret_cast<const rocblas_double_complex*>(y), incy,
reinterpret_cast<rocblas_double_complex*>(result)));
#endif
}

template <>
void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
#ifndef __HIP_PLATFORM_HCC__
TORCH_CUDABLAS_CHECK(cublasCdotu(handle, n, reinterpret_cast<const cuComplex*>(x),
incx, reinterpret_cast<const cuComplex*>(y), incy,
reinterpret_cast<cuComplex*>(result)));
#else
TORCH_CUDABLAS_CHECK(cublasCdotu(handle, n, reinterpret_cast<const rocblas_float_complex*>(x),
incx, reinterpret_cast<const rocblas_float_complex*>(y), incy,
reinterpret_cast<rocblas_float_complex*>(result)));
#endif
}

template <>
Expand Down
48 changes: 19 additions & 29 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ static Tensor squeeze_multiple(const Tensor& self, IntArrayRef dims) {
static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
// can't take max of empty tensor
if (self.numel() != 0) {
auto maxes = at::max_values(self, dims, true);
auto maxes = at::amax(self, dims, true);
auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
at::sum_out(result, at::exp(self - maxes), dims, keepdim);
Expand Down Expand Up @@ -631,40 +631,30 @@ Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
}
}

Tensor min_values(const Tensor& self, IntArrayRef dims, bool keepdim) {
if (dims.size() == 1) {
return std::get<0>(self.min(dims[0], keepdim));
} else {
Tensor result = at::empty({0}, self.options());
ScalarType dtype = get_dtype(result, self, {}, true);
auto iter = make_reduction("min_values", result, self, dims, keepdim, dtype);
TORCH_CHECK(iter.numel() > 0, "min_values on a tensor with no elements is not defined.");
min_values_stub(iter.device_type(), iter);
return result;
}
Tensor &amin_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(self.scalar_type() == result.scalar_type(), "Illegal dtype for self, and out:", self.scalar_type(), result.scalar_type());
auto iter = make_reduction("amin", result, self, dim, keepdim, self.scalar_type());
TORCH_CHECK(iter.numel() > 0, "operation does not have an identity");
min_values_stub(iter.device_type(), iter);
return result;
}

Tensor max_values(const Tensor& self, IntArrayRef dims, bool keepdim) {
if (dims.size() == 1) {
return std::get<0>(self.max(dims[0], keepdim));
} else {
Tensor result = at::empty({0}, self.options());
ScalarType dtype = get_dtype(result, self, {}, true);
auto iter = make_reduction("max_values", result, self, dims, keepdim, dtype);
TORCH_CHECK(iter.numel() > 0, "max_values on a tensor with no elements is not defined.");
max_values_stub(iter.device_type(), iter);
return result;
}
Tensor amin(const Tensor& self, IntArrayRef dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::amin_out(result, self, dim, keepdim);
}

Tensor min_values(const Tensor& self, DimnameList dims, bool keepdim) {
TORCH_CHECK(false, "NYI: min_values with names");
return at::min_values(self, dimnames_to_positions(self, dims), keepdim);
Tensor &amax_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(self.scalar_type() == result.scalar_type(), "Illegal dtype for self, and out:", self.scalar_type(), result.scalar_type());
auto iter = make_reduction("amax", result, self, dim, keepdim, self.scalar_type());
TORCH_CHECK(iter.numel() > 0, "operation does not have an identity");
max_values_stub(iter.device_type(), iter);
return result;
}

Tensor max_values(const Tensor& self, DimnameList dims, bool keepdim) {
TORCH_CHECK(false, "NYI: max_values with names");
return at::max_values(self, dimnames_to_positions(self, dims), keepdim);
Tensor amax(const Tensor& self, IntArrayRef dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::amax_out(result, self, dim, keepdim);
}

Tensor& argmax_out(Tensor& result, const Tensor& self, c10::optional<int64_t> dim, bool keepdim) {
Expand Down
29 changes: 25 additions & 4 deletions aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,21 +269,42 @@ static void or_kernel_impl(TensorIterator& iter) {
/*ident=*/false);
}

template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};

static void min_values_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kHalf, iter.dtype(), "min_values_cpu", [&iter] {
if (iter.dtype() == kLong) {
// This case is special because of Vec256<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
using scalar_t = int64_t;
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
return;
}
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { return minimum(a, b); });
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { return minimum(a, b); },
upper_bound<scalar_t>());
});
}

static void max_values_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kHalf, iter.dtype(), "max_values_cpu", [&iter] {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); },
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { return maximum(a, b); });
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { return maximum(a, b); },
lower_bound<scalar_t>());
});
}

Expand Down
20 changes: 6 additions & 14 deletions aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,15 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) {
}

void max_values_kernel_cuda(TensorIterator& iter) {
if (iter.dtype(1) == kHalf) {
max_values_kernel_cuda_impl<at::Half, float>(iter);
} else {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() {
max_values_kernel_cuda_impl<scalar_t>(iter);
});
}
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() {
max_values_kernel_cuda_impl<scalar_t>(iter);
});
}

void min_values_kernel_cuda(TensorIterator& iter) {
if (iter.dtype(1) == kHalf) {
min_values_kernel_cuda_impl<at::Half, float>(iter);
} else {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
});
}
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
});
}

template <typename scalar_t, typename acc_t=scalar_t>
Expand Down
18 changes: 8 additions & 10 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1999,18 +1999,17 @@

- func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)

- func: max_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
use_c10_dispatcher: full
variants: function, method

- func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
variants: function, method

- func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)

- func: max_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
use_c10_dispatcher: full
variants: function, method

- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)

# Return: (Tensor output, Tensor indices)
- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
use_c10_dispatcher: full
Expand Down Expand Up @@ -2082,18 +2081,17 @@

- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)

- func: min_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
use_c10_dispatcher: full
variants: function, method

- func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
variants: function, method

- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)

- func: min_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
use_c10_dispatcher: full
variants: function, method

- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)

- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
use_c10_dispatcher: full

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/test/reduce_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ TEST(ReduceOpsTest, MaxValuesAndMinValues) {
for (const auto dtype : {kHalf, kFloat, kDouble, kShort, kInt, kLong}) {
auto a = at::rand({H, W}, TensorOptions(kCUDA).dtype(at::kHalf));
ASSERT_FLOAT_EQ(
a.max_values(c10::IntArrayRef{0, 1}).item<double>(),
a.amax(c10::IntArrayRef{0, 1}).item<double>(),
a.max().item<double>()
);
ASSERT_FLOAT_EQ(
a.min_values(c10::IntArrayRef{0, 1}).item<double>(),
a.amin(c10::IntArrayRef{0, 1}).item<double>(),
a.min().item<double>()
);
}
Expand Down
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
if(NOT INTERN_BUILD_MOBILE)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/import_legacy.cpp
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: addr
.. automethod:: addr_
.. automethod:: allclose
.. automethod:: amax
.. automethod:: amin
.. automethod:: angle
.. automethod:: apply_
.. automethod:: argmax
Expand Down
6 changes: 4 additions & 2 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ Reduction Ops

argmax
argmin
amax
amin
max
min
dist
logsumexp
mean
Expand Down Expand Up @@ -380,9 +384,7 @@ Comparison Ops
kthvalue
le
lt
max
maximum
min
minimum
ne
sort
Expand Down
4 changes: 4 additions & 0 deletions test/backward_compatibility/check_backward_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
("aten::append*", datetime.date(2020, 4, 15)),
("aten::_min", datetime.date(2020, 9, 9)),
("aten::_max", datetime.date(2020, 9, 9)),
("aten::amax", datetime.date(2020, 10, 9)),
("aten::amin", datetime.date(2020, 10, 9)),
("aten::min_values", datetime.date(2020, 10, 9)),
("aten::max_values", datetime.date(2020, 10, 9)),
("aten::split_with_sizes", datetime.date(2020, 7, 29)),
("aten::eq", datetime.date(2020, 7, 30)),
("aten::log", datetime.date(2020, 7, 30)),
Expand Down
9 changes: 9 additions & 0 deletions test/onnx/test_pytorch_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,14 @@ def wrapper(self):
return wrapper
return skip_dec

def skipIfONNXShapeInference(onnx_shape_inference):
def skip_dec(func):
def wrapper(self):
if self.onnx_shape_inference is onnx_shape_inference:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
return func(self)
return wrapper
return skip_dec

def flatten(x):
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))

0 comments on commit 5937951

Please sign in to comment.