From f1f35a5ed08b3968ed9a32d9173fab577d9b64ac Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 2 Dec 2025 10:22:37 -0800 Subject: [PATCH] Revert "Fix style after https://github.com/pytorch/ao/pull/3261 (#3397)" This reverts commit 316ef03cc0b9846632e77b5210b58c73d6d0b084. --- .../pt2e/test_x86inductor_fusion.py | 62 +++++-------------- .../quantization/pt2e/inductor_passes/x86.py | 60 +++++------------- 2 files changed, 30 insertions(+), 92 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index a52032dc47..e570451523 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -139,22 +139,10 @@ def forward(self, input): class FP8QDQConv2d(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.qtype = torch.float8_e4m3fn - self.weight = torch.randn( - (out_channels, in_channels // groups, *kernel_size) - ).to(self.qtype) + self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype) self.weight_scale = 2.0 self.scale = 2.0 self.bias = None @@ -182,16 +170,7 @@ def forward(self, input): output_dtype=torch.float, ) - return torch.nn.functional.conv2d( - dq_input, - weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - + return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def qdq(input, scale): dtype = input.dtype @@ -226,7 +205,9 @@ def create_mod_info_recursion(parent): parent_child_mod_dict = generate_model_info(model) for name, mod in model.named_modules(): mod_type_str = mod.__class__.__name__ - if mod_type_str not in ["Linear", "Conv2d"]: + if mod_type_str not in [ + "Linear", "Conv2d" + ]: continue param = mod.weight xmax = torch.max(param) @@ -244,16 +225,7 @@ def create_mod_info_recursion(parent): patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param elif mod_type_str in ["Conv2d"]: - patched_mod = FP8QDQConv2d( - mod.in_channels, - mod.out_channels, - mod.kernel_size, - mod.stride, - mod.padding, - mod.dilation, - mod.groups, - False, - ) + patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False) patched_mod.bias = mod.bias patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param @@ -638,9 +610,7 @@ def test_qconv2d_relu6_fp8_cpu(self): r""" This testcase will quantize Conv2d->ReLU6 pattern. """ - self._qconv2d_unary_test_helper( - device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True - ) + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -657,9 +627,7 @@ def test_qconv2d_hardtanh_fp8_cpu(self): r""" This testcase will quantize Conv2d->Hardtanh pattern. """ - self._qconv2d_unary_test_helper( - device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True - ) + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -710,9 +678,7 @@ def test_qconv2d_hardswish_fp8_cpu(self): r""" This testcase will quantize Conv2d->Hardswish pattern. """ - self._qconv2d_unary_test_helper( - device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True - ) + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -765,9 +731,7 @@ def test_qconv2d_silu_fp8_cpu(self): r""" This testcase will quantize Conv2d->SiLU pattern. """ - self._qconv2d_unary_test_helper( - device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True - ) + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -947,7 +911,9 @@ def forward(self, x, x2, x3): add_fn_list = quantization_add_fn_list if not is_fp8: add_fn_list = add_fn_list + quantization_inplace_add_fn_list - for add_fn, swap_inputs in itertools.product(add_fn_list, [False, True]): + for add_fn, swap_inputs in itertools.product( + add_fn_list, [False, True] + ): mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device) x = torch.randn( (1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index c5280b9db0..419561ba92 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -174,14 +174,12 @@ def get_dequantize_per_tensor_activation_pattern( output_dtype=KeywordArg("w_dtype"), ) - def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern): return _may_generate_pattern_with_dtype_convert( dequant_wgt_pattern, KeywordArg("autocast_wgt_dtype"), ) - def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): return CallFunction( aten.clone.default, @@ -189,11 +187,8 @@ def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): memory_format=KeywordArg("memory_format"), ) - def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): - return get_dequantize_clone_weight_pattern( - get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern) - ) + return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)) def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): @@ -455,18 +450,14 @@ def fn(match): break assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern - if ( - not is_fp8 - and extra_input_from_dequant - and ( - (not isinstance(extra_input_of_binary_node, torch.fx.Node)) - or ( - extra_input_of_binary_node.target - not in [ - quantized_decomposed.dequantize_per_tensor.default, - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - ] - ) + if not is_fp8 and extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + not in [ + quantized_decomposed.dequantize_per_tensor.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] ) ): return False @@ -701,9 +692,7 @@ def _inner(match): return _inner -def _register_qconv_weight_prepack_pass( - pattern, pass_number, dtype=torch.float32, is_fp8=False -): +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype), @@ -787,10 +776,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if is_fp8: # For float8, we assume the scales are from aten.full.default instead of # a constant buffer to avoid constant folding of q/dq before fusion passes. - assert ( - w_scale.target is torch.ops.aten.full.default - and x_scale.target is torch.ops.aten.full.default - ) + assert w_scale.target is torch.ops.aten.full.default and x_scale.target is torch.ops.aten.full.default with torch.utils._python_dispatch._disable_current_modes(): w_scale_tensor = torch.tensor([w_scale.args[1]]) match.graph.owning_module.register_buffer("w_scale", w_scale_tensor) @@ -1460,12 +1446,8 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): - for dtype, is_fp8 in itertools.product( - [torch.float32, torch.bfloat16], [True, False] - ): - weight_prepack_patterns = _generate_qconv_weight_prepack_patterns( - dtype, is_fp8=is_fp8 - ) + for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]): + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( @@ -2068,13 +2050,7 @@ def qconv(match: Match, *args, **kwargs): kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) - assert output_dtype in [ - torch.int8, - torch.uint8, - torch.float8_e4m3fn, - torch.float32, - torch.bfloat16, - ] + assert output_dtype in [torch.int8, torch.uint8, torch.float8_e4m3fn, torch.float32, torch.bfloat16] # Output QParams if output_dtype == torch.float8_e4m3fn: # For float8, we assume the scale is from aten.full.default instead of @@ -2321,9 +2297,7 @@ def _register_qconv_unary_fusion(): def _register_qconv_binary_fusion(): - for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product( - [False, True], [False, True] - ): + for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]): qconv_binary_op = ( torch.ops.onednn.qconv2d_pointwise.binary_tensor if x_scale_zp_are_tensors @@ -2332,9 +2306,7 @@ def _register_qconv_binary_fusion(): # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output swap_binary_inputs_list = [False, True] binary_replace_patterns = {} - for swap_inputs, is_fp8 in itertools.product( - swap_binary_inputs_list, [False, True] - ): + for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]): binary_replace_patterns.update( { PostOpAttr(