Skip to content

Commit

Permalink
[Quant][Inductor] Enable quantization conv pattern fusion inside indu…
Browse files Browse the repository at this point in the history
…ctor (#104588)

**Summary**
Enable the `dequant-quantization-quant` pattern fusion and lowering inside inductor.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_unary
```

Pull Request resolved: #104588
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581
  • Loading branch information
leslie-fang-intel authored and voznesenskym committed Aug 27, 2023
1 parent 052d311 commit 42e8de5
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 167 deletions.
9 changes: 6 additions & 3 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,18 +408,21 @@ def forward(self, x):
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)
# Totally 3 pattern_matcher_count, 10 pattern_matcher_nodes

# Totally 4 pattern_matcher_count, 17 pattern_matcher_nodes
# 1. pair of to_int8 and to_fp32 at conv input matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. dequant-conv pattern matched in quantization weight prepack
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. pair of to_int8 and to_fp32 at conv output matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
# 4. Quantization fusion in post-grad fusion pass
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
self._test_common(
mod,
(v,),
3,
10,
4,
17,
check_quantization=True,
)

Expand Down
6 changes: 5 additions & 1 deletion torch/_inductor/fx_passes/mkldnn_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from ..virtualized import ops
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern
from .quantization import _register_quantization_weight_pack_pass
from .quantization import (
_register_quantization_lowerings,
_register_quantization_weight_pack_pass,
)

if torch._C._has_mkldnn:
aten = torch.ops.aten
Expand Down Expand Up @@ -1067,6 +1070,7 @@ def _mkldnn_fusion_init():
_register_inplace_fusion()
_register_binary_unary_fusion()
_register_binary_fusion()
_register_quantization_lowerings()

@functools.lru_cache(None)
def _mkldnn_weight_pack_init():
Expand Down
4 changes: 0 additions & 4 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ def lazy_init():

_mkldnn_fusion_init()

from .quantization import register_quantization_lowerings

register_quantization_lowerings()


def reorder_for_locality(graph: torch.fx.Graph):
def visit(other_node):
Expand Down
201 changes: 113 additions & 88 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import functools

import torch
from ..ir import QConv
from ..lowering import lowerings as L
from ..pattern_matcher import Arg, CallFunction, KeywordArg, Match
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern

aten = torch.ops.aten
prims = torch.ops.prims
quantized_decomposed = torch.ops.quantized_decomposed
dequantize_per_channel = quantized_decomposed.dequantize_per_channel.default

"""
dequantize activation:
x = x.to(fp32)
x = x - zero_point
x = x * scale
"""
dequantize_activation_pattern = CallFunction(
dequantize_per_tensor_activation_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.sub.Tensor,
Expand All @@ -31,139 +30,163 @@
KeywordArg("x_scale"),
)

dequantize_weight_pattern = CallFunction(
dequantize_per_channel,
KeywordArg("w"),
dequantize_per_channel_weight_pattern = CallFunction(
quantized_decomposed.dequantize_per_channel.default,
KeywordArg("q_weight"),
KeywordArg("w_scale"),
KeywordArg("w_zp"),
KeywordArg("w_axis"), # axis for quantization
KeywordArg("w_qmin"), # quant clamp min
KeywordArg("w_qmax"), # quant clamp max
KeywordArg("qw_dtype"), # dtype=torch.int8
KeywordArg("w_axis"),
KeywordArg("w_quant_min"),
KeywordArg("w_quant_max"),
KeywordArg("w_dtype"),
)

dequantize_per_channel_clone_weight_pattern = CallFunction(
aten.clone.default,
dequantize_per_channel_weight_pattern,
memory_format=KeywordArg("memory_format"),
)

aten_conv_pattern = CallFunction(
aten.convolution.default,
dequantize_activation_pattern,
dequantize_weight_pattern,
dequantize_qconv_pt2e_pattern = CallFunction(
torch.ops.onednn.qconv2d_pointwise.default,
KeywordArg("x"),
KeywordArg("x_scale"), # x_scale
KeywordArg("x_zp"), # x_zp
KeywordArg("packed_weight"), # packed_weight
KeywordArg("w_scale"), # w_scale
KeywordArg("w_zp"), # w_zp
KeywordArg("b"), # bias
KeywordArg("stride"),
KeywordArg("padding"),
KeywordArg("dilation"),
KeywordArg("transposed"),
KeywordArg("o_padding"),
KeywordArg("groups"),
KeywordArg("inv_output_scale"), # inv_output_scale = 1.0
KeywordArg("output_zero_point"), # output_zero_point = 0
KeywordArg("fp32_output"), # fp32_output = True
KeywordArg("attr"), # attr = "none"
Arg(), # scalars
Arg(), # algorithm
)

"""
quantize output:
scale = 1 / scale
scale = 1.0 * scale
output = round(output * scale)
output = output + zero_point
output = clamp_min(output, 0)
output = clamp_max(output, 127)
output = output.to(uint8)
"""
quantize_conv_output_pattern = CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.clamp_max.default,

def generate_pattern_with_output_quant(computation_call):
"""
quantize output:
output = round(output * o_inv_scale)
output = output + zero_point
output = clamp_min(output, 0)
output = clamp_max(output, 127)
output = output.to(uint8)
"""
quantize_conv_output_pattern_pt2e = CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.clamp_min.default,
aten.clamp_max.default,
CallFunction(
aten.add.Tensor,
aten.clamp_min.default,
CallFunction(
aten.round.default,
aten.add.Tensor,
CallFunction(
aten.mul.Tensor,
aten_conv_pattern, # output of conv
aten.round.default,
CallFunction(
aten.mul.Tensor,
CallFunction(
aten.reciprocal.default, KeywordArg("o_scale")
),
Arg(), # 1.0
computation_call,
KeywordArg("o_inv_scale"),
),
),
KeywordArg("o_zp"),
),
KeywordArg("o_zp"),
KeywordArg("o_qmin"),
),
KeywordArg("o_qmin"), # 0
KeywordArg("o_qmax"),
),
KeywordArg("o_qmax"), # 127
),
KeywordArg("o_dtype"), # dtype=torch.uint8
)


dequant_per_channel_pattern = CallFunction(
quantized_decomposed.dequantize_per_channel.default, # dequant_per_channel node
KeywordArg("q_weight"),
KeywordArg("w_scale"),
KeywordArg("w_zp"),
KeywordArg("w_axis"),
KeywordArg("w_quant_min"),
KeywordArg("w_quant_max"),
KeywordArg("w_dtype"),
)


dequant_per_channel_clone_to_channel_last_pattern = CallFunction(
aten.clone.default,
dequant_per_channel_pattern,
memory_format=KeywordArg("memory_format"),
)
KeywordArg("o_dtype"),
)
return quantize_conv_output_pattern_pt2e


def _register_quantized_conv_lowering(pattern):
@register_lowering_pattern(pattern)
def _register_quantized_conv_lowering(
pattern,
pass_number,
computation_op,
fp32_output,
unary_attr,
):
@register_lowering_pattern(pattern, pass_number=pass_number)
def qconv(match: Match, *args, **kwargs):
x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
w, w_scale, w_zp, w_axis = (
kwargs["w"],
# Activation QParams
x, x_scale, x_zp = (
kwargs["x"],
kwargs["x_scale"],
kwargs["x_zp"],
)
# Weight QParams
packed_weight, w_scale, w_zp = (
kwargs["packed_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
kwargs["w_axis"],
)
b, stride, padding, dilation = (
# Conv Params
b, stride, padding, dilation, groups = (
kwargs["b"],
kwargs["stride"],
kwargs["padding"],
kwargs["dilation"],
)
groups, o_scale, o_zero_point, o_dtype = (
kwargs["groups"],
kwargs["o_scale"],
)
# Output QParams
o_inv_scale, o_zero_point = (
kwargs["o_inv_scale"],
kwargs["o_zp"],
kwargs["o_dtype"],
)
weight_shape = w.get_size()
dim = len(weight_shape) - 2
return QConv.create(
dim,
assert (
kwargs["fp32_output"] is True
) # Expected int8-in fp32-out qconv in weight prepack phase
assert (
kwargs["attr"] == "none"
) # Expected no post op fused in weight prepack phase
computation_args = (
x,
x_scale,
x_zp,
w,
packed_weight,
w_scale,
w_zp,
w_axis,
b,
stride,
padding,
dilation,
groups,
o_scale,
o_inv_scale,
o_zero_point,
o_dtype,
fp32_output,
unary_attr.op_name,
unary_attr.scalars_attr,
unary_attr.algorithm_attr,
)
return L[computation_op](*computation_args)

return qconv


def register_quantization_lowerings():
_register_quantized_conv_lowering(quantize_conv_output_pattern)
def _register_quantization_lowerings():
class UnaryAttr:
def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
self.op_name = op_name
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""

# Register dq-conv2d-q pattern for ExternKernel Lowering
quantize_conv_output_pattern_pt2e = generate_pattern_with_output_quant(
dequantize_qconv_pt2e_pattern
)
_register_quantized_conv_lowering(
quantize_conv_output_pattern_pt2e,
2, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
False, # fp32_output
UnaryAttr("none", [], ""), # unary_attr
)


def _is_valid_dequant_conv2d_pattern(match):
Expand Down Expand Up @@ -327,7 +350,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
def _generate_dequant_convolution_node_pattern(_dequant_per_channel_pattern):
dequant_convolution_node_pattern = CallFunction(
aten.convolution.default,
dequantize_activation_pattern,
dequantize_per_tensor_activation_pattern,
_dequant_per_channel_pattern,
KeywordArg("b"),
KeywordArg("stride"),
Expand All @@ -342,13 +365,15 @@ def _generate_dequant_convolution_node_pattern(_dequant_per_channel_pattern):

def _generate_qconv_weight_prepack_patterns():
return (
_generate_dequant_convolution_node_pattern(dequant_per_channel_pattern),
_generate_dequant_convolution_node_pattern(
dequantize_per_channel_weight_pattern
),
# There is another pattern due to the pass of convert_conv_weights_to_channels_last
# https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
# Depend on some heuristics, it may or may not insert to(channel_last) node
# between convolution and dequant_per_channel node
_generate_dequant_convolution_node_pattern(
dequant_per_channel_clone_to_channel_last_pattern
dequantize_per_channel_clone_weight_pattern
),
)

Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ def run_node(self, n: torch.fx.Node):
torch.ops.mkldnn._linear_pointwise.default,
torch.ops.mkldnn._linear_pointwise.binary,
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.onednn.qconv2d_pointwise.default,
]
if torch._C.has_mkl:
need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
Expand Down

0 comments on commit 42e8de5

Please sign in to comment.