Skip to content

Commit

Permalink
Update on "Fix max_pool2d with ceil_mode bug"
Browse files Browse the repository at this point in the history
This PR fixes a bug with how pooling output shape was computed.

## BC Breaking Notes
Previously, a bug in the pooling code allowed a sliding window to be entirely off bounds. Now, sliding windows must start inside the input or left padding (not right padding, see #46929) and may only go off-bounds if ceil_mode=True.

fixes #45357

TODO

- [x] Ensure existing tests are checking for the correct output size

[ghstack-poisoned]
  • Loading branch information
heitorschueroff committed Oct 29, 2020
2 parents 4009646 + 4a581ba commit a4f0899
Show file tree
Hide file tree
Showing 121 changed files with 2,629 additions and 943 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ jobs:

- run:
name: Build
no_output_timeout: "1h"
no_output_timeout: "90m"
command: |
# Do not set -u here; there is some problem with CircleCI
# variable expansion with PROMPT_COMMAND
Expand Down
38 changes: 34 additions & 4 deletions .circleci/docker/common/install_cache.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@

set -ex

install_ubuntu() {
echo "Preparing to build sccache from source"
apt-get update
apt-get install -y cargo pkg-config libssl-dev
echo "Checking out sccache repo"
git clone https://github.com/pytorch/sccache
cd sccache
echo "Building sccache"
cargo build --release
cp target/release/sccache /opt/cache/bin
echo "Cleaning up"
cd ..
rm -rf sccache
apt-get remove -y cargo rustc
apt-get autoclean && apt-get clean
}

install_binary() {
echo "Downloading sccache binary from S3 repo"
curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /opt/cache/bin/sccache
}

mkdir -p /opt/cache/bin
mkdir -p /opt/cache/lib
sed -e 's|PATH="\(.*\)"|PATH="/opt/cache/bin:\1"|g' -i /etc/environment
Expand All @@ -11,12 +33,20 @@ export PATH="/opt/cache/bin:$PATH"
if [ -n "$ROCM_VERSION" ]; then
curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache
else
curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /opt/cache/bin/sccache
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
case "$ID" in
ubuntu)
install_ubuntu
;;
*)
install_binary
;;
esac
fi
chmod a+x /opt/cache/bin/sccache

function write_sccache_stub() {
printf "#!/bin/sh\nexec sccache $(which $1) \"\$@\"" > "/opt/cache/bin/$1"
printf "#!/bin/sh\nif [ \$(ps -p \$PPID -o comm=) != sccache ]; then\n exec sccache $(which $1) \"\$@\"\nelse\n exec $(which $1) \"\$@\"\nfi" > "/opt/cache/bin/$1"
chmod a+x "/opt/cache/bin/$1"
}

Expand All @@ -38,8 +68,8 @@ if [ -n "$CUDA_VERSION" ]; then
# where CUDA is installed. Instead, we install an nvcc symlink outside
# of the PATH, and set CUDA_NVCC_EXECUTABLE so that we make use of it.

printf "#!/bin/sh\nexec sccache $(which nvcc) \"\$@\"" > /opt/cache/lib/nvcc
chmod a+x /opt/cache/lib/nvcc
write_sccache_stub nvcc
mv /opt/cache/bin/nvcc /opt/cache/lib/
fi

if [ -n "$ROCM_VERSION" ]; then
Expand Down
1 change: 1 addition & 0 deletions .circleci/docker/common/install_gcc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if [ -n "$GCC_VERSION" ]; then

update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50

# Cleanup package manager
apt-get autoclean && apt-get clean
Expand Down
2 changes: 1 addition & 1 deletion .circleci/verbatim-sources/job-specs/binary-job-specs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@

