Skip to content

Commit

Permalink
[Quant] onednn backend switch to ideep new api without affacting perf…
Browse files Browse the repository at this point in the history
…ormance (#91056)

> Reopen of #90354

**Summary**
Onednn quantization backend switch to new API in `third_party/ideep`.
- `struct forward_params` for conv/deconv are changed. Modify primitive cache accordingly.
- Use new versions of `prepare` and `compute` API. Fp32 and int8 paths separated. The old ones will be deprecated.
- Now `ideep::tensor::reorder_if_differ_in` supports block-to-block reorder. Use it instead of defining a util function `onednn_utils::try_reorder`.
- For new API of transposed convolution, we can use a flag to keep weight desc align with oneDNN thus needless to transpose it explicitly in PyTorch.
- Use `is_channels_last` flag to specify layout of src/dst when querying expected weight desc.

It won't impact correctness. Performance should be unaffected or slightly better.
FBGEMM and QNNPACK backends are not affected.

Performance results are given below.
1. End-to-end performance of static quantized models (from torchvision)
(throughput: fps, higher is better)
![image](https://user-images.githubusercontent.com/12522207/206105879-45c59996-9804-4531-aa1f-dc962e6db5ab.png)

2. Op benchmark of dynamic quantized linear
(Latency: ms, lower is better)
![image](https://user-images.githubusercontent.com/12522207/206124949-77352991-0fda-4285-a484-e20a5797262b.png)

Test method & env:
- Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz
- Run multi-instances on a single node. Use one core for each instance.
- Use Jemalloc and Intel OpenMP

**Test plan**
python test/test_quantization.py

Pull Request resolved: #91056
Approved by: https://github.com/jgong5
  • Loading branch information
Xia-Weiwen authored and pytorchmergebot committed Jan 18, 2023
1 parent fb50a4b commit 5a2ae88
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 106 deletions.
77 changes: 26 additions & 51 deletions aten/src/ATen/native/quantized/cpu/OnednnUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,78 +91,53 @@ struct LinearPrimitiveCache : PrimitiveCache {
struct ConvPrimitiveCache : PrimitiveCache {
ConvPrimitiveCache() {}

ConvPrimitiveCache(const PrimitiveCacheKey& key,
const ConvDesc& conv_desc,
const ideep::tensor& bias,
const ideep::attr_t bias_attr) {
ConvPrimitiveCache(
const PrimitiveCacheKey& key,
const ConvParams& params,
const ideep::tensor& bias) {
this->key = key;
this->primitive_desc = conv_desc;
this->primitive = Conv(this->primitive_desc);
// Construct tensor of input zero point
ideep::tensor::desc input_zp_desc = {{1}, ideep::data_type::s32, {1}};
this->input_zp_tensor.init(input_zp_desc, ideep::engine::cpu_engine());
auto zp_data_ptr = reinterpret_cast<int32_t *>(this->input_zp_tensor.get_data_handle());
zp_data_ptr[0] = std::get<InputZeroPoint>(key);
// Construct expected bias
this->expected_bias = bias.reorder_if_differ_in(conv_desc.bias_desc(), bias_attr);
this->params = params;
if (!bias.is_empty()) {
this->expected_bias =
bias.reorder_if_differ_in(params.pd.bias_desc(), params.bias_attr);
}
}

ConvDesc primitive_desc;
Conv primitive;
ideep::tensor input_zp_tensor;
ideep::tensor expected_bias;
ConvParams params;

inline ConvDesc& get_primitive_desc() {
return primitive_desc;
}

inline Conv& get_primitive() {
return primitive;
}

inline ideep::tensor& get_src_zp_tensor() {
return input_zp_tensor;
ConvParams& get_params() {
return params;
}

inline ideep::tensor& get_bias() {
ideep::tensor& get_bias() {
return expected_bias;
}
};

struct DeconvPrimitiveCache : PrimitiveCache {
DeconvPrimitiveCache() {}

DeconvPrimitiveCache(const PrimitiveCacheKey& key,
const DeconvDesc& deconv_desc,
const ideep::tensor& bias,
const ideep::attr_t bias_attr,
const ideep::tensor& input_zero_point) {
DeconvPrimitiveCache(
const PrimitiveCacheKey& key,
const DeconvParams& params,
const ideep::tensor& bias) {
this->key = key;
this->primitive_desc = deconv_desc;
this->primitive = Deconv(this->primitive_desc);
this->input_zp_tensor = std::move(input_zero_point);
// Construct expected bias
this->expected_bias = bias.reorder_if_differ_in(deconv_desc.bias_desc(), bias_attr);
this->params = params;
if (!bias.is_empty()) {
this->expected_bias =
bias.reorder_if_differ_in(params.pd.bias_desc(), params.bias_attr);
}
}

DeconvDesc primitive_desc;
Deconv primitive;
ideep::tensor input_zp_tensor;
DeconvParams params;
ideep::tensor expected_bias;

inline DeconvDesc& get_primitive_desc() {
return primitive_desc;
}

inline Deconv& get_primitive() {
return primitive;
}

inline ideep::tensor& get_src_zp_tensor() {
return input_zp_tensor;
DeconvParams& get_params() {
return params;
}

inline ideep::tensor& get_bias() {
ideep::tensor& get_bias() {
return expected_bias;
}
};
Expand Down
53 changes: 19 additions & 34 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,6 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
// Scales of ONEDNN and PyTorch are reciprocal
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input_scale);
const ideep::scale_t& weights_scales = weights.get_scale();
int64_t scale_size = weights_scales.size();
double inv_output_scale = 1.0/output_scale;
const ideep::zero_point_t src_zero_points = ideep::zero_point_t(1, input_zp);
const ideep::zero_point_t dst_zero_points = ideep::zero_point_t(1, output_zero_point);
Expand All @@ -1274,29 +1273,25 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
ideep::convolution_transpose_forward::prepare(
params, src, weights, b, dst_dims, dst,
strides, padding_l, padding_r, dilates, groups(),
src_scales, weights_scales, ideep::scale_t(scale_size, inv_output_scale),
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
src_zero_points, dst_zero_points, op_attr,
dnnl::algorithm::deconvolution_direct,
dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
get_deconv_cache() = DeconvPrimitiveCache(
cache_key, params.pd, b, params.bias_attr, params.input_zero_point);
onednn_utils::try_reorder(
weights, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
get_deconv_cache() = DeconvPrimitiveCache(cache_key, params, b);
weights = weights.reorder_if_differ_in(params.pd.weights_desc());
});
if (get_deconv_cache().hit(cache_key)) {
Deconv& primitive = get_deconv_cache().get_primitive();
DeconvDesc& pd = get_deconv_cache().get_primitive_desc();
auto& src_zp_tensor = get_deconv_cache().get_src_zp_tensor();
DeconvParams& params = get_deconv_cache().get_params();
auto& expected_bias = get_deconv_cache().get_bias();
ideep::convolution_transpose_forward::compute(
pd, primitive, src, weights, expected_bias, dst, src_zp_tensor, groups());
ideep::convolution_transpose_forward::compute<false, false>(
params, src, weights, expected_bias, dst);
} else {
ideep::convolution_transpose_forward::compute_v2(
ideep::convolution_transpose_forward::compute(
src, weights, b, dst_dims, dst,
strides, padding_l, padding_r, dilates,
groups(), src_scales, weights_scales,
ideep::scale_t(scale_size, inv_output_scale),
ideep::scale_t(1, inv_output_scale),
src_zero_points, dst_zero_points, op_attr,
dnnl::algorithm::deconvolution_direct,
dnnl::prop_kind::forward_inference,
Expand All @@ -1306,42 +1301,32 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
PrimitiveCacheKey cache_key = std::make_tuple(
input_scale, input_zp, src_dims, output_scale, output_zero_point, num_threads);
c10::call_once(*cache_initialized_flag, [&](){
src.set_zero_point(src_zero_points);
dst.set_zero_point(dst_zero_points);
ConvParams params;
ideep::convolution_forward::prepare(
params, src, weights, b, dst_dims, dst,
strides, dilates, padding_l, padding_r, groups(),
src_scales, weights_scales, ideep::scale_t(scale_size, inv_output_scale),
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
src_zero_points, dst_zero_points,
op_attr, dnnl::algorithm::convolution_direct,
dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
get_conv_cache() = ConvPrimitiveCache(cache_key, params.pd, b, params.bias_attr);
onednn_utils::try_reorder(
weights, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
get_conv_cache() = ConvPrimitiveCache(cache_key, params, b);
weights = weights.reorder_if_differ_in(params.pd.weights_desc());
});
// If hit, use cached data. If miss, fall back to normal path.
if (get_conv_cache().hit(cache_key)) {
ConvDesc& pd = get_conv_cache().get_primitive_desc();
Conv& primitive = get_conv_cache().get_primitive();
auto& src_zp_tensor = get_conv_cache().get_src_zp_tensor();
auto& params = get_conv_cache().get_params();
auto& expected_bias = get_conv_cache().get_bias();
ideep::convolution_forward::compute(
pd, primitive, src, weights, expected_bias, dst, src_zp_tensor, groups());
ideep::convolution_forward::compute<false, false>(params, src, weights, expected_bias, dst);
} else {
src.set_zero_point(src_zero_points);
dst.set_zero_point(dst_zero_points);
ConvParams params;
ideep::convolution_forward::prepare(
params, src, weights, b, dst_dims, dst,
ideep::convolution_forward::compute(
src, weights, b, dst_dims, dst,
strides, dilates, padding_l, padding_r, groups(),
src_scales, weights_scales, ideep::scale_t(scale_size, inv_output_scale),
op_attr, dnnl::algorithm::convolution_direct,
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
src_zero_points, dst_zero_points, op_attr,
dnnl::algorithm::convolution_direct,
dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
onednn_utils::try_reorder(
weights, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
ideep::convolution_forward::compute(params, src, weights, b, dst);
}
}
return output;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
ideep::tag w_tag = ideep::tag::any;
const bool with_groups = groups > 1;
if (transpose) {
w_desc = ideep::convolution_transpose_forward::expected_weights_desc(
// template args: <(src/dst) is_channels_last, transposed>
w_desc = ideep::convolution_transpose_forward::expected_weights_desc<true, false>(
dims, dnnl::memory::data_type::s8,
strides, padding_l, padding_r, dilates, groups,
dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
Expand All @@ -419,15 +420,14 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
dims_giohw = with_groups ? ideep::utils::group_dims(dims_iohw, groups) : dims_iohw;
std::vector<int64_t> perms(dims_giohw.size(), 0); // for permutation of weight
std::iota(perms.begin(), perms.end(), 0);
w_desc = w_desc.transpose(with_groups, with_groups + 1);
std::swap(perms[with_groups], perms[with_groups + 1]);
weight_copy = weight.reshape(dims_giohw).permute(c10::IntArrayRef(perms)).clone();
} else {
w_desc = ideep::convolution_forward::expected_weights_desc(
dims, dnnl::memory::data_type::s8,
strides, padding_l, padding_r, dilates, groups,
dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
dnnl::memory::data_type::u8, ideep::dims(), op_attr);
dnnl::memory::data_type::u8, ideep::dims(), op_attr, /*is_channels_last=*/true);
weight_copy = weight.clone();
}
if (with_groups) {
Expand Down
17 changes: 9 additions & 8 deletions aten/src/ATen/native/quantized/cpu/qlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,19 +844,20 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
c10::call_once(*cache_initialized_flag, [&](){
LinearParams params;
ideep::matmul_forward::prepare</*is_dynamic=*/false>(
params, x, w, b, y, 1.0f, 1.0f,
params, x, w, b, y,
src_scales, weights_scales, dst_scales,
src_zero_point, dst_zero_point, op_attr);
get_cache() = LinearPrimitiveCache(cache_key, params);
onednn_utils::try_reorder(
w, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr);
get_cache() = LinearPrimitiveCache(cache_key, params, b);
w = w.reorder_if_differ_in(params.pd.weights_desc());
});
if (get_cache().hit(cache_key)) {
LinearParams& params = get_cache().get_param();
ideep::matmul_forward::compute(params, x, w, b, y);
auto& expected_bias = get_cache().get_expected_bias();
ideep::matmul_forward::compute<false, false>(params, x, w, expected_bias, y);
} else {
ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f, src_scales, weights_scales,
dst_scales, src_zero_point, dst_zero_point, op_attr);
ideep::matmul_forward::compute(x, w, b, y, src_scales, weights_scales,
dst_scales, src_zero_point, dst_zero_point,
1.0f, 1.0f, op_attr);
}
auto out_sizes = input.sizes().vec();
out_sizes.back() = N;
Expand Down
18 changes: 8 additions & 10 deletions aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,22 +567,20 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
c10::call_once(*cache_initialized_flag, [&](){
LinearParams params;
ideep::matmul_forward::prepare</*is_dynamic=*/true>(
params, x, w, b, y, 1.0f, 1.0f,
params, x, w, b, y,
src_scales, weights_scales, ideep::scale_t(),
src_zero_point, ideep::zero_point_t(), op_attr);
src_zero_point, ideep::zero_point_t(), 1.0f, 1.0f, op_attr);
get_cache() = LinearPrimitiveCache(cache_key, params);
onednn_utils::try_reorder(
w, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
w = w.reorder_if_differ_in(params.pd.weights_desc());
});
if (get_cache().hit_dynamic(cache_key)) {
LinearParams& params = get_cache().get_param();
ideep::matmul_forward::compute_dynamic(
params, x, w, b, y, 1.0f, 1.0f, src_scales, weights_scales,
ideep::scale_t(), src_zero_point, ideep::zero_point_t());
ideep::matmul_forward::compute(params, x, w, b, y, src_scales, src_zero_point);
} else {
ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f,
src_scales, weights_scales, ideep::scale_t(),
src_zero_point, ideep::zero_point_t(), op_attr);
ideep::matmul_forward::compute(x, w, b, y,
src_scales, weights_scales, ideep::scale_t(),
src_zero_point, ideep::zero_point_t(),
1.0f, 1.0f, op_attr);
}
auto out_sizes = input.sizes().vec();
out_sizes.back() = w.get_dim(1);
Expand Down

0 comments on commit 5a2ae88

Please sign in to comment.