Skip to content

Commit

Permalink
Update on "Update gather documentation to allow index.shape[k] <= inp…
Browse files Browse the repository at this point in the history
…ut.shape[k] rather than ==."

Differential Revision: [D22680014](https://our.internmc.facebook.com/intern/diff/D22680014)
  • Loading branch information
gchanan committed Dec 28, 2020
2 parents 919da0f + fc559bd commit 275d795
Show file tree
Hide file tree
Showing 137 changed files with 760 additions and 719 deletions.
10 changes: 10 additions & 0 deletions CONTRIBUTING.md
Expand Up @@ -903,6 +903,16 @@ You'll need to install an appropriately configured flake8; see
[Lint as you type](https://github.com/pytorch/pytorch/wiki/Lint-as-you-type)
for documentation on how to do this.

If you haven't set up the pre-commit hook and have already committed files and
CI reports `flake8` errors, you can run the check locally in your PR branch with:
```bash
flake8 $(git diff --name-only $(git merge-base --fork-point master))
```
fix the code so that no errors are reported when you re-run the above check again,
and then commit the fix.
## Building PyTorch with ASAN
[ASAN](https://github.com/google/sanitizers/wiki/AddressSanitizer) is very
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/LegacyTHFunctionsCUDA.h
Expand Up @@ -75,7 +75,6 @@ Tensor & _thnn_log_sigmoid_backward_out(Tensor & grad_input, const Tensor & grad
Tensor _thnn_log_sigmoid_backward(const Tensor & grad_output, const Tensor & self, const Tensor & buffer);
Tensor & _thnn_rrelu_with_noise_forward_out(Tensor & output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
Tensor _thnn_rrelu_with_noise_forward(const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
Tensor & _thnn_rrelu_with_noise_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training);
Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training);
Tensor & _thnn_rrelu_with_noise_forward_(Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
std::tuple<Tensor &,Tensor &,Tensor &> _thnn_conv2d_forward_out(Tensor & output, Tensor & columns, Tensor & ones, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const Tensor & bias, IntArrayRef stride, IntArrayRef padding);
Expand Down
78 changes: 0 additions & 78 deletions aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Expand Up @@ -2498,84 +2498,6 @@ Tensor _thnn_rrelu_with_noise_forward(const Tensor & self, const Tensor & noise,
}
return output;
}
Tensor & _thnn_rrelu_with_noise_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training) {
const OptionalDeviceGuard device_guard(device_of(self));
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaDoubleRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
case ScalarType::Float: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
case ScalarType::Half: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaHalfRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
default:
AT_ERROR("_thnn_rrelu_with_noise_backward_out not supported on CUDAType for ", dispatch_scalar_type);
}
return grad_input;
}
Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training) {
const OptionalDeviceGuard device_guard(device_of(self));
auto dispatch_scalar_type = infer_scalar_type(self);
auto grad_input_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto grad_input = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(grad_input_));
switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
THNN_CudaDoubleRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
case ScalarType::Float: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
THNN_CudaRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
case ScalarType::Half: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
THNN_CudaHalfRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
default:
AT_ERROR("_thnn_rrelu_with_noise_backward not supported on CUDAType for ", dispatch_scalar_type);
}
return grad_input;
}
Tensor & _thnn_rrelu_with_noise_forward_(Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator) {
const OptionalDeviceGuard device_guard(device_of(self));
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
22 changes: 16 additions & 6 deletions aten/src/ATen/native/Math.h
Expand Up @@ -277,15 +277,20 @@ static inline float trigamma(float x) {
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline double calc_digamma(double x) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
static double PSI_10 = 2.25175258906672110764;
if (x == 0) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(INFINITY, -x);
}

int x_is_integer = x == floor(x);
bool x_is_integer = x == trunc(x);
if (x < 0) {
if (x_is_integer) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return NAN;
}
return calc_digamma(1 - x) - M_PI / tan(M_PI * x);
}
Expand Down Expand Up @@ -324,15 +329,20 @@ static inline double calc_digamma(double x) {
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline float calc_digamma(float x) {
// See [C++ Standard Reference: Gamma Function]
static float PSI_10 = 2.25175258906672110764f;
if (x == 0) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(INFINITY, -x);
}

int x_is_integer = x == floorf(x);
bool x_is_integer = x == truncf(x);
if (x < 0) {
if (x_is_integer) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return NAN;
}
// Avoid rounding errors for `tan`'s input.
// Those make a big difference at extreme values.
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -276,6 +276,10 @@ Tensor& erfc_out(Tensor& result, const Tensor& self) { return unary_op_impl_floa
Tensor erfc(const Tensor& self) { return unary_op_impl_float(self, erfc_stub); }
Tensor& erfc_(Tensor& self) { return unary_op_impl_(self, at::erfc_out); }

Tensor& erfinv_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erfinv_stub); }
Tensor erfinv(const Tensor& self) { return unary_op_impl_float(self, erfinv_stub); }
Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, at::erfinv_out); }

Tensor& frac_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, frac_stub); }
Tensor frac(const Tensor& self) { return unary_op_impl(self, at::frac_out); }
Tensor& frac_(Tensor& self) { return unary_op_impl_(self, at::frac_out); }
Expand Down Expand Up @@ -314,8 +318,8 @@ Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out
Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); }
Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); }

Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, digamma_stub); }
Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); }
Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, digamma_stub); }
Tensor digamma(const Tensor& self) { return unary_op_impl_float(self, digamma_stub); }
Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); }

Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, reciprocal_stub); }
Expand Down Expand Up @@ -683,7 +687,6 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU) \
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cuda, CUDA)

IMPLEMENT_UNARY_OP_VEC_CUDA(erfinv)
IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma)