- run:
name: Build
no_output_timeout: "1h"
no_output_timeout: "90m"
command: |
# Do not set -u here; there is some problem with CircleCI
# variable expansion with PROMPT_COMMAND
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ jobs:
run: |
sudo apt-get install -y doxygen && pip install -r requirements.txt
cd docs/cpp/source && ./check-doxygen.sh
- name: CUDA kernel launch check
run: |
set -eux
python torch/testing/check_kernel_launches.py |& tee ${GITHUB_WORKSPACE}/cuda_kernel_launch_checks.txt
flake8-py3:
runs-on: ubuntu-latest
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(A
}
if (guard.needs_inputs) {
torch::jit::Stack stack = impl::BoxedKernelWrapper<Return(Args...)>::boxArgs(args...);
guard.before(op.schema().name(), stack, seq_num);
guard.before(op, stack, seq_num);
} else {
guard.before(op.schema().name(), seq_num);
guard.before(op, seq_num);
}
}
}
Expand Down Expand Up @@ -438,9 +438,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
seq_num = at::sequence_number::peek();
}
if (guard.needs_inputs) {
guard.before(op.schema().name(), *stack, seq_num);
guard.before(op, *stack, seq_num);
} else {
guard.before(op.schema().name(), seq_num);
guard.before(op, seq_num);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace c10 {
_(prim, ReturnStmt) \
_(prim, BreakStmt) \
_(prim, ContinueStmt) \
_(prim, LocalVariableScope) \
_(prim, ListComprehensionScope) \
_(prim, Store) \
_(prim, AutogradZero) \
_(prim, AutogradAnyNonZero) \
Expand Down Expand Up @@ -129,7 +129,7 @@ namespace c10 {
_(prim, fork) \
_(prim, forkClosure) \
_(prim, RaiseException) \
_(prim, Function) \
_(prim, Closure) \
_(prim, CreateObject) \
_(prim, SetAttr) \
_(prim, GetAttr) \
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/core/op_registration/op_whitelist.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ namespace impl {
// returns true iff whitelist contains item
// op_whitelist_contains("a;bc;d", "bc") == true
constexpr bool op_whitelist_contains(string_view whitelist, string_view item) {
size_t next = -1;
//Choose a really big value for next so that if something goes wrong
//this code will blow up in a hopefully detectable way.
size_t next = std::numeric_limits<size_t>::max();
for (size_t cur = 0; cur <= whitelist.size(); cur = next) {
next = whitelist.find(';', cur);
if (next != string_view::npos) {
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/cuda/Exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ const char *cusparseGetErrorString(cusparseStatus_t status);

#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)

// This should be used directly after every kernel launch to ensure
// the launch happened correctly and provide an early, close-to-source
// diagnostic if it didn't.
#define TORCH_CUDA_KERNEL_LAUNCH_CHECK() AT_CUDA_CHECK(cudaGetLastError())

// For CUDA Driver API
//
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/quantized/Copy.h>
#include <ATen/native/vulkan/ops/Copy.h>
#include <ATen/quantized/Quantizer.h>
#include <ATen/vulkan/Context.h>
#include <ATen/metal/Context.h>
Expand Down Expand Up @@ -131,7 +132,11 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
}

if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) {
#ifdef USE_VULKAN_API
return vulkan::ops::copy_(self, src);
#else
return at::vulkan::vulkan_copy_(self, src);
#endif
}

if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) {
Expand Down
8 changes: 5 additions & 3 deletions aten/src/ATen/native/MaxPooling.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/core/grad_mode.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/MaxPooling.h>
#include <ATen/native/Pool.h>
Expand Down Expand Up @@ -98,10 +99,11 @@ Tensor max_pool1d(
IntArrayRef dilation,
bool ceil_mode) {
if (self.is_quantized()) {
return at::quantized_max_pool1d(self, kernel_size, stride, padding,
dilation, ceil_mode);
return at::quantized_max_pool1d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
if (self.requires_grad() || !self.device().is_cpu()) {
if ((self.requires_grad() && at::GradMode::is_enabled()) ||
!self.device().is_cpu()) {
// Needs indices for grad and with_indices defines CUDA dispatch
return std::get<0>(at::max_pool1d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode));
Expand Down
11 changes: 5 additions & 6 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ Tensor new_empty(
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Tensor eye(int64_t n, const TensorOptions& options) {
return native::eye(n, -1, options);
// the default value of `m` equals to `n`
return native::eye(n, n, options);
}

Tensor eye(int64_t n, int64_t m, const TensorOptions& options) {
Expand All @@ -390,15 +391,13 @@ Tensor eye(int64_t n, int64_t m, const TensorOptions& options) {
}

Tensor& eye_out_cpu(Tensor& result, int64_t n) {
return native::eye_out_cpu(result, n, -1);
// the default value of `m` equals to `n`
return native::eye_out_cpu(result, n, n);
}

Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) {
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);

if(m < 0) {
m = n;
}
TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);

result.resize_({n, m});
result.zero_();
Expand Down
20 changes: 10 additions & 10 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,20 +284,20 @@ Tensor& i0_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(re
Tensor i0(const Tensor& self) { return unary_op_impl(self, at::i0_out); }
Tensor& i0_(Tensor& self) { return unary_op_impl_(self, at::i0_out); }

Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log_stub); }
Tensor log(const Tensor& self) { return unary_op_impl(self, at::log_out); }
Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log_stub); }
Tensor log(const Tensor& self) { return unary_op_impl_float(self, log_stub); }
Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); }

Tensor& log10_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log10_stub); }
Tensor log10(const Tensor& self) { return unary_op_impl(self, at::log10_out); }
Tensor& log10_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log10_stub); }
Tensor log10(const Tensor& self) { return unary_op_impl_float(self, log10_stub); }
Tensor& log10_(Tensor& self) { return unary_op_impl_(self, at::log10_out); }

