Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardized clamp kernels to Numpy-like implementation #43288

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 1 addition & 32 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Expand Up @@ -615,23 +615,12 @@ inline T minimum(const T& a, const T& b) {
return c;
}

// To save BC, it will not propagate NaN based on IEEE 754 201X
template <class T,
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp(const Vec256<T> &a, const Vec256<T> &min_vec, const Vec256<T> &max_vec) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = a[i] < min_vec[i] ? min_vec[i] : (a[i] > max_vec[i] ? max_vec[i] : a[i]);
}
return c;
}

template <class T,
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp(const Vec256<T> &a, const Vec256<T> &min_vec, const Vec256<T> &max_vec) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : (std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i]);
c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
}
return c;
}
Expand All @@ -646,16 +635,6 @@ Vec256<T> inline clamp_max(const Vec256<T> &a, const Vec256<T> &max_vec) {
return c;
}

template <class T,
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp_max(const Vec256<T> &a, const Vec256<T> &max_vec) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i];
}
return c;
}

template <class T,
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
Expand All @@ -666,16 +645,6 @@ Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
return c;
}

template <class T,
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : a[i];
}
return c;
}

struct Vec256i;

#ifdef CPU_CAPABILITY_AVX2
Expand Down
26 changes: 0 additions & 26 deletions aten/src/ATen/cpu/vec256/vec256_complex_double.h
Expand Up @@ -416,32 +416,6 @@ Vec256<c10::complex<double>> inline minimum(const Vec256<c10::complex<double>>&
return _mm256_or_pd(min, isnan);
}

template <>
Vec256<c10::complex<double>> inline clamp(const Vec256<c10::complex<double>>& a, const Vec256<c10::complex<double>>& min, const Vec256<c10::complex<double>>& max) {
vsimkus marked this conversation as resolved.
Show resolved Hide resolved
auto abs_a = a.abs_2_();
auto abs_min = min.abs_2_();
auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ);
auto abs_max = max.abs_2_();
auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ);
return _mm256_blendv_pd(_mm256_blendv_pd(a, min, max_mask), max, min_mask);
}

template <>
Vec256<c10::complex<double>> inline clamp_min(const Vec256<c10::complex<double>>& a, const Vec256<c10::complex<double>>& min) {
auto abs_a = a.abs_2_();
auto abs_min = min.abs_2_();
auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ);
return _mm256_blendv_pd(a, min, max_mask);
}

template <>
Vec256<c10::complex<double>> inline clamp_max(const Vec256<c10::complex<double>>& a, const Vec256<c10::complex<double>>& max) {
auto abs_a = a.abs_2_();
auto abs_max = max.abs_2_();
auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ);
return _mm256_blendv_pd(a, max, min_mask);
}

