diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index ad59678767..520b5fbdfb 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -93,9 +93,6 @@ skipIfNoFloat8Support = unittest.skipIf( not torch_version_at_least("2.9.0"), "Float8 requires torch 2.9+" ) -skipIfNoQConvFp8Support = unittest.skipIf( - not torch_version_at_least("2.10.0.dev"), "QConv fp8 requires torch 2.10+" -) def get_default_quantizer(is_qat, is_dynamic): @@ -141,61 +138,6 @@ def forward(self, input): return out -class FP8QDQConv2d(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - super().__init__() - self.qtype = torch.float8_e4m3fn - self.weight = torch.randn( - (out_channels, in_channels // groups, *kernel_size) - ).to(self.qtype) - self.weight_scale = 2.0 - self.scale = 2.0 - self.bias = None - if bias: - self.bias = torch.randn((out_channels,)) - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - - def forward(self, input): - weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( - tensor=self.weight.data, - scale=torch.tensor([self.weight_scale]), - output_dtype=torch.float, - ) - q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( - tensor=input, - scale=torch.tensor([self.scale]), - float8_dtype=self.qtype, - ) - dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( - tensor=q_input, - scale=torch.tensor([self.scale]), - output_dtype=torch.float, - ) - - return torch.nn.functional.conv2d( - dq_input, - weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - def qdq(input, scale): dtype = input.dtype q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( @@ -229,7 +171,9 @@ def create_mod_info_recursion(parent): parent_child_mod_dict = generate_model_info(model) for name, mod in model.named_modules(): mod_type_str = mod.__class__.__name__ - if mod_type_str not in ["Linear", "Conv2d"]: + if mod_type_str not in [ + "Linear", + ]: continue param = mod.weight xmax = torch.max(param) @@ -246,20 +190,6 @@ def create_mod_info_recursion(parent): patched_mod.bias = mod.bias patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param - elif mod_type_str in ["Conv2d"]: - patched_mod = FP8QDQConv2d( - mod.in_channels, - mod.out_channels, - mod.kernel_size, - mod.stride, - mod.padding, - mod.dilation, - mod.groups, - False, - ) - patched_mod.bias = mod.bias - patched_mod.weight_scale = weight_scale.item() - patched_mod.weight.data = q_param parent = parent_child_mod_dict[mod].parent name = parent_child_mod_dict[mod].name @@ -452,7 +382,7 @@ def _test_code_common( @unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") class TestPatternMatcher(TestPatternMatcherBase): - def _qconv2d_test_helper(self, device="cpu", mixed_bf16=False, is_fp8=False): + def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): class M(torch.nn.Module): def __init__( self, @@ -478,14 +408,14 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1 # int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution] - # mixed_bf16: [dequant_node, optional(convert_element_type_4), + # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_nodes"], - 18 if mixed_bf16 else 12, + 18 if int8_mixed_bf16 else 12, ) self.assertEqual( counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 @@ -496,8 +426,7 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, - is_fp8=is_fp8, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, ) @skipIfNoDynamoSupport @@ -509,16 +438,6 @@ def test_qconv2d_cpu(self): """ self._qconv2d_test_helper("cpu") - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") - @skipIfNoQConvFp8Support - def test_qconv2d_fp8_cpu(self): - r""" - This testcase will quantize a single Conv2d module. - """ - self._qconv2d_test_helper("cpu", is_fp8=True) - @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -527,26 +446,14 @@ def test_qconv2d_int8_mixed_bf16(self): r""" This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. """ - self._qconv2d_test_helper(mixed_bf16=True) - - @skipIfNoDynamoSupport - @skipIfNoONEDNNBF16 - @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") - @skipIfNoQConvFp8Support - def test_qconv2d_fp8_mixed_bf16(self): - r""" - This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. - """ - self._qconv2d_test_helper(mixed_bf16=True, is_fp8=True) + self._qconv2d_test_helper(int8_mixed_bf16=True) def _qconv2d_unary_test_helper( self, device="cpu", - mixed_bf16=False, + int8_mixed_bf16=False, unary_op=torch.nn.ReLU(), qconv_unary_matcher_nodes=None, - is_fp8=False, ): class M(torch.nn.Module): def __init__( @@ -595,9 +502,8 @@ def matcher_check_fn(): mod, (v,), check_quantization=True, - check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, matcher_check_fn=matcher_check_fn, - is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -608,15 +514,6 @@ def test_qconv2d_relu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu") - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_relu_fp8_cpu(self): - r""" - This testcase will quantize Conv2d->ReLU pattern. - """ - self._qconv2d_unary_test_helper(device="cpu", is_fp8=True) - @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -624,7 +521,7 @@ def test_qconv2d_relu_int8_mixed_bf16_xpu(self): r""" This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization. """ - self._qconv2d_unary_test_helper(mixed_bf16=True) + self._qconv2d_unary_test_helper(int8_mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -634,17 +531,6 @@ def test_qconv2d_relu6_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6()) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_relu6_fp8_cpu(self): - r""" - This testcase will quantize Conv2d->ReLU6 pattern. - """ - self._qconv2d_unary_test_helper( - device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True - ) - @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_hardtanh_cpu(self): @@ -653,17 +539,6 @@ def test_qconv2d_hardtanh_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh()) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_hardtanh_fp8_cpu(self): - r""" - This testcase will quantize Conv2d->Hardtanh pattern. - """ - self._qconv2d_unary_test_helper( - device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True - ) - @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -676,26 +551,8 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardtanh(), - mixed_bf16=True, - qconv_unary_matcher_nodes=11, - ) - - @skipIfNoDynamoSupport - @skipIfNoONEDNNBF16 - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_hardtanh_fp8_mixed_bf16_cpu(self): - r""" - This testcase will quantize Conv2d->Hardtanh pattern. - Match.nodes: - [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor] - [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type] - """ - self._qconv2d_unary_test_helper( - unary_op=torch.nn.Hardtanh(), - mixed_bf16=True, + int8_mixed_bf16=True, qconv_unary_matcher_nodes=11, - is_fp8=True, ) @skipIfNoDynamoSupport @@ -706,17 +563,6 @@ def test_qconv2d_hardswish_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish()) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_hardswish_fp8_cpu(self): - r""" - This testcase will quantize Conv2d->Hardswish pattern. - """ - self._qconv2d_unary_test_helper( - device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True - ) - @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -730,29 +576,10 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardswish(), - mixed_bf16=True, + int8_mixed_bf16=True, qconv_unary_matcher_nodes=17, ) - @skipIfNoDynamoSupport - @skipIfNoONEDNNBF16 - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_hardswish_fp8_mixed_bf16_cpu(self): - r""" - This testcase will quantize Conv2d->Hardswish pattern. - Match.nodes: - [qconv2d_pointwise_default, convert_element_type, add, clamp_min, - clamp_max, mul, div, convert_element_type, quantize_per_tensor] - [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type] - """ - self._qconv2d_unary_test_helper( - unary_op=torch.nn.Hardswish(), - mixed_bf16=True, - qconv_unary_matcher_nodes=17, - is_fp8=True, - ) - @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_silu_cpu(self): @@ -761,17 +588,6 @@ def test_qconv2d_silu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU()) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_silu_fp8_cpu(self): - r""" - This testcase will quantize Conv2d->SiLU pattern. - """ - self._qconv2d_unary_test_helper( - device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True - ) - @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -785,31 +601,12 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.SiLU(), - mixed_bf16=True, - qconv_unary_matcher_nodes=11, - ) - - @skipIfNoDynamoSupport - @skipIfNoONEDNNBF16 - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_silu_fp8_mixed_bf16_cpu(self): - r""" - This testcase will quantize Conv2d->SiLU pattern. - Match.nodes: - [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, - convert_element_type, quantize_per_tensor] - [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type] - """ - self._qconv2d_unary_test_helper( - unary_op=torch.nn.SiLU(), - mixed_bf16=True, + int8_mixed_bf16=True, qconv_unary_matcher_nodes=11, - is_fp8=True, ) def _qconv2d_add_test_helper( - self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False + self, device="cpu", use_relu=False, int8_mixed_bf16=False ): r""" This testcase will quantize a Conv2d->Add pattern as: @@ -883,12 +680,11 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, - is_fp8=is_fp8, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, ) def _qconv2d_add_test_helper2( - self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False + self, device="cpu", use_relu=False, int8_mixed_bf16=False ): r""" This testcase will quantize two Conv2d->Add patterns as: @@ -947,10 +743,9 @@ def forward(self, x, x2, x3): res = self.relu2(res) return res - add_fn_list = quantization_add_fn_list - if not is_fp8: - add_fn_list = add_fn_list + quantization_inplace_add_fn_list - for add_fn, swap_inputs in itertools.product(add_fn_list, [False, True]): + for add_fn, swap_inputs in itertools.product( + quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True] + ): mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device) x = torch.randn( (1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device @@ -982,8 +777,7 @@ def matcher_check_fn(): (x, x2, x3), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, - is_fp8=is_fp8, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, ) @skipIfNoDynamoSupport @@ -992,27 +786,12 @@ def test_qconv2d_add_cpu(self): self._qconv2d_add_test_helper() self._qconv2d_add_test_helper2() - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_add_fp8_cpu(self): - self._qconv2d_add_test_helper(is_fp8=True) - self._qconv2d_add_test_helper2(is_fp8=True) - @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN def test_qconv2d_add_int8_mixed_bf16(self): - self._qconv2d_add_test_helper(mixed_bf16=True) - self._qconv2d_add_test_helper2(mixed_bf16=True) - - @skipIfNoDynamoSupport - @skipIfNoONEDNNBF16 - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_add_fp8_mixed_bf16(self): - self._qconv2d_add_test_helper(mixed_bf16=True, is_fp8=True) - self._qconv2d_add_test_helper2(mixed_bf16=True, is_fp8=True) + self._qconv2d_add_test_helper(int8_mixed_bf16=True) + self._qconv2d_add_test_helper2(int8_mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1020,27 +799,12 @@ def test_qconv2d_add_relu_cpu(self): self._qconv2d_add_test_helper(use_relu=True) self._qconv2d_add_test_helper2(use_relu=True) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_add_relu_fp8_cpu(self): - self._qconv2d_add_test_helper(use_relu=True, is_fp8=True) - self._qconv2d_add_test_helper2(use_relu=True, is_fp8=True) - @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN def test_qconv2d_add_relu_int8_mixed_bf16(self): - self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True) - self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True) - - @skipIfNoDynamoSupport - @skipIfNoONEDNNBF16 - @skipIfNoONEDNN - @skipIfNoQConvFp8Support - def test_qconv2d_add_relu_fp8_mixed_bf16(self): - self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True, is_fp8=True) - self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True, is_fp8=True) + self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True) + self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index c5280b9db0..a0aef11541 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -167,43 +167,29 @@ def get_dequantize_per_tensor_activation_pattern( KeywordArg("w_dtype"), ) -dequantize_fp8_weight_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - KeywordArg("q_weight"), - KeywordArg("w_scale"), - output_dtype=KeywordArg("w_dtype"), -) - - -def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern): - return _may_generate_pattern_with_dtype_convert( - dequant_wgt_pattern, +dequantize_per_channel_to_bf16_weight_pattern = ( + _may_generate_pattern_with_dtype_convert( + dequantize_per_channel_weight_pattern, KeywordArg("autocast_wgt_dtype"), ) +) +dequantize_per_channel_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_weight_pattern, + memory_format=KeywordArg("memory_format"), +) -def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): - return CallFunction( - aten.clone.default, - dequant_wgt_pattern, - memory_format=KeywordArg("memory_format"), - ) - - -def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): - return get_dequantize_clone_weight_pattern( - get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern) - ) +dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_to_bf16_weight_pattern, + memory_format=KeywordArg("memory_format"), +) -def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): - qconv_op = ( - torch.ops.onednn.qconv_pointwise.tensor - if x_scale_zp_are_tensors - else torch.ops.onednn.qconv_pointwise.default - ) +def get_qconv_pt2e_pattern(users=1): return CallFunction( - qconv_op, + torch.ops.onednn.qconv_pointwise.default, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), @@ -225,6 +211,35 @@ def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): ) +def get_qconv2d_binary_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv2d_pointwise.binary, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("accum"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("accum_scale"), + KeywordArg("accum_zero_point"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): qlinear_op = ( torch.ops.onednn.qlinear_pointwise.tensor @@ -446,7 +461,6 @@ def fn(match): return False binary_node_inputs = next(iter(compute_node.users)).args assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" - is_fp8 = match.kwargs["x"].meta["val"].dtype is torch.float8_e4m3fn if output_dtype in [torch.float32, torch.bfloat16]: extra_input_of_binary_node = None for arg in binary_node_inputs: @@ -455,18 +469,14 @@ def fn(match): break assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern - if ( - not is_fp8 - and extra_input_from_dequant - and ( - (not isinstance(extra_input_of_binary_node, torch.fx.Node)) - or ( - extra_input_of_binary_node.target - not in [ - quantized_decomposed.dequantize_per_tensor.default, - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - ] - ) + if extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + not in [ + quantized_decomposed.dequantize_per_tensor.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] ) ): return False @@ -701,9 +711,7 @@ def _inner(match): return _inner -def _register_qconv_weight_prepack_pass( - pattern, pass_number, dtype=torch.float32, is_fp8=False -): +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype), @@ -716,7 +724,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): | dequant_per_tensor | - Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight + Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight Insert weight prepack node and change the pattern to: int8 activation @@ -739,7 +747,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): ) if dtype == torch.float32: - dequant = ( + dequant_per_channel = ( clone_node.args[0] # type: ignore[union-attr] if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] @@ -750,9 +758,9 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] ) - dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr] + dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] - assert dequant.target in [ # type: ignore[union-attr] + assert dequant_per_channel.target in [ # type: ignore[union-attr] quantized_decomposed.dequantize_per_channel.default, torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] @@ -760,7 +768,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): # Activation QParams qx, x_zp, x_scale = ( kwargs["x"], - kwargs["x_zp"] if "x_zp" in kwargs else None, + kwargs["x_zp"], kwargs["x_scale"], ) @@ -768,7 +776,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): qw, w_scale, w_zp = ( kwargs["q_weight"], kwargs["w_scale"], - kwargs["w_zp"] if "w_zp" in kwargs else None, + kwargs["w_zp"], ) # Conv Params @@ -784,25 +792,14 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None - if is_fp8: - # For float8, we assume the scales are from aten.full.default instead of - # a constant buffer to avoid constant folding of q/dq before fusion passes. - assert ( - w_scale.target is torch.ops.aten.full.default - and x_scale.target is torch.ops.aten.full.default - ) - with torch.utils._python_dispatch._disable_current_modes(): - w_scale_tensor = torch.tensor([w_scale.args[1]]) - match.graph.owning_module.register_buffer("w_scale", w_scale_tensor) - w_scale = match.graph.create_node("get_attr", "w_scale") graph = match.graph with graph.inserting_before(conv_node): # Insert weight prepack node and the QConv node packed_weight_inputs = ( qw, w_scale, - x_scale.args[1] if is_fp8 else x_scale, - 0, + x_scale, + x_zp, stride, padding, dilation, @@ -833,16 +830,9 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): [], # scalars "", # algorithm ) - Node = torch.fx.node.Node - # fp8 not need zp - if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8): - new_conv_node = graph.call_function( - torch.ops.onednn.qconv_pointwise.tensor, args=new_args - ) - else: - new_conv_node = graph.call_function( - torch.ops.onednn.qconv_pointwise.default, args=new_args - ) + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) @@ -857,7 +847,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): graph.erase_node(clone_node) # type: ignore[arg-type] if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] - graph.erase_node(dequant) # type: ignore[arg-type] + graph.erase_node(dequant_per_channel) # type: ignore[arg-type] counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes @@ -865,17 +855,17 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): def _generate_dequant_convolution_node_pattern( - _dequant_pattern, dtype=torch.float32, is_fp8=False + _dequant_per_channel_pattern, dtype=torch.float32 ): assert dtype in [torch.float32, torch.bfloat16] dequant_convolution_node_pattern = CallFunction( aten.convolution.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8), + get_dequantize_per_tensor_activation_pattern(), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), - _dequant_pattern, + _dequant_per_channel_pattern, KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), @@ -887,30 +877,24 @@ def _generate_dequant_convolution_node_pattern( return dequant_convolution_node_pattern -def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False): +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): assert dtype in [torch.float32, torch.bfloat16] - if is_fp8: - dequant_wgt_pattern = dequantize_fp8_weight_pattern - else: - dequant_wgt_pattern = dequantize_per_channel_weight_pattern return ( _generate_dequant_convolution_node_pattern( - dequant_wgt_pattern + dequantize_per_channel_weight_pattern if dtype == torch.float32 - else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern), + else dequantize_per_channel_to_bf16_weight_pattern, dtype, - is_fp8=is_fp8, ), # There is another pattern due to the pass of convert_conv_weights_to_channels_last # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. # Depend on some heuristics, it may or may not insert to(channel_last) node - # between convolution and dequant node + # between convolution and dequant_per_channel node _generate_dequant_convolution_node_pattern( - get_dequantize_clone_weight_pattern(dequant_wgt_pattern) + dequantize_per_channel_clone_weight_pattern if dtype == torch.float32 - else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern), + else dequantize_per_channel_to_bf16_clone_weight_pattern, dtype, - is_fp8=is_fp8, ), ) @@ -1318,7 +1302,12 @@ def _generate_qlinear_weight_prepack_patterns( is_fp8=False, ): if is_fp8: - dequant_wgt_pattern = dequantize_fp8_weight_pattern + dequant_wgt_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), + ) else: dequant_wgt_pattern = dequantize_per_channel_weight_pattern if input_dim_exceeds_two and not input_contiguous: @@ -1460,16 +1449,12 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): - for dtype, is_fp8 in itertools.product( - [torch.float32, torch.bfloat16], [True, False] - ): - weight_prepack_patterns = _generate_qconv_weight_prepack_patterns( - dtype, is_fp8=is_fp8 - ) + for dtype in [torch.float32, torch.bfloat16]: + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( - weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8 + weight_prepack_pattern, pass_number=1, dtype=dtype ) @@ -2068,25 +2053,13 @@ def qconv(match: Match, *args, **kwargs): kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) - assert output_dtype in [ - torch.int8, - torch.uint8, - torch.float8_e4m3fn, - torch.float32, - torch.bfloat16, - ] + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] # Output QParams - if output_dtype == torch.float8_e4m3fn: - # For float8, we assume the scale is from aten.full.default instead of - # a constant buffer to avoid constant folding of q/dq before fusion passes. - assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default - o_inv_scale = kwargs["o_inv_scale"].args[1] - else: - o_inv_scale = ( - kwargs["o_inv_scale"] - if (output_dtype == torch.uint8 or output_dtype == torch.int8) - else 1.0 - ) + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) o_zero_point = ( kwargs["o_zp"] if (output_dtype == torch.uint8 or output_dtype == torch.int8) @@ -2192,69 +2165,56 @@ def _register_qconv_unary_fusion(): _silu_fusion, ) - combinations = itertools.product( - [torch.float32, torch.bfloat16], [False, True], [False, True] - ) - for original_pattern_output_dtype, x_scale_zp_are_tensors, is_fp8 in combinations: + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: # Priority 1 to match: QConv2d Unary pattern with int8 output # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant is_bf16 = original_pattern_output_dtype == torch.bfloat16 - computation_op = ( - torch.ops.onednn.qconv_pointwise.tensor - if x_scale_zp_are_tensors - else torch.ops.onednn.qconv_pointwise.default - ) conv_unary_replace_patterns = { PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), - is_fp8=is_fp8, + get_qconv_pt2e_pattern(1), ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default + get_qconv_pt2e_pattern(1), aten.relu.default ), - is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardtanh", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + get_qconv_pt2e_pattern(1), 1, is_bf16, ), with_dtype_convert=is_bf16, - is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardswish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, - is_fp8=is_fp8, ), PostOpAttr( "none", None, "swish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, - is_fp8=is_fp8, ), } @@ -2263,21 +2223,21 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - computation_op, # computation_op + torch.ops.onednn.qconv_pointwise.default, # computation_op unary_attr, # unary_attr ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default + get_qconv_pt2e_pattern(1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + get_qconv_pt2e_pattern(1), 1, is_bf16, ), @@ -2289,7 +2249,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2301,7 +2261,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2315,26 +2275,17 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - computation_op, # computation_op + torch.ops.onednn.qconv_pointwise.default, # computation_op unary_attr, # unary_attr ) def _register_qconv_binary_fusion(): - for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product( - [False, True], [False, True] - ): - qconv_binary_op = ( - torch.ops.onednn.qconv2d_pointwise.binary_tensor - if x_scale_zp_are_tensors - else torch.ops.onednn.qconv2d_pointwise.binary - ) + for int8_mixed_bf16_with_inplace_add in [False, True]: # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output swap_binary_inputs_list = [False, True] binary_replace_patterns = {} - for swap_inputs, is_fp8 in itertools.product( - swap_binary_inputs_list, [False, True] - ): + for swap_inputs in swap_binary_inputs_list: binary_replace_patterns.update( { PostOpAttr( @@ -2342,12 +2293,11 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + get_qconv_pt2e_pattern(1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), - is_fp8=is_fp8, ), PostOpAttr( "sum", 1.0, "relu", [], "" @@ -2355,14 +2305,13 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + get_qconv_pt2e_pattern(1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), aten.relu.default, ), - is_fp8=is_fp8, ), } ) @@ -2371,7 +2320,7 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - qconv_binary_op, # computation_op + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2383,7 +2332,7 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + get_qconv_pt2e_pattern(1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2401,14 +2350,14 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - qconv_binary_op, # computation_op + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op binary_unary_attr, # binary_unary_attr ) else: _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - qconv_binary_op, # computation_op + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2421,7 +2370,7 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + get_qconv_pt2e_pattern(1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2436,7 +2385,7 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number - qconv_binary_op, # computation_op + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2478,8 +2427,8 @@ def qlinear_post_op_fusion(match: Match, *args, **kwargs): # Output QParams if output_dtype == torch.float8_e4m3fn: - # For float8, we assume the scale is from aten.full.default instead of - # a constant buffer to avoid constant folding of q/dq before fusion passes. + # For float8, torchao.quantize_affine_float8 requires tensor as scale + # Support scale node is full firstly assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default o_inv_scale = kwargs["o_inv_scale"].args[1] else: