diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index e570451523..a52032dc47 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -139,10 +139,22 @@ def forward(self, input): class FP8QDQConv2d(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): super().__init__() self.qtype = torch.float8_e4m3fn - self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype) + self.weight = torch.randn( + (out_channels, in_channels // groups, *kernel_size) + ).to(self.qtype) self.weight_scale = 2.0 self.scale = 2.0 self.bias = None @@ -170,7 +182,16 @@ def forward(self, input): output_dtype=torch.float, ) - return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d( + dq_input, + weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + def qdq(input, scale): dtype = input.dtype @@ -205,9 +226,7 @@ def create_mod_info_recursion(parent): parent_child_mod_dict = generate_model_info(model) for name, mod in model.named_modules(): mod_type_str = mod.__class__.__name__ - if mod_type_str not in [ - "Linear", "Conv2d" - ]: + if mod_type_str not in ["Linear", "Conv2d"]: continue param = mod.weight xmax = torch.max(param) @@ -225,7 +244,16 @@ def create_mod_info_recursion(parent): patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param elif mod_type_str in ["Conv2d"]: - patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False) + patched_mod = FP8QDQConv2d( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + False, + ) patched_mod.bias = mod.bias patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param @@ -610,7 +638,9 @@ def test_qconv2d_relu6_fp8_cpu(self): r""" This testcase will quantize Conv2d->ReLU6 pattern. """ - self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True) + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -627,7 +657,9 @@ def test_qconv2d_hardtanh_fp8_cpu(self): r""" This testcase will quantize Conv2d->Hardtanh pattern. """ - self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True) + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True + ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -678,7 +710,9 @@ def test_qconv2d_hardswish_fp8_cpu(self): r""" This testcase will quantize Conv2d->Hardswish pattern. """ - self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True) + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True + ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -731,7 +765,9 @@ def test_qconv2d_silu_fp8_cpu(self): r""" This testcase will quantize Conv2d->SiLU pattern. """ - self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True) + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True + ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -911,9 +947,7 @@ def forward(self, x, x2, x3): add_fn_list = quantization_add_fn_list if not is_fp8: add_fn_list = add_fn_list + quantization_inplace_add_fn_list - for add_fn, swap_inputs in itertools.product( - add_fn_list, [False, True] - ): + for add_fn, swap_inputs in itertools.product(add_fn_list, [False, True]): mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device) x = torch.randn( (1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 419561ba92..c5280b9db0 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -174,12 +174,14 @@ def get_dequantize_per_tensor_activation_pattern( output_dtype=KeywordArg("w_dtype"), ) + def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern): return _may_generate_pattern_with_dtype_convert( dequant_wgt_pattern, KeywordArg("autocast_wgt_dtype"), ) + def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): return CallFunction( aten.clone.default, @@ -187,8 +189,11 @@ def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): memory_format=KeywordArg("memory_format"), ) + def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): - return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)) + return get_dequantize_clone_weight_pattern( + get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern) + ) def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): @@ -450,14 +455,18 @@ def fn(match): break assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern - if not is_fp8 and extra_input_from_dequant and ( - (not isinstance(extra_input_of_binary_node, torch.fx.Node)) - or ( - extra_input_of_binary_node.target - not in [ - quantized_decomposed.dequantize_per_tensor.default, - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - ] + if ( + not is_fp8 + and extra_input_from_dequant + and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + not in [ + quantized_decomposed.dequantize_per_tensor.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] + ) ) ): return False @@ -692,7 +701,9 @@ def _inner(match): return _inner -def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): +def _register_qconv_weight_prepack_pass( + pattern, pass_number, dtype=torch.float32, is_fp8=False +): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype), @@ -776,7 +787,10 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if is_fp8: # For float8, we assume the scales are from aten.full.default instead of # a constant buffer to avoid constant folding of q/dq before fusion passes. - assert w_scale.target is torch.ops.aten.full.default and x_scale.target is torch.ops.aten.full.default + assert ( + w_scale.target is torch.ops.aten.full.default + and x_scale.target is torch.ops.aten.full.default + ) with torch.utils._python_dispatch._disable_current_modes(): w_scale_tensor = torch.tensor([w_scale.args[1]]) match.graph.owning_module.register_buffer("w_scale", w_scale_tensor) @@ -1446,8 +1460,12 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): - for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]): - weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8) + for dtype, is_fp8 in itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ): + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns( + dtype, is_fp8=is_fp8 + ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( @@ -2050,7 +2068,13 @@ def qconv(match: Match, *args, **kwargs): kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) - assert output_dtype in [torch.int8, torch.uint8, torch.float8_e4m3fn, torch.float32, torch.bfloat16] + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float8_e4m3fn, + torch.float32, + torch.bfloat16, + ] # Output QParams if output_dtype == torch.float8_e4m3fn: # For float8, we assume the scale is from aten.full.default instead of @@ -2297,7 +2321,9 @@ def _register_qconv_unary_fusion(): def _register_qconv_binary_fusion(): - for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]): + for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product( + [False, True], [False, True] + ): qconv_binary_op = ( torch.ops.onednn.qconv2d_pointwise.binary_tensor if x_scale_zp_are_tensors @@ -2306,7 +2332,9 @@ def _register_qconv_binary_fusion(): # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output swap_binary_inputs_list = [False, True] binary_replace_patterns = {} - for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]): + for swap_inputs, is_fp8 in itertools.product( + swap_binary_inputs_list, [False, True] + ): binary_replace_patterns.update( { PostOpAttr(