Skip to content

Commit

Permalink
[Quant][Inductor] Enable the lowering of quantized maxpool2d (#105906)
Browse files Browse the repository at this point in the history
**Summary**
Enable the `dq-maxpool2d-q` pattern match and lower into `torch.ops.quantized.max_pool2d`.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_qmaxpool2d
python -m pytest test_quantized_op.py -k test_max_pool2d_pt2e
```

Pull Request resolved: #105906
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456, #105639
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Aug 26, 2023
1 parent 70ca18f commit 9319dd1
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 67 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))

#define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)

#define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))

#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
Expand Down
146 changes: 90 additions & 56 deletions aten/src/ATen/native/quantized/cpu/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/_empty_affine_quantized.h>
#include <ATen/ops/quantized_max_pool1d.h>
#include <ATen/ops/quantized_max_pool1d_native.h>
Expand Down Expand Up @@ -254,67 +255,92 @@ Tensor q_maxpool_2d(
// In this case, we can preserve the data layout in memory
// as well as use a loop nest that is more amenable to
// vectorization.
Tensor qy = at::_empty_affine_quantized(
Tensor qy;
if constexpr(std::is_same_v<Q, uint8_t>) {
qy = at::empty(
oSizes,
qx.options()
.dtype(toQIntType(qx.scalar_type()))
.memory_format(qx.suggest_memory_format()),
qx.q_scale(),
qx.q_zero_point(),
c10::nullopt);
.device(c10::kCPU)
.dtype(qx.scalar_type())
.memory_format(c10::MemoryFormat::ChannelsLast));
} else {
qy = at::_empty_affine_quantized(
oSizes,
qx.options()
.dtype(toQIntType(qx.scalar_type()))
.memory_format(qx.suggest_memory_format()),
qx.q_scale(),
qx.q_zero_point(),
c10::nullopt);
}
qmaxpool_2d_nhwc_stub(qx.device().type(), qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
return qy;
} else {
Tensor qy = at::_empty_affine_quantized(
oSizes,
qx.options().dtype(toQIntType(qx.scalar_type())),
qx.q_scale(),
qx.q_zero_point());
auto qx_contig = qx.contiguous();
auto qxd = qx_contig.data_ptr<Q>();
auto qyd = qy.data_ptr<Q>();
if (ndim == 3 || nbatch == 1) {
auto* iData = qxd;
auto* oData = qyd;
spatial_dilated_max_pooling<Q>(
iData,
iC,
iH,
iW,
oH,
oW,
kH,
kW,
sH,
sW,
pH,
pW,
dH,
dW,
oData);
Tensor qy;
if constexpr(!std::is_same_v<Q, uint8_t>) {
qy = at::_empty_affine_quantized(
oSizes,
qx.options().dtype(toQIntType(qx.scalar_type())),
qx.q_scale(),
qx.q_zero_point());
auto qx_contig = qx.contiguous();
auto qxd = qx_contig.data_ptr<Q>();
auto qyd = qy.data_ptr<Q>();
if (ndim == 3 || nbatch == 1) {
auto* iData = qxd;
auto* oData = qyd;
spatial_dilated_max_pooling<Q>(
iData,
iC,
iH,
iW,
oH,
oW,
kH,
kW,
sH,
sW,
pH,
pW,
dH,
dW,
oData);
} else {
at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
for (const auto p : c10::irange(start, end)) {
auto* iData = qxd + p * iC * iW * iH;
auto* oData = qyd + p * oC * oW * oH;
spatial_dilated_max_pooling<Q>(
iData,
iC,
iH,
iW,
oH,
oW,
kH,
kW,
sH,
sW,
pH,
pW,
dH,
dW,
oData);
}
});
}
} else {
at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
for (const auto p : c10::irange(start, end)) {
auto* iData = qxd + p * iC * iW * iH;
auto* oData = qyd + p * oC * oW * oH;
spatial_dilated_max_pooling<Q>(
iData,
iC,
iH,
iW,
oH,
oW,
kH,
kW,
sH,
sW,
pH,
pW,
dH,
dW,
oData);
}
});
// If qx is uint8 and contiguous memory format,
// Use the channels_last implementation and convert qy back to contiguous.
qy = at::empty(
oSizes,
qx.options()
.device(c10::kCPU)
.dtype(qx.scalar_type())
.memory_format(c10::MemoryFormat::ChannelsLast));
auto qx_nhwc = qx.contiguous(c10::MemoryFormat::ChannelsLast);
qmaxpool_2d_nhwc_stub(qx_nhwc.device().type(), qx_nhwc, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
qy = qy.contiguous();
}
return qy;
}
Expand Down Expand Up @@ -611,7 +637,7 @@ Tensor quantized_max_pool2d(
}
#endif
Tensor qy;
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool2d", [&]() {
AT_DISPATCH_QINT_TYPES_AND(ScalarType::Byte, qx.scalar_type(), "max_pool2d", [&]() {
qy = q_maxpool_2d<scalar_t>(
qx,
kernel_size[0],
Expand Down Expand Up @@ -706,6 +732,10 @@ class QMaxPool_arr_args final {
std::vector<int64_t> padding,
std::vector<int64_t> dilation,
bool ceil_mode) {
if (!qx.is_quantized() && kSpatialDim == 2 && qx.scalar_type() == c10::ScalarType::Byte){
return at::native::quantized_max_pool2d(qx, kernel_size, stride, padding,
dilation, ceil_mode);
}
if (kSpatialDim == 1) {
return at::quantized_max_pool1d(qx, kernel_size, stride, padding,
dilation, ceil_mode);
Expand All @@ -722,6 +752,10 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool2d"), TORCH_FN(QMaxPool_arr_args<2>::run));
}

TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool2d"), TORCH_FN(QMaxPool_arr_args<2>::run));
}

} // namespace
} // namespace native
} // namespace at
42 changes: 34 additions & 8 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,8 @@ void qmul_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
});
}

void qmaxpool_2d_nhwc_kernel(
template <typename scalar_t, typename scalar_t_underlying>
void _qmaxpool_2d_nhwc_kernel(
const Tensor& qx,
int64_t iC, // input/output channels
int64_t iH,
Expand All @@ -1476,7 +1477,6 @@ void qmaxpool_2d_nhwc_kernel(
int64_t dH,
int64_t dW, // dilation
Tensor& qy) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool2d_nhwc", [&]() {
scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());

Expand All @@ -1486,8 +1486,8 @@ void qmaxpool_2d_nhwc_kernel(
data_index_init(begin, b, nBatch, row, oH, col, oW);

for (const auto i : c10::irange(begin, end)) {
auto* i_p = reinterpret_cast<scalar_t::underlying*>(idata + b * iW * iH * iC);
auto* o_p = reinterpret_cast<scalar_t::underlying*>(odata + i * iC);
auto* i_p = reinterpret_cast<scalar_t_underlying*>(idata + b * iW * iH * iC);
auto* o_p = reinterpret_cast<scalar_t_underlying*>(odata + i * iC);

// Loop over reduction block
int64_t h_start = row * sH - pH;
Expand All @@ -1505,7 +1505,7 @@ void qmaxpool_2d_nhwc_kernel(
constexpr auto vec_width = Vectorized<scalar_t>::size();
for (; c + 4 * vec_width <= iC; c += 4 * vec_width) {
Vectorized<scalar_t> acc{
scalar_t(std::numeric_limits<scalar_t::underlying>::lowest())};
scalar_t(std::numeric_limits<scalar_t_underlying>::lowest())};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
Vectorized<scalar_t> accs[4] = {acc, acc, acc, acc};
int64_t tcntr = 0;
Expand All @@ -1528,7 +1528,7 @@ void qmaxpool_2d_nhwc_kernel(
// Vector loop
for (; c + vec_width <= iC; c += vec_width) {
Vectorized<scalar_t> acc{
scalar_t(std::numeric_limits<scalar_t::underlying>::lowest())};
scalar_t(std::numeric_limits<scalar_t_underlying>::lowest())};
int64_t tcntr = 0;
int64_t x, y;
for (y = h_start; y < h_end; y += dH) {
Expand All @@ -1542,7 +1542,7 @@ void qmaxpool_2d_nhwc_kernel(
} // for c

for (; c < iC; ++c) {
auto max_val = std::numeric_limits<scalar_t::underlying>::lowest();
auto max_val = std::numeric_limits<scalar_t_underlying>::lowest();
int64_t tcntr = 0;
int64_t x, y;
for (y = h_start; y < h_end; y += dH) {
Expand All @@ -1559,7 +1559,33 @@ void qmaxpool_2d_nhwc_kernel(
data_index_step(b, nBatch, row, oH, col, oW);
}
});
});
}

void qmaxpool_2d_nhwc_kernel(
const Tensor& qx,
int64_t iC, // input/output channels
int64_t iH,
int64_t iW, // input sizes
int64_t oH,
int64_t oW, // output sizes
int64_t kH,
int64_t kW, // kernel size
int64_t sH,
int64_t sW, // strides
int64_t pH,
int64_t pW, // padding
int64_t dH,
int64_t dW, // dilation
Tensor& qy) {
if (qx.scalar_type() == ScalarType::Byte) {
AT_DISPATCH_INTEGRAL_TYPES(qx.scalar_type(), "max_pool2d_nhwc", [&]() {
_qmaxpool_2d_nhwc_kernel<scalar_t, scalar_t>(qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
});
} else {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool2d_nhwc", [&]() {
_qmaxpool_2d_nhwc_kernel<scalar_t, scalar_t::underlying>(qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
});
}
}

void qmaxpool_3d_nthwc_kernel(
Expand Down
52 changes: 51 additions & 1 deletion test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
skipIfNoDynamoSupport,
skipIfNoONEDNN,
)
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CPU

# The dict value is match_nodes(computation_op+unary_op)
Expand Down Expand Up @@ -399,6 +399,7 @@ def forward(self, x):

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qconv2d_binary(self):
class M(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -455,6 +456,7 @@ def forward(self, x):

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qconv2d_unary(self):
class M(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -513,6 +515,7 @@ def forward(self, x):

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_dequant_promotion(self):
class M(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -552,6 +555,53 @@ def forward(self, x):
check_quantization=True,
)

@skipIfNoDynamoSupport
@skipIfRocm
def test_qmaxpool2d(self):
class M(torch.nn.Module):
def __init__(
self,
kwargs,
):
super().__init__()
self.conv = torch.nn.Conv2d(
3, 64, 7, bias=True, stride=2, padding=3, dilation=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(3, **kwargs)

def forward(self, x):
return self.maxpool(self.relu(self.conv(x)))

kwargs_list = [
{"stride": 2},
{"stride": 2, "padding": 1},
{"stride": 2, "padding": 1, "dilation": 1},
{"stride": 2, "padding": 1, "dilation": 1, "ceil_mode": False},
]
for kwargs in kwargs_list:
mod = M(kwargs).eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)
# Totally 6 pattern_matcher_count, 31 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 * 3, matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant-conv pattern matched in quantization weight prepack * 1
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. qconv2d_relu fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, relu, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 4. qmaxpool2d * 1
# [convert_element_type_3, sub_1, mul_3, max_pool2d_with_indices, getitem, mul_4, round_3, add_2,
# clamp_min_2, clamp_max_2, convert_element_type_4]
self._test_common(
mod,
(v,),
6,
31,
check_quantization=True,
)

# https://github.com/pytorch/pytorch/issues/99841.
def test_hardtanh_pattern_fallback(self):
class Model(torch.nn.Module):
Expand Down

0 comments on commit 9319dd1

Please sign in to comment.