Skip to content

Commit

Permalink
rebase on "eager quant: remove fake_quant after add/mul nodes during …
Browse files Browse the repository at this point in the history
…QAT"


Summary:

Changes behavior of Eager mode quantization to remove observation after `add_scalar/mul_scalar`.
This is not used, and it removes one difference between Eager and FX modes.

Test Plan:

```
python test/test_quantization.py TestQuantizeFxOps.test_quantized_add_qat
python test/test_quantization.py TestQuantizeFxOps.test_quantized_mul_qat
python test/test_quantization.py TestQuantizationAwareTraining.test_add_scalar_uses_input_qparams
python test/test_quantization.py TestQuantizationAwareTraining.test_mul_scalar_uses_input_qparams
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D25486276](https://our.internmc.facebook.com/intern/diff/D25486276)

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Dec 16, 2020
2 parents 2c58920 + 76d09ec commit a620fe7
Show file tree
Hide file tree
Showing 37 changed files with 527 additions and 276 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/VmapModeRegistrations.cpp
Expand Up @@ -79,15 +79,15 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) {

m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, optional<Generator>, TENSOROPTIONS>);
m.impl_UNBOXED("rand.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, const TensorOptions&>);
m.impl_UNBOXED("rand.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, const TensorOptions&>);
m.impl("rand.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, TENSOROPTIONS>);
m.impl("rand.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, TENSOROPTIONS>);
m.impl("rand.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
m.impl("rand.generator_out", unsupportedRandomOp_<IntArrayRef, optional<Generator>, Tensor&>);

m.impl("randn", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
m.impl("randn.generator", unsupportedRandomOp<IntArrayRef, optional<Generator>, TENSOROPTIONS>);
m.impl_UNBOXED("randn.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, const TensorOptions&>);
m.impl_UNBOXED("randn.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, const TensorOptions&>);
m.impl("randn.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, TENSOROPTIONS>);
m.impl("randn.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, TENSOROPTIONS>);
m.impl("randn.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
m.impl("randn.generator_out", unsupportedRandomOp_<IntArrayRef, optional<Generator>, Tensor&>);

Expand Down
21 changes: 7 additions & 14 deletions aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
Expand Up @@ -265,20 +265,13 @@ namespace impl {
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(std::move(v));
}
};
template<bool AllowDeprecatedTypes>
struct ivalue_to_arg<optional<ArrayRef<int64_t>>, AllowDeprecatedTypes> final {
// If an argument is optional<ArrayRef<int64_t>>, convert the IValue to a optional<std::vector<int64_t>> and pass that
// to the operator.
static OptionalArray<int64_t> call(IValue&& v) {
return std::move(v).toOptionalIntArray();
}
};
template<bool AllowDeprecatedTypes>
struct ivalue_to_arg<optional<ArrayRef<double>>, AllowDeprecatedTypes> final {
// If an argument is optional<ArrayRef<T>>, convert the IValue to a optional<std::vector<T>> and pass that
// to the operator.
static OptionalArray<double> call(IValue&& v) {
return std::move(v).toOptionalDoubleArray();
template<class T, bool AllowDeprecatedTypes>
struct ivalue_to_arg<optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
// If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that
// to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but impliticly convertible
// to optional<ArrayRef<T>>.
static OptionalArray<T> call(IValue&& v) {
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(std::move(v));
}
};

Expand Down
8 changes: 0 additions & 8 deletions aten/src/ATen/core/ivalue.h
Expand Up @@ -705,14 +705,6 @@ struct CAFFE2_API IValue final {
template <typename T>
optional<T> toOptional();

/// @private [doxygen private]
/// Only for use in generated code.
OptionalArray<int64_t> toOptionalIntArray();

/// @private [doxygen private]
/// Only for use in generated code.
OptionalArray<double> toOptionalDoubleArray();

/// @private [doxygen private]
/// this is a shallow comparison of two IValues to test the object identity
bool isSameIdentity(const IValue& rhs) const;
Expand Down
54 changes: 30 additions & 24 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -861,6 +861,36 @@ c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>>) {
return impl::toTypedList<Elem>(std::move(ivalue).toList());
}

template <typename T>
static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) {
std::vector<T> result;
result.reserve(impl->list.size());
for (size_t i = 0, N = impl->list.size(); i < N; ++i) {
result.push_back(impl->list[i].to<T>());
}
return result;
}

template <typename T>
static std::vector<T> createVectorFromList(const c10::List<T>& impl) {
std::vector<T> result;
result.reserve(impl.size());
for (size_t i = 0, N = impl.size(); i < N; ++i) {
result.push_back(impl[i]);
}
return result;
}

template <typename T>
OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>>) {
if (ivalue.isNone()) {
return {};
}
return createVectorFromList<T>(
std::move(ivalue).to<c10::List<T>>()
);
}

namespace detail {
template <typename Elem, size_t... I>
std::array<Elem, sizeof...(I)> generic_to_array(
Expand Down Expand Up @@ -952,16 +982,6 @@ inline T IValue::to() const& {
return generic_to(*this, _fake_type<T>{});
}

template <typename T>
static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) {
std::vector<T> result;
result.reserve(impl->list.size());
for (size_t i = 0, N = impl->list.size(); i < N; ++i) {
result.push_back(impl->list[i].to<T>());
}
return result;
}

inline c10::List<int64_t> IValue::toIntList() && {
AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
return c10::List<int64_t>(moveToIntrusivePtr<c10::detail::ListImpl>());
Expand Down Expand Up @@ -1211,20 +1231,6 @@ inline optional<T> IValue::toOptional() {
return this->to<T>();
}

inline OptionalArray<int64_t> IValue::toOptionalIntArray() {
if (this->isNone()) {
return {};
}
return this->toIntVector();
}

inline OptionalArray<double> IValue::toOptionalDoubleArray() {
if (this->isNone()) {
return {};
}
return this->toDoubleVector();
}

inline bool IValue::isCustomClass() const {
return torch::isCustomClass(*this);
}
Expand Down
21 changes: 21 additions & 0 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Expand Up @@ -290,6 +290,27 @@ Tensor index(const Tensor & self, TensorList indices) {
return iter.output();
}

Tensor quantized_index(const Tensor & self, TensorList indices) {
TORCH_INTERNAL_ASSERT(
self.qscheme() == c10::kPerTensorAffine ||
self.qscheme() == c10::kPerTensorSymmetric,
"Indexing is only supported for per-Tensor quantized Tensors.");

// For now, this is a naive implementation which does dq -> index -> q.
// TODO(future PR): improve performance by removing the copies.
const auto& self_dq = self.dequantize();

TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");

auto info = make_info(self_dq, indices);
auto iter = make_index_iterator(info);
index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides);
at::Tensor res = iter.output();

return at::quantize_per_tensor(
res, self.q_scale(), self.q_zero_point(), self.scalar_type());
}

Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) {
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
at::assert_no_internal_overlap(result);
Expand Down
27 changes: 22 additions & 5 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -106,9 +106,11 @@
variants: method

- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
use_c10_dispatcher: full
variants: method

- func: rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)
use_c10_dispatcher: full
variants: method

- func: align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)
Expand Down Expand Up @@ -1738,6 +1740,7 @@
use_c10_dispatcher: full

- func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
device_guard: False

- func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
Expand Down Expand Up @@ -1942,6 +1945,7 @@
variants: function, method

- func: unflatten.int(Tensor(a) self, int dim, int[] sizes, Dimname[]? names=None) -> Tensor(a)
use_c10_dispatcher: full
variants: method

- func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a)
Expand Down Expand Up @@ -2023,6 +2027,7 @@
CPU, CUDA: frac_out

- func: full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
device_guard: False

- func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
Expand Down Expand Up @@ -2200,6 +2205,7 @@
variants: function, method
dispatch:
CPU, CUDA: index
QuantizedCPU: quantized_index
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
# NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
# - Tensor Tensor::index(ArrayRef<TensorIndex> indices)
Expand Down Expand Up @@ -3164,6 +3170,7 @@
variants: function

- func: ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
device_guard: False

- func: ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
Expand Down Expand Up @@ -3310,9 +3317,11 @@
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures

- func: rand.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
device_guard: False

- func: rand.generator_with_names(int[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
device_guard: False

- func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
Expand Down Expand Up @@ -3367,9 +3376,11 @@
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures

- func: randn.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
device_guard: False

- func: randn.generator_with_names(int[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
device_guard: False

- func: randn.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
Expand Down Expand Up @@ -4442,6 +4453,7 @@
variants: function

- func: zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
device_guard: False

- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
Expand Down Expand Up @@ -9565,7 +9577,8 @@
CPU: slow_conv_transpose2d_cpu
CUDA: slow_conv_transpose2d_cuda

- func: slow_conv_transpose2d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
- func: slow_conv_transpose2d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn
dispatch:
CPU: slow_conv_transpose2d_backward_out_cpu
Expand All @@ -9592,7 +9605,8 @@
CPU: slow_conv_transpose3d_cpu
CUDA: slow_conv_transpose3d_cuda

- func: slow_conv_transpose3d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
- func: slow_conv_transpose3d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn
dispatch:
CPU: slow_conv_transpose3d_backward_out_cpu
Expand Down Expand Up @@ -9627,7 +9641,8 @@
CPU: slow_conv2d_forward_cpu
CUDA: legacy::cuda::_thnn_conv2d_forward

- func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
- func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn
dispatch:
CPU: slow_conv2d_backward_out_cpu
Expand Down Expand Up @@ -9660,7 +9675,8 @@
dispatch:
CUDA: legacy::cuda::_thnn_conv_depthwise2d_forward

- func: thnn_conv_depthwise2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight) -> (Tensor(a!), Tensor(b!))
- func: thnn_conv_depthwise2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!) grad_input, Tensor(b!) grad_weight) -> (Tensor(a!), Tensor(b!))
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn
dispatch:
CUDA: thnn_conv_depthwise2d_backward_out
Expand Down Expand Up @@ -9691,7 +9707,8 @@
dispatch:
CPU: slow_conv3d_forward_cpu

- func: slow_conv3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
- func: slow_conv3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn
dispatch:
CPU: slow_conv3d_backward_out_cpu
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/record_function.cpp
Expand Up @@ -277,10 +277,12 @@ class CallbackManager {
bool is_start) {
try {
if (is_start) {
ctx = rfcb.start()(rf);
ctx = rfcb.start() ? rfcb.start()(rf) : nullptr;
}
else {
rfcb.end()(rf, ctx.get());
if (rfcb.end()) {
rfcb.end()(rf, ctx.get());
}
}
return true;
} catch (const std::exception &e) {
Expand Down
20 changes: 11 additions & 9 deletions aten/src/ATen/record_function.h
Expand Up @@ -305,14 +305,16 @@ struct TORCH_API RecordFunction {
*/
class TORCH_API RecordFunctionCallback {
public:
using StartCallback = std::unique_ptr<ObserverContext>(*)(const RecordFunction&);
using EndCallback = void (*)(const RecordFunction&, ObserverContext*);

// This interface supports observers that require passing an ObserverContext
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
std::function<void(const RecordFunction&, ObserverContext*)> end =
[](const RecordFunction&, ObserverContext*) {}):
start_(std::move(start)),
end_(std::move(end)) {
StartCallback start,
EndCallback end = nullptr) :
start_(start),
end_(end) {
scopes_.fill(true);
}

Expand Down Expand Up @@ -368,18 +370,18 @@ class TORCH_API RecordFunctionCallback {
return scopes_[(size_t)sc];
}

inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
inline StartCallback start() const {
return start_;
}

inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
inline EndCallback end() const {
return end_;
}

private:
friend class CallbackManager;
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
std::function<void(const RecordFunction&, ObserverContext*)> end_;
StartCallback start_;
EndCallback end_;
bool(*should_run_)(const RecordFunctionCallback&) = nullptr;
double sampling_prob_ = 1.0;
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
Expand Down
10 changes: 5 additions & 5 deletions binaries/record_function_benchmark.cc
Expand Up @@ -19,10 +19,10 @@ const float kLowSamplingProb = 0.0001;

void addTestCallback(
double sampling_prob = 1.0,
std::function<std::unique_ptr<at::ObserverContext>(const at::RecordFunction&)> fn =
[](const at::RecordFunction&) { return nullptr; }) {
at::RecordFunctionCallback::StartCallback fn =
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
auto cb = at::RecordFunctionCallback(
std::move(fn),
fn,
[](const at::RecordFunction&, at::ObserverContext*) {})
.needsInputs(false);
if (sampling_prob < 1.0) {
Expand Down Expand Up @@ -106,10 +106,10 @@ int main(int argc, char** argv) {
at::clearCallbacks();

std::cout << "Checking number of sampled observer invocations" << std::endl;
int cb_count = 0;
static int cb_count = 0;
addTestCallback(
kLowSamplingProb,
[&](const at::RecordFunction& fn) {
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
++cb_count;
return nullptr;
}
Expand Down
2 changes: 2 additions & 0 deletions c10/core/TensorOptions.h
Expand Up @@ -691,6 +691,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
return DeviceType::Vulkan;
} else if (tid == DispatchKey::Metal) {
return DeviceType::Metal;
} else if (tid == DispatchKey::QuantizedCPU) {
return DeviceType::CPU;
} else {
AT_ASSERTM(false, "Unknown DispatchKey: ", tid);
}
Expand Down

0 comments on commit a620fe7

Please sign in to comment.