Tensor& log1p_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log1p_stub); }
Tensor log1p(const Tensor& self) { return unary_op_impl(self, at::log1p_out); }
Tensor& log1p_(Tensor& self) { return unary_op_impl_(self, at::log1p_out); }

Tensor& log2_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log2_stub); }
Tensor log2(const Tensor& self) { return unary_op_impl(self, at::log2_out); }
Tensor& log2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log2_stub); }
Tensor log2(const Tensor& self) { return unary_op_impl_float(self, log2_stub); }
Tensor& log2_(Tensor& self) { return unary_op_impl_(self, at::log2_out); }

Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, round_stub); }
Expand Down Expand Up @@ -339,8 +339,8 @@ Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float
Tensor sin(const Tensor& self) { return unary_op_impl_float(self, sin_stub); }
Tensor& sin_(Tensor& self) { return unary_op_impl_(self, at::sin_out); }

Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cos_stub); }
Tensor cos(const Tensor& self) { return unary_op_impl(self, at::cos_out); }
Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cos_stub); }
Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); }
Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); }

Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); }
Expand Down Expand Up @@ -452,8 +452,8 @@ Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(
Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); }
Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); }

Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tan_stub); }
Tensor tan(const Tensor& self) { return unary_op_impl(self, at::tan_out); }
Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tan_stub); }
Tensor tan(const Tensor& self) { return unary_op_impl_float(self, tan_stub); }
Tensor& tan_(Tensor& self) { return unary_op_impl_(self, at::tan_out); }

Tensor& trunc_out(Tensor& result, const Tensor& self) {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void nextafter_kernel_cuda(TensorIterator& iter) {

void heaviside_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a == 0 ? b : static_cast<scalar_t>(a > 0);
});
});
Expand Down
8 changes: 3 additions & 5 deletions aten/src/ATen/native/cuda/TensorFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ namespace at {
namespace native {

Tensor& eye_out_cuda(Tensor& result, int64_t n) {
return at::native::eye_out_cuda(result, n, /*m=*/-1);
// the default value of `m` equals to `n`
return at::native::eye_out_cuda(result, n, n);
}

Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);

if(m < 0) {
m = n;
}
TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);

result.resize_({n, m});
result.zero_();
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void sin_kernel_cuda(TensorIterator& iter) {
}

void cos_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "cos_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "cos_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::cos(a);
});
Expand Down Expand Up @@ -99,7 +99,7 @@ void atanh_kernel_cuda(TensorIterator& iter) {
}

void tan_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "tan_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "tan_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::tan(a);
});
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/UnaryLogKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
namespace at { namespace native {

void log_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::log(a);
});
});
}

void log10_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log10_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log10_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::log10(a);
});
Expand All @@ -35,7 +35,7 @@ void log1p_kernel_cuda(TensorIterator& iter) {
}

void log2_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log2_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log2_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::log2(a);
});
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qadd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ Tensor qadd_scalar(Tensor qa, Scalar b) {
return _add_scalar_out<ReLUFused>(qc, qa, b);
}

template <bool ReLUFused = false>
Tensor qadd_scalar2(Scalar b, Tensor qa) {
TORCH_CHECK(qa.qscheme() == kPerTensorAffine ||
qa.qscheme() == kPerTensorSymmetric,
"Only per tensor quantization is supported in Add.");
auto qc = at::empty_like(qa, qa.suggest_memory_format());
return _add_scalar_out<ReLUFused>(qc, qa, b);
}

template <bool ReLUFused = false>
Tensor qadd_scalar_out(Tensor qa, Scalar b, Tensor out) {
check_inputs(qa, out);
Expand All @@ -269,10 +278,12 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::add"), TORCH_FN(qadd</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.out"), TORCH_FN(qadd_out</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar"), TORCH_FN(qadd_scalar</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar2"), TORCH_FN(qadd_scalar2</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar_out"), TORCH_FN(qadd_scalar_out</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu"), TORCH_FN(qadd</*ReLUFused=*/true>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.out"), TORCH_FN(qadd_out</*ReLUFused=*/true>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar"), TORCH_FN(qadd_scalar</*ReLUFused=*/true>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar2"), TORCH_FN(qadd_scalar2</*ReLUFused=*/true>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar_out"), TORCH_FN(qadd_scalar_out</*ReLUFused=*/true>));
// deprecated functions, kept for backward compatibility
m.impl(TORCH_SELECTIVE_NAME("quantized::add_out"), TORCH_FN(qadd_out</*ReLUFused=*/false>));
Expand Down

0 comments on commit a4f0899

Please sign in to comment.