DEFINE_DISPATCH(abs_stub);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -360,7 +360,7 @@ static void atanh_kernel(TensorIterator& iter) {
}

static void digamma_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "digamma", [&]() {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "digamma", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return calc_digamma(a); });
Expand Down
11 changes: 8 additions & 3 deletions aten/src/ATen/native/cuda/Math.cuh
Expand Up @@ -93,6 +93,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
*/
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
Expand All @@ -108,14 +109,18 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {

accscalar_t x = static_cast<accscalar_t>(in);
if (x == 0) {
return static_cast<scalar_t>(INFINITY);
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(static_cast<scalar_t>(INFINITY), -x);
}

bool x_is_integer = x == ::floor(x);
bool x_is_integer = x == ::trunc(x);
accscalar_t result = 0;
if (x < 0) {
if (x_is_integer) {
return static_cast<scalar_t>(INFINITY);
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return static_cast<scalar_t>(NAN);
}
// Rounding errors in tan's input can really affect the output
// for extreme values, so we always perform this computation in double.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryGammaKernels.cu
Expand Up @@ -11,7 +11,7 @@
namespace at { namespace native {

void digamma_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "digamma_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "digamma_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_digamma(a);
});
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Expand Up @@ -160,7 +160,7 @@ void erfc_kernel_cuda(TensorIterator& iter) {
}

void erfinv_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "erfinv_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfinv(a);
});
Expand Down
6 changes: 2 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -6937,14 +6937,12 @@
use_c10_dispatcher: full
variants: method
dispatch:
CPU: _erfinv__cpu
CUDA: _erfinv__cuda
CPU, CUDA: erfinv_

- func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
dispatch:
CPU: _erfinv_out_cpu
CUDA: _erfinv_out_cuda
CPU, CUDA: erfinv_out

- func: i0(Tensor self) -> Tensor
use_c10_dispatcher: full
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/quantized/cpu/qconv.cpp
Expand Up @@ -746,7 +746,7 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
run_status == pytorch_qnnp_status_success,
"failed to run quantized::conv2d (qnnpack) operator");

return output;
return output.contiguous(act.suggest_memory_format());
}

template at::Tensor PackedConvWeightsQnnp<2>::apply(
Expand Down
18 changes: 10 additions & 8 deletions aten/src/ATen/native/quantized/cuda/affine_quantizer.cu
Expand Up @@ -25,14 +25,16 @@ void quantize_tensor_per_tensor_affine_cuda(
.add_input(qtensor)
.build();

gpu_kernel(iter,
[=] GPU_LAMBDA (float raw_val, scalar_t quantized_val) -> scalar_t {
int64_t qvalue = static_cast<int64_t>(nearbyint(raw_val / scale + zero_point));
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
quantized_val.val_ = qvalue;
return quantized_val;
});
gpu_kernel(
iter,
[=] GPU_LAMBDA(float raw_val, scalar_t quantized_val) -> scalar_t {
int64_t qvalue =
static_cast<int64_t>(nearbyint(raw_val / scale) + zero_point);
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
quantized_val.val_ = qvalue;
return quantized_val;
});
});
}

Expand Down
39 changes: 19 additions & 20 deletions aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Expand Up @@ -34,17 +34,16 @@ void fake_quantize_tensor_kernel_cuda(
.add_output(output)
.add_input(input)
.build();
gpu_kernel(iter,
[=] GPU_LAMBDA (float input_val) -> float {
return (fminf(
gpu_kernel(iter, [=] GPU_LAMBDA(float input_val) -> float {
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(std::nearbyint(
input_val * inv_scale + zero_point)))) -
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) + zero_point))) -
zero_point) *
scale;
});
scale;
});
}

void fake_quantize_grad_tensor_kernel_cuda(
Expand All @@ -63,11 +62,10 @@ void fake_quantize_grad_tensor_kernel_cuda(
.add_input(output_grad)
.add_input(input)
.build();
gpu_kernel(iter,
[=] GPU_LAMBDA (float dy, float x) -> float {
int64_t Xq = std::nearbyint(x * inv_scale + zero_point);
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
gpu_kernel(iter, [=] GPU_LAMBDA(float dy, float x) -> float {
int64_t Xq = std::nearbyint(x * inv_scale) + zero_point;
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
}

void _fake_quantize_grad_learnable_tensor_kernel_cuda(
Expand All @@ -82,7 +80,7 @@ void _fake_quantize_grad_learnable_tensor_kernel_cuda(
gpu_kernel_multiple_outputs(
iter, [=] GPU_LAMBDA (float XInput, float dYInput) -> thrust::tuple<float, float, float> {
float dXOutput, dZeroPointOutput, dScaleOutput;
int64_t xq = std::nearbyint(zero_point + XInput * inv_scale);
int64_t xq = std::nearbyint(XInput * inv_scale) + zero_point;
dXOutput = dYInput * (xq >= quant_min && xq <= quant_max);
xq = std::max(std::min(xq, quant_max), quant_min);
float xfq = static_cast<float>((xq - zero_point) * scale);
Expand All @@ -108,12 +106,13 @@ void fake_quant_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_
[=] GPU_LAMBDA (float input_val, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(std::nearbyint(
input_val * inv_scale + zero_point)))) -
zero_point) *
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) +
zero_point))) -
zero_point) *
scale;
});
}
Expand All @@ -122,7 +121,7 @@ void fake_quant_grad_per_channel_cuda(TensorIterator &iter, int64_t quant_min, i
gpu_kernel(iter,
[=] GPU_LAMBDA (float x, float dy, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
int64_t Xq = std::nearbyint(x * inv_scale + zero_point);
int64_t Xq = std::nearbyint(x * inv_scale) + zero_point;
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
}
Expand Down

0 comments on commit 275d795

Please sign in to comment.