template <>
Vec256<c10::complex<double>> inline operator&(const Vec256<c10::complex<double>>& a, const Vec256<c10::complex<double>>& b) {
return _mm256_and_pd(a, b);
Expand Down
26 changes: 0 additions & 26 deletions aten/src/ATen/cpu/vec256/vec256_complex_float.h
Expand Up @@ -456,32 +456,6 @@ Vec256<c10::complex<float>> inline minimum(const Vec256<c10::complex<float>>& a,
return _mm256_or_ps(min, isnan);
}

template <>
Vec256<c10::complex<float>> inline clamp(const Vec256<c10::complex<float>>& a, const Vec256<c10::complex<float>>& min, const Vec256<c10::complex<float>>& max) {
auto abs_a = a.abs_2_();
auto abs_min = min.abs_2_();
auto max_mask = _mm256_cmp_ps(abs_a, abs_min, _CMP_LT_OQ);
auto abs_max = max.abs_2_();
auto min_mask = _mm256_cmp_ps(abs_a, abs_max, _CMP_GT_OQ);
return _mm256_blendv_ps(_mm256_blendv_ps(a, min, max_mask), max, min_mask);
}

template <>
Vec256<c10::complex<float>> inline clamp_min(const Vec256<c10::complex<float>>& a, const Vec256<c10::complex<float>>& min) {
auto abs_a = a.abs_2_();
auto abs_min = min.abs_2_();
auto max_mask = _mm256_cmp_ps(abs_a, abs_min, _CMP_LT_OQ);
return _mm256_blendv_ps(a, min, max_mask);
}

template <>
Vec256<c10::complex<float>> inline clamp_max(const Vec256<c10::complex<float>>& a, const Vec256<c10::complex<float>>& max) {
auto abs_a = a.abs_2_();
auto abs_max = max.abs_2_();
auto min_mask = _mm256_cmp_ps(abs_a, abs_max, _CMP_GT_OQ);
return _mm256_blendv_ps(a, max, min_mask);
}

template <>
Vec256<c10::complex<float>> inline operator&(const Vec256<c10::complex<float>>& a, const Vec256<c10::complex<float>>& b) {
return _mm256_and_ps(a, b);
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -501,7 +501,7 @@ Tensor signbit(const Tensor& self) {
}

Tensor& clamp_out(Tensor& result, const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors.");
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
if (min && max) {
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp only supports strided layout, got: ", self.layout());
Expand All @@ -512,7 +512,7 @@ Tensor& clamp_out(Tensor& result, const Tensor& self, optional<Scalar> min, opti
} else if (min) {
at::clamp_min_out(result, self, *min);
} else {
AT_ERROR("At least one of 'min' or 'max' must not be None");
TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None");
}
return result;
}
Expand All @@ -527,7 +527,7 @@ Tensor& clamp_(Tensor& self, optional<Scalar> min, optional<Scalar> max) {
}

Tensor& clamp_max_out(Tensor& result, const Tensor& self, Scalar max) {
TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors.");
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp_max only supports strided layout, got: ", self.layout());
auto iter = TensorIterator::unary_op(result, self);
Expand All @@ -545,7 +545,7 @@ Tensor& clamp_max_(Tensor& self, Scalar max) {
}

Tensor& clamp_min_out(Tensor& result, const Tensor& self, Scalar min) {
TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors.");
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp_min only supports strided layout, got: ", self.layout());
auto iter = TensorIterator::unary_op(result, self);
Expand Down
15 changes: 6 additions & 9 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -411,36 +411,33 @@ static void nan_to_num_kernel(
}

static void clamp_kernel(TensorIterator& iter, Scalar min_scalar, Scalar max_scalar) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_cpu", [&]() {
c10::scalar_value_type<scalar_t>::type (*zabs_)(scalar_t) = zabs;
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_cpu", [&]() {
auto min = min_scalar.to<scalar_t>();
auto max = max_scalar.to<scalar_t>();
auto min_vec = Vec256<scalar_t>(min);
auto max_vec = Vec256<scalar_t>(max);
cpu_kernel_vec(iter,
[=](scalar_t a) -> scalar_t { return zabs_(a) < zabs_(min) ? min : (zabs_(a) > zabs_(max) ? max : a); },
[=](scalar_t a) -> scalar_t { return std::min(std::max(a, min), max); },
[=](Vec256<scalar_t> a) { return vec256::clamp(a, min_vec, max_vec); });
});
}

static void clamp_max_kernel(TensorIterator& iter, Scalar max_scalar) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_max_cpu", [&]() {
c10::scalar_value_type<scalar_t>::type (*zabs_)(scalar_t) = zabs;
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_max_cpu", [&]() {
auto max = max_scalar.to<scalar_t>();
auto max_vec = Vec256<scalar_t>(max);
cpu_kernel_vec(iter,
[=](scalar_t a) -> scalar_t { return zabs_(a) > zabs_(max) ? max : a; },
[=](scalar_t a) -> scalar_t { return std::min(a, max); },
[=](Vec256<scalar_t> a) { return vec256::clamp_max(a, max_vec); });
});
}

static void clamp_min_kernel(TensorIterator& iter, Scalar min_scalar) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_min_cpu", [&]() {
c10::scalar_value_type<scalar_t>::type (*zabs_)(scalar_t) = zabs;
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_min_cpu", [&]() {
auto min = min_scalar.to<scalar_t>();
auto min_vec = Vec256<scalar_t>(min);
cpu_kernel_vec(iter,
[=](scalar_t a) -> scalar_t { return zabs_(a) < zabs_(min) ? min : a; },
[=](scalar_t a) -> scalar_t { return std::max(a, min); },
[=](Vec256<scalar_t> a) { return vec256::clamp_min(a, min_vec); });
});
}
Expand Down
22 changes: 19 additions & 3 deletions aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Expand Up @@ -12,6 +12,7 @@
#include <ATen/native/cuda/Math.cuh>
#include <ATen/NumericUtils.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/NumericUtils.h>
#include <c10/util/complex.h>

namespace at {
Expand Down Expand Up @@ -158,7 +159,12 @@ void clamp_kernel_cuda(TensorIterator& iter, Scalar min_value, Scalar max_value)
auto lower = min_value.to<scalar_t>();
auto upper = max_value.to<scalar_t>();
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
return (v < lower) ? lower : (v > upper ? upper : v);
// Propagate nan, which doesn't propagate automatically for ROCm
if (_isnan(v)) {
return v;
} else {
return ::min(::max(v, lower), upper);
}
});
});
}
Expand All @@ -167,7 +173,12 @@ void clamp_min_kernel_cuda(TensorIterator& iter, Scalar min_value) {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_min_cuda", [&]() {
auto lower = min_value.to<scalar_t>();
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
return v < lower ? lower : v;
// Propagate nan, which doesn't propagate automatically for ROCm
if (_isnan(v)) {
return v;
} else {
return ::max(v, lower);
}
});
});
}
Expand All @@ -176,7 +187,12 @@ void clamp_max_kernel_cuda(TensorIterator& iter, Scalar max_value) {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_max_cuda", [&]() {
auto upper = max_value.to<scalar_t>();
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
return v > upper ? upper : v;
// Propagate nan, which doesn't propagate automatically for ROCm
if (_isnan(v)) {
return v;
} else {
return ::min(v, upper);
}
});
});
}
Expand Down