Skip to content

Commit

Permalink
Update on "fx quant: hook up ConvTranspose{n}d"
Browse files Browse the repository at this point in the history
Summary:

Quantization of `ConvTranpose{n}d` is supported in Eager mode. This PR
adds the support for FX graph mode.

Note: this currenlty only works in `qnnpack` because per-channel weights
are not supported by quantized conv transpose. In a future PR we should throw
an error when someone tries to quantize a ConvTranspose model with per-channel
weight observers until this is fixed.

Test Plan:

```
python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_1d
python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_2d
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D25674636](https://our.internmc.facebook.com/intern/diff/D25674636)

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Dec 28, 2020
2 parents 1b60e51 + 8ff59fa commit 99fb85c
Show file tree
Hide file tree
Showing 159 changed files with 2,914 additions and 1,019 deletions.
2 changes: 1 addition & 1 deletion .circleci/cimodel/data/dimensions.py
Expand Up @@ -8,8 +8,8 @@
]

ROCM_VERSIONS = [
"3.9",
"3.10",
"4.0",
]

ROCM_VERSION_LABELS = ["rocm" + v for v in ROCM_VERSIONS]
Expand Down
208 changes: 104 additions & 104 deletions .circleci/config.yml

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion BUILD.bazel
Expand Up @@ -373,7 +373,6 @@ filegroup(
filegroup(
name = "thc_srcs_cu",
srcs = [
"aten/src/THC/THCBlas.cu.cc",
"aten/src/THC/THCReduceApplyUtils.cu.cc",
"aten/src/THC/THCSleep.cu.cc",
"aten/src/THC/THCSortUtils.cu.cc",
Expand Down
10 changes: 10 additions & 0 deletions CONTRIBUTING.md
Expand Up @@ -903,6 +903,16 @@ You'll need to install an appropriately configured flake8; see
[Lint as you type](https://github.com/pytorch/pytorch/wiki/Lint-as-you-type)
for documentation on how to do this.

If you haven't set up the pre-commit hook and have already committed files and
CI reports `flake8` errors, you can run the check locally in your PR branch with:
```bash
flake8 $(git diff --name-only $(git merge-base --fork-point master))
```
fix the code so that no errors are reported when you re-run the above check again,
and then commit the fix.
## Building PyTorch with ASAN
[ASAN](https://github.com/google/sanitizers/wiki/AddressSanitizer) is very
Expand Down
1 change: 0 additions & 1 deletion android/gradle/android_tasks.gradle
@@ -1,4 +1,3 @@

import java.nio.file.Files
import java.nio.file.Paths
import java.io.FileOutputStream
Expand Down
1 change: 0 additions & 1 deletion android/pytorch_android/host/build.gradle
Expand Up @@ -38,4 +38,3 @@ dependencies {
}

apply from: rootProject.file('gradle/release.gradle')

1 change: 0 additions & 1 deletion android/settings.gradle
Expand Up @@ -4,4 +4,3 @@ project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torch

project(':pytorch_host').projectDir = file('pytorch_android/host')
project(':test_app').projectDir = file('test_app/app')

1 change: 0 additions & 1 deletion aten/src/ATen/LegacyTHFunctionsCUDA.h
Expand Up @@ -75,7 +75,6 @@ Tensor & _thnn_log_sigmoid_backward_out(Tensor & grad_input, const Tensor & grad
Tensor _thnn_log_sigmoid_backward(const Tensor & grad_output, const Tensor & self, const Tensor & buffer);
Tensor & _thnn_rrelu_with_noise_forward_out(Tensor & output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
Tensor _thnn_rrelu_with_noise_forward(const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
Tensor & _thnn_rrelu_with_noise_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training);
Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training);
Tensor & _thnn_rrelu_with_noise_forward_(Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
std::tuple<Tensor &,Tensor &,Tensor &> _thnn_conv2d_forward_out(Tensor & output, Tensor & columns, Tensor & ones, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const Tensor & bias, IntArrayRef stride, IntArrayRef padding);
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/MemoryOverlap.cpp
Expand Up @@ -48,6 +48,9 @@ MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b) {
if (!a->is_contiguous() || !b->is_contiguous()) {
return MemOverlapStatus::TOO_HARD;
}
if (!a->has_storage() || !b->has_storage()) {
return MemOverlapStatus::NO;
}
if (a->storage().data() == b->storage().data()) {
const auto a_begin = static_cast<char*>(a->data());
const auto a_end = a_begin + a->numel() * a->itemsize();
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/core/Formatting.cpp
Expand Up @@ -292,6 +292,11 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
stream << ", axis: " << tensor_.q_per_channel_axis();
}
}

auto& fw_grad = tensor.fw_grad(/* level */ 0);
if (fw_grad.defined()) {
stream << ", tangent:" << std::endl << fw_grad;
}
stream << " ]";
}
return stream;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Expand Up @@ -510,4 +510,5 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("_version", CppFunction::makeFallthrough());
m.impl("requires_grad_", CppFunction::makeFallthrough());
m.impl("retain_grad", CppFunction::makeFallthrough());
m.impl("_fw_primal", CppFunction::makeFallthrough());
}
2 changes: 2 additions & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -436,6 +436,7 @@ _(aten, logdet) \
_(aten, logit) \
_(aten, logspace) \
_(aten, logsumexp) \
_(aten, xlogy) \
_(aten, lstm) \
_(aten, lstm_cell) \
_(aten, lstsq) \
Expand Down Expand Up @@ -552,6 +553,7 @@ _(aten, permute) \
_(aten, pin_memory) \
_(aten, pinverse) \
_(aten, pixel_shuffle) \
_(aten, pixel_unshuffle) \
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_base.h
Expand Up @@ -251,7 +251,7 @@ struct Vec256 {
Vec256<T> angle() const {
// other_t_angle is for SFINAE and clarity. Make sure it is not changed.
static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
return Vec256(0);
return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T>
}
template <typename complex_t_angle = T,
typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0>
Expand Down
18 changes: 17 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Expand Up @@ -203,7 +203,23 @@ template <> class Vec256<BFloat16> {
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> angle() const {
return _mm256_set1_epi16(0);
__m256 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto angle_lambda = [](__m256 values) {
const auto zero_vec = _mm256_set1_ps(0.f);
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(M_PI);

const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
return angle;
};
auto o1 = angle_lambda(lo);
auto o2 = angle_lambda(hi);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> real() const {
return *this;
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_double.h
Expand Up @@ -108,7 +108,16 @@ template <> class Vec256<double> {
return _mm256_andnot_pd(mask, values);
}
Vec256<double> angle() const {
return _mm256_set1_pd(0);
const auto zero_vec = _mm256_set1_pd(0.f);
const auto nan_vec = _mm256_set1_pd(NAN);
const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_pd(M_PI);

const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask);
angle = _mm256_blendv_pd(angle, nan_vec, nan_mask);
return angle;
}
Vec256<double> real() const {
return *this;
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_float.h
Expand Up @@ -115,7 +115,16 @@ template <> class Vec256<float> {
return _mm256_andnot_ps(mask, values);
}
Vec256<float> angle() const {
return _mm256_set1_ps(0);
const auto zero_vec = _mm256_set1_ps(0.f);
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(M_PI);

const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
return angle;
}
Vec256<float> real() const {
return *this;
Expand Down
12 changes: 0 additions & 12 deletions aten/src/ATen/cpu/vec256/vec256_int.h
Expand Up @@ -121,9 +121,6 @@ class Vec256<int64_t> : public Vec256i {
auto inverse = _mm256_xor_si256(values, is_larger);
return _mm256_sub_epi64(inverse, is_larger);
}
Vec256<int64_t> angle() const {
return _mm256_set1_epi64x(0);
}
Vec256<int64_t> real() const {
return *this;
}
Expand Down Expand Up @@ -250,9 +247,6 @@ class Vec256<int32_t> : public Vec256i {
Vec256<int32_t> abs() const {
return _mm256_abs_epi32(values);
}
Vec256<int32_t> angle() const {
return _mm256_set1_epi32(0);
}
Vec256<int32_t> real() const {
return *this;
}
Expand Down Expand Up @@ -467,9 +461,6 @@ class Vec256<int16_t> : public Vec256i {
Vec256<int16_t> abs() const {
return _mm256_abs_epi16(values);
}
Vec256<int16_t> angle() const {
return _mm256_set1_epi16(0);
}
Vec256<int16_t> real() const {
return *this;
}
Expand Down Expand Up @@ -719,9 +710,6 @@ class Vec256<int8_t> : public Vec256i {
Vec256<int8_t> abs() const {
return _mm256_abs_epi8(values);
}
Vec256<int8_t> angle() const {
return _mm256_set1_epi8(0);
}
Vec256<int8_t> real() const {
return *this;
}
Expand Down
78 changes: 0 additions & 78 deletions aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Expand Up @@ -2498,84 +2498,6 @@ Tensor _thnn_rrelu_with_noise_forward(const Tensor & self, const Tensor & noise,
}
return output;
}
Tensor & _thnn_rrelu_with_noise_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training) {
const OptionalDeviceGuard device_guard(device_of(self));
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaDoubleRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
case ScalarType::Float: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
case ScalarType::Half: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaHalfRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
default:
AT_ERROR("_thnn_rrelu_with_noise_backward_out not supported on CUDAType for ", dispatch_scalar_type);
}
return grad_input;
}
Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training) {
const OptionalDeviceGuard device_guard(device_of(self));
auto dispatch_scalar_type = infer_scalar_type(self);
auto grad_input_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto grad_input = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(grad_input_));
switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
THNN_CudaDoubleRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
case ScalarType::Float: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
THNN_CudaRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
case ScalarType::Half: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto lower_ = lower.toDouble();
auto upper_ = upper.toDouble();
THNN_CudaHalfRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false);
break;
}
default:
AT_ERROR("_thnn_rrelu_with_noise_backward not supported on CUDAType for ", dispatch_scalar_type);
}
return grad_input;
}
Tensor & _thnn_rrelu_with_noise_forward_(Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator) {
const OptionalDeviceGuard device_guard(device_of(self));
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
27 changes: 27 additions & 0 deletions aten/src/ATen/native/AutogradComposite.cpp
@@ -0,0 +1,27 @@
#include <ATen/ATen.h>

namespace at {
namespace native {

/// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients.
/// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is.
/// This function is backward differentiable.
at::Tensor make_dual(const at::Tensor& primal, const at::Tensor& tangent, int64_t level) {
TORCH_CHECK(!primal.fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that "
"already has a forward gradient at the same level ", level, " is not supported.");

auto dual_tensor = primal.view(primal.sizes());
dual_tensor.set_fw_grad(tangent, level, /* is_inplace_op */ false);
return dual_tensor;
}

/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
/// is a view of the dual and the tangent is returned as is.
/// This function is backward differentiable.
std::tuple<at::Tensor, at::Tensor> unpack_dual(const at::Tensor& tensor, int64_t level) {
return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor.fw_grad(level));
}

} // namespace native

} // namespace at
38 changes: 38 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -62,6 +62,7 @@ DEFINE_DISPATCH(igammac_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);
DEFINE_DISPATCH(xlogy_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -1101,5 +1102,42 @@ Tensor& ldexp_(Tensor& self, const Tensor& other) {
return at::ldexp_out(self, self, other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
xlogy_stub(iter.device_type(), iter);
return result;
}

Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) {
return at::xlogy_out(result, c10::scalar_to_tensor(self, other.device()), other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) {
return at::xlogy_out(result, self, c10::scalar_to_tensor(other, self.device()));
}

Tensor xlogy(const Tensor& x, const Tensor& y) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, x, y);
xlogy_stub(iter.device_type(), iter);
return iter.output();
}

Tensor xlogy(Scalar x, const Tensor& y) {
return at::xlogy(c10::scalar_to_tensor(x, y.device()), y);
}

Tensor xlogy(const Tensor& x, Scalar y) {
return at::xlogy(x, c10::scalar_to_tensor(y, x.device()));
}

Tensor& xlogy_(Tensor& x, const Tensor& y) {
return at::xlogy_out(x, x, y);
}

Tensor& xlogy_(Tensor& x, Scalar y) {
return at::xlogy_out(x, x, c10::scalar_to_tensor(y, x.device()));
}

} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -74,5 +74,6 @@ DECLARE_DISPATCH(binary_fn, igammac_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);
DECLARE_DISPATCH(binary_fn, xlogy_stub);

}} // namespace at::native

0 comments on commit 99fb85c

Please sign in to comment.