Skip to content

Commit

Permalink
[wip] quantization: store safe_on_fbgemm flag on quantized conv
Browse files Browse the repository at this point in the history
Summary:

This is a start of fixing the problems surfaced in #46749.
This particular PR only fixes a small part of this:
1. if a conv module is unsafe to run in fbgemm, we now persist this
information.
2. if we are in an fbgemm kernel and we detect that the current conv
packed params are tagged as unsafe, we throw an error.

For now, this PR is a WIP to get some early feedback if this is the
right direction, since iteration cost on this is high. In particular,
missing things here are:
* better unit testing
* serialization, verifying that this is BC
* getting all the conv callsites (currently just module + conv2d is handled)

Note: there were some potential improvements discussed on dynamically
dispatching to qnnpack if it is available and the flag is set.  This PR
does not attempt to solve this issue - it can be solved by future PRs.

Test Plan:

```
python test/test_quantization.py TestQuantizedOps.test_conv_reduce_range
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 823b16eb0f825f40e1d90228526deb3cef637b21
Pull Request resolved: #59984
  • Loading branch information
vkuzo committed Jul 15, 2021
1 parent 968a01a commit 7b55e17
Show file tree
Hide file tree
Showing 13 changed files with 254 additions and 81 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/native/quantized/cpu/conv_packed_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
virtual torch::List<int64_t> dilation() const = 0;
virtual int64_t groups() const = 0;
virtual bool transpose() const = 0;
virtual bool input_qrange_le_128() const = 0;
};
30 changes: 23 additions & 7 deletions aten/src/ATen/native/quantized/cpu/conv_serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
* - groups
* - flags (bitmask)
* - (1 << 0) transpose (1 = yes)
* - (1 << 1) input_qrange_le_128 (1 = yes)
* 2. list of optional tensors
* 0: None (helps with type inference)
* 1: weight (this must be present)
Expand Down Expand Up @@ -91,6 +92,8 @@ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
// inputs
if (version_str == "2") {
version = 2;
} else if (version_str == "3") {
version = 3;
}
} else if (firstElement.isInt()) {
auto raw_version = firstElement.toInt();
Expand Down Expand Up @@ -138,7 +141,11 @@ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
}
config_vals.push_back(groups[0].item<int16_t>());
// transpose does not exist in v1, so we fill in a default value
config_vals.push_back(0);
int64_t flags_transpose = (0 << 0);
// input_qrange_le_128 does not exist in v1, so we default to true
int64_t flags_input_qrange_le_128 = (1 << 1);
int64_t flags = flags_transpose | flags_input_qrange_le_128;
config_vals.push_back(flags);

std::vector<c10::optional<at::Tensor>> tensors;
tensors.emplace_back();
Expand Down Expand Up @@ -169,6 +176,8 @@ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
for (const auto i : c10::irange(config_a.size(0))) {
config_vals.emplace_back(config_a[i]);
}
// set default for input_qrange_le_128
config_vals[config_vals.size() - 1] = config_vals[config_vals.size() - 1] | (1 << 1);

auto weight = non_optional[1];
auto bias = optional[0];
Expand All @@ -188,7 +197,7 @@ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
}
}

#define QCONV_SERIALIZATION_VERSION 2
#define QCONV_SERIALIZATION_VERSION 3

#if QCONV_SERIALIZATION_VERSION == 2
using ConvParamsSerializationType = ConvParamsSerializationTypeV2;
Expand All @@ -197,7 +206,7 @@ template <uint32_t kSpatialDim>
ConvParamsSerializationTypeV2 serialize_conv(
const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {

std::string version = "2";
std::string version = "3";
std::vector<at::Tensor> non_optional;
std::vector<c10::optional<at::Tensor>> optional;

Expand All @@ -215,6 +224,7 @@ ConvParamsSerializationTypeV2 serialize_conv(
output_padding.end());
params_vec.push_back(params->groups());
params_vec.push_back(params->transpose());
params_vec.push_back(params->input_qrange_le_128());
int64_t vec_size = params_vec.size();
at::Tensor params_tensor = at::from_blob(
params_vec.data(), {vec_size},
Expand Down Expand Up @@ -251,7 +261,10 @@ ConvParamsSerializationTypeV3 serialize_conv(
config_vals.insert(config_vals.end(), output_padding.begin(),
output_padding.end());
config_vals.push_back(params->groups());
config_vals.push_back(params->transpose());
int64_t flags_transpose = (params->transpose() << 0);
int64_t flags_input_qrange_le_128 = (params->input_qrange_le_128() << 1);
int64_t flags = (flags_transpose | flags_input_qrange_le_128);
config_vals.push_back(flags);

at::Tensor weight;
c10::optional<at::Tensor> bias;
Expand Down Expand Up @@ -320,8 +333,9 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
config_vals.size());

bool transpose = flags & (1 << 0);
bool input_qrange_le_128 = flags & (1 << 1);

int64_t other_flags = flags & ~(1 << 0);
int64_t other_flags = flags & ~((1 << 0) | (1 << 1));
TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, ".");

auto& ctx = at::globalContext();
Expand All @@ -336,7 +350,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
output_padding,
dilation,
groups,
transpose
transpose,
input_qrange_le_128
);
}
#endif // USE_FBGEMM
Expand All @@ -354,7 +369,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
output_padding,
dilation,
groups,
transpose
transpose,
input_qrange_le_128
);
}
#endif // USE_PYTORCH_QNNPACK
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct TORCH_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
torch::List<int64_t> dilation,
int64_t groups,
uint8_t transpose,
bool input_qrange_le_128,
std::vector<int32_t> col_offsets,
std::vector<int64_t> kernel,
std::vector<float> w_scale,
Expand All @@ -153,6 +154,7 @@ struct TORCH_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
dilation_(std::move(dilation)),
groups_(groups),
transpose_(transpose),
input_qrange_le_128_(input_qrange_le_128),
col_offsets(std::move(col_offsets)),
kernel(std::move(kernel)),
w_scale(std::move(w_scale)),
Expand All @@ -167,6 +169,7 @@ struct TORCH_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
torch::List<int64_t> dilation_;
int64_t groups_;
uint8_t transpose_;
uint8_t input_qrange_le_128_;
std::vector<int32_t> col_offsets;
std::vector<int64_t> kernel;
std::vector<float> w_scale;
Expand All @@ -193,7 +196,8 @@ struct TORCH_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose);
bool transpose,
bool input_qrange_le_128);

const float* GetBiasData(at::Tensor* bias);

Expand Down Expand Up @@ -227,6 +231,10 @@ struct TORCH_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
return (bool)transpose_;
}

bool input_qrange_le_128() const override {
return (bool)input_qrange_le_128_;
}

private:
template <bool ReluFused>
at::Tensor apply_impl(
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,13 @@ at::Tensor PackedConvWeight<kSpatialDim>::apply_impl(
ConvDimChecks<kSpatialDim>(
act.ndimension(), stride().size(), padding().size(),
output_padding().size(), dilation().size(), func_name, transpose());
// TODO(before land): figure out if this check needs to be architecture
// specific and fix it, instead of always failing
// TODO(before land): create a separate issue to describe the problem
// and the workarounds, and link to it from the error message
if (!input_qrange_le_128()) {
TORCH_WARN_ONCE("This module has a potential to saturate, TODO link to issue");
}

const int N = act.size(0);
const int C = act.size(1);
Expand Down
47 changes: 30 additions & 17 deletions aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose) {
bool transpose,
bool input_qrange_le_128) {
TORCH_CHECK(
weight.ndimension() == kSpatialDim + 2,
"Weights are expected to have ",
Expand Down Expand Up @@ -166,6 +167,7 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
dilation,
groups,
transpose,
input_qrange_le_128,
col_offsets,
kSpatialDim == 2 ? std::vector<int64_t>{kernel_h, kernel_w}
: std::vector<int64_t>{kernel_d, kernel_h, kernel_w},
Expand All @@ -192,7 +194,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsQnnp<
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose) {
bool transpose,
bool input_qrange_le_128) {
TORCH_CHECK(
kSpatialDim == 2 || kSpatialDim == 3, // 1D is packed as 2d, hence we don't need other checks
"QNNPACK packing only supports 2D / 3D convolution.");
Expand Down Expand Up @@ -288,6 +291,7 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsQnnp<
dilation,
groups,
transpose,
input_qrange_le_128,
c10::nullopt, /* input_scale */
{kernel_h, kernel_w},
w_scales,
Expand All @@ -308,7 +312,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<2>> PackedConvWeightsQnnp<
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose);
bool transpose,
bool input_qrange_le_128);
#endif // USE_PYTORCH_QNNPACK

namespace at {
Expand All @@ -324,26 +329,29 @@ class QConvPackWeightInt8 final {
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups) {
int64_t groups,
bool input_qrange_le_128) {
torch::List<int64_t> output_padding;
output_padding.reserve(kSpatialDim);
for (int idx = 0; idx < kSpatialDim; ++idx) {
output_padding.push_back((int64_t)0);
}
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
/*transpose=*/false);
/*transpose=*/false, input_qrange_le_128);
}

// TODO: add input_qrange_le_128 here
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_deconv(
Tensor weight,
c10::optional<Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups) {
int64_t groups,
bool input_qrange_le_128) {
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
/*transpose=*/true);
/*transpose=*/true, input_qrange_le_128);
}

private:
Expand All @@ -355,13 +363,14 @@ class QConvPackWeightInt8 final {
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose) {
bool transpose,
bool input_qrange_le_128) {
auto& ctx = at::globalContext();
#ifdef USE_FBGEMM
if (ctx.qEngine() == at::QEngine::FBGEMM) {
return PackedConvWeight<kSpatialDim>::prepack(
weight, bias, stride, padding, output_padding, dilation, groups,
transpose);
transpose, input_qrange_le_128);
}
#endif

Expand All @@ -373,7 +382,7 @@ class QConvPackWeightInt8 final {
"and Conv2d now.");
return PackedConvWeightsQnnp<kSpatialDim>::prepack(
weight, bias, stride, padding, output_padding, dilation, groups,
transpose);
transpose, input_qrange_le_128);
}
#endif

Expand All @@ -394,10 +403,11 @@ class QConv1dPackWeightInt8 final {
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups) {
int64_t groups,
bool input_qrange_le_128) {
const torch::List<int64_t> output_padding({0});
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
/*transpose=*/false);
/*transpose=*/false, input_qrange_le_128);
}

static c10::intrusive_ptr<ConvPackedParamsBase<2>> run_deconv(
Expand All @@ -407,9 +417,10 @@ class QConv1dPackWeightInt8 final {
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups) {
int64_t groups,
bool input_qrange_le_128) {
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
/*transpose=*/true);
/*transpose=*/true, input_qrange_le_128);
}

private:
Expand All @@ -421,7 +432,9 @@ class QConv1dPackWeightInt8 final {
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose) {
bool transpose,
bool input_qrange_le_128) {
// TODO: use the flag
auto& ctx = at::globalContext();
if (weight.dim() == 3) {
weight = weight.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
Expand All @@ -434,7 +447,7 @@ class QConv1dPackWeightInt8 final {
if (ctx.qEngine() == at::QEngine::FBGEMM) {
return PackedConvWeight<2>::prepack(
weight, bias, stride, padding, output_padding, dilation, groups,
transpose);
transpose, input_qrange_le_128);
}
#endif

Expand All @@ -444,7 +457,7 @@ class QConv1dPackWeightInt8 final {
if (ctx.qEngine() == at::QEngine::QNNPACK) {
return PackedConvWeightsQnnp<2>::prepack(
weight, bias, stride, padding, output_padding, dilation, groups,
transpose);
transpose, input_qrange_le_128);
}
#endif
TORCH_CHECK(
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/native/quantized/cpu/qnnpack_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
torch::List<int64_t> dilation,
int64_t groups,
bool transpose,
bool input_qrange_le_128,
c10::optional<double> input_scale,
std::vector<int64_t> kernel,
at::Tensor w_scale,
Expand All @@ -111,6 +112,7 @@ struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
dilation_(std::move(dilation)),
groups_(groups),
transpose_(transpose),
input_qrange_le_128_(input_qrange_le_128),
input_scale(input_scale),
kernel_(std::move(kernel)),
w_scales(w_scale),
Expand Down Expand Up @@ -237,6 +239,7 @@ struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
torch::List<int64_t> dilation_;
int64_t groups_;
bool transpose_;
bool input_qrange_le_128_;
c10::optional<double> input_scale;
std::vector<int64_t> kernel_;
at::Tensor w_scales;
Expand Down Expand Up @@ -265,7 +268,8 @@ struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose);
bool transpose,
bool input_qrange_le_128);

torch::List<int64_t> stride() const override {
return stride_;
Expand All @@ -291,6 +295,10 @@ struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
return transpose_;
}

bool input_qrange_le_128() const override {
return input_qrange_le_128_;
}

private:
template <bool ReluFused>
at::Tensor apply_impl(
Expand Down

0 comments on commit 7b55e17

Please sign in to comment.