From c18f1a7c2f50e9a849427bc30d8d547c43d53e4a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 22 Sep 2023 13:18:33 -0700 Subject: [PATCH] [quant][pt2] Support cudnn_batch_norm in QAT fusion Summary: Today, we get different batch norm ops depending on the device the model is placed on at export time. Exporting `model.cpu()` gives `_native_batch_norm_legit`, while exporting `model.cuda()` gives `cudnn_batch_norm`. QAT fusion currently only supports the former and silently ignores the latter. This commit fixes this by additionally matching on the latter op during QAT fusion. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_conv_bn_fusion python test/test_quantization.py TestQuantizePT2EQAT.test_qat_conv_bn_relu_fusion Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar ghstack-source-id: b2903e80497bc9696de8a93d3177d153f00ecfac Pull Request resolved: https://github.com/pytorch/pytorch/pull/109908 --- .../pt2e/test_quantize_pt2e_qat.py | 34 ++++++++++++++++--- torch/ao/quantization/pt2e/qat_utils.py | 30 ++++++++++++---- torch/ao/quantization/pt2e/utils.py | 18 +++++++++- 3 files changed, 71 insertions(+), 11 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index 42d6fe205a515..b68c0885d6185 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: quantization"] import copy import operator +import unittest from typing import Any, Optional, Tuple import torch @@ -26,6 +27,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_quantization import ( QuantizationTestCase, skip_if_no_torchvision, @@ -122,6 +124,7 @@ def _verify_symmetric_xnnpack_qat_graph( example_inputs: Tuple[Any, ...], has_relu: bool, has_bias: bool = True, + is_cuda: bool = False, expected_conv_literal_args: Optional[Tuple[Any, ...]] = None, ): self._verify_symmetric_xnnpack_qat_graph_helper( @@ -130,6 +133,7 @@ def _verify_symmetric_xnnpack_qat_graph( is_per_channel=True, has_relu=has_relu, has_bias=has_bias, + is_cuda=is_cuda, expected_conv_literal_args=expected_conv_literal_args, ) self._verify_symmetric_xnnpack_qat_graph_helper( @@ -138,6 +142,7 @@ def _verify_symmetric_xnnpack_qat_graph( is_per_channel=False, has_relu=has_relu, has_bias=has_bias, + is_cuda=is_cuda, expected_conv_literal_args=expected_conv_literal_args, ) @@ -148,6 +153,7 @@ def _verify_symmetric_xnnpack_qat_graph_helper( is_per_channel: bool, has_relu: bool, has_bias: bool = True, + is_cuda: bool = False, expected_conv_literal_args: Optional[Tuple[Any, ...]] = None, ): """ @@ -189,10 +195,12 @@ def _verify_symmetric_xnnpack_qat_graph_helper( relu_node = None getitem_node = output_fq_node.args[0] bn_node = getitem_node.args[0] + if is_cuda: + expected_bn_op = torch.ops.aten.cudnn_batch_norm.default + else: + expected_bn_op = torch.ops.aten._native_batch_norm_legit.default self.assertEqual(getitem_node.target, operator.getitem) - self.assertEqual( - bn_node.target, torch.ops.aten._native_batch_norm_legit.default - ) + self.assertEqual(bn_node.target, expected_bn_op) # Verify: conv / scale_factor.reshape [+ bias.reshape] if has_bias: @@ -303,11 +311,20 @@ def forward(self, x): self._verify_symmetric_xnnpack_qat_numerics(M(has_relu=True), example_inputs) def test_qat_conv_bn_fusion(self): - example_inputs = (torch.randn(1, 3, 5, 5),) m = TestHelperModules.ConvWithBNRelu(relu=False) + example_inputs = (torch.randn(1, 3, 5, 5),) self._verify_symmetric_xnnpack_qat_graph(m, example_inputs, has_relu=False) self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_qat_conv_bn_fusion_cuda(self): + m = TestHelperModules.ConvWithBNRelu(relu=False).cuda() + example_inputs = (torch.randn(1, 3, 5, 5).cuda(),) + self._verify_symmetric_xnnpack_qat_graph( + m, example_inputs, has_relu=False, is_cuda=True, + ) + self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + def test_qat_conv_bn_fusion_literal_args(self): class M(torch.nn.Module): def __init__(self): @@ -368,6 +385,15 @@ def test_qat_conv_bn_relu_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, example_inputs, has_relu=True) self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_qat_conv_bn_relu_fusion_cuda(self): + m = TestHelperModules.ConvWithBNRelu(relu=True).cuda() + example_inputs = (torch.randn(1, 3, 5, 5).cuda(),) + self._verify_symmetric_xnnpack_qat_graph( + m, example_inputs, has_relu=True, is_cuda=True, + ) + self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + def test_qat_conv_bn_relu_fusion_no_conv_bias(self): m = TestHelperModules.ConvWithBNRelu(relu=True, bias=False) example_inputs = (torch.randn(3, 3, 5, 5),) diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index e463fec166aa9..25cd5ff572bc6 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -15,6 +15,7 @@ QuantizationSpecBase, ) from .utils import ( + _is_supported_batch_norm_for_training, fold_bn_weights_into_conv_node, get_aten_graph_module, ) @@ -43,6 +44,7 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs( is_per_channel: bool, has_bias: bool, + is_cuda: bool, ) -> Dict[str, Any]: """ Optional example inputs for both `_quantized_qat_conv2d_bn_pattern` @@ -59,6 +61,10 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs( kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int) if has_bias: kwargs["conv_bias"] = torch.randn(1) + if is_cuda: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.cuda() return kwargs def _conv2d_bn_pattern( @@ -367,7 +373,7 @@ def _get_conv_bn_getitem_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]: if n.target == torch.ops.aten.conv2d.default: assert conv_node is None conv_node = n - elif n.target == torch.ops.aten._native_batch_norm_legit.default: + elif _is_supported_batch_norm_for_training(n): assert bn_node is None bn_node = n elif n.target == operator.getitem: @@ -490,6 +496,11 @@ def _get_new_qspec(qspec: QuantizationSpecBase): annotation.output_qspec = _get_new_qspec(annotation.output_qspec) def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: + m = _fuse_conv_bn_qat_helper(m, is_cuda=False) + m = _fuse_conv_bn_qat_helper(m, is_cuda=True) + return m + +def _fuse_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule: """ Given a graph of decomposed aten ops, replace the (conv + bn) pattern with the fused QAT subgraph equivalent. The input graph should already be annotated. @@ -501,7 +512,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: m.graph.eliminate_dead_code() m.recompile() example_inputs = _conv2d_bn_pattern_example_inputs - match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs) + match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs, is_cuda) # Step (1): Replace patterns with conv bias # @@ -512,6 +523,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: replacement_pattern_with_conv_bias = get_aten_graph_module( _qat_conv2d_bn_pattern, example_inputs, + is_cuda, ) replacements_with_conv_bias = replace_pattern_with_filters( m, @@ -527,6 +539,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: replacement_pattern_no_conv_bias = get_aten_graph_module( _qat_conv2d_bn_pattern_no_conv_bias, example_inputs, + is_cuda, ) replacements_no_conv_bias = replace_pattern_with_filters( m, @@ -569,7 +582,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: _copy_over_literal_conv_args(original_node, replacement_conv_node) # Step (3c): Update old references in the conv node's input_qspec_map _update_conv_input_qspec_map_after_replacement(original_node, replacement_conv_node) - if original_node.target == torch.ops.aten._native_batch_norm_legit.default: + if _is_supported_batch_norm_for_training(original_node): replacement_bn_node.meta = original_node.meta original_to_replacement_node[original_node] = replacement_bn_node if original_node.target == operator.getitem: @@ -632,6 +645,11 @@ def _remove_extra_dequantize(m: GraphModule): m.recompile() def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: + m = _fold_conv_bn_qat_helper(m, is_cuda=False) + m = _fold_conv_bn_qat_helper(m, is_cuda=True) + return m + +def _fold_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule: """ Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv. """ @@ -653,15 +671,15 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: if not has_relu and relu_is_inplace: continue example_inputs = _quantized_conv2d_bn_pattern_example_inputs - kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias) + kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias, is_cuda) match_pattern = _get_quantized_qat_conv2d_bn_pattern( is_per_channel, has_relu, has_bias, relu_is_inplace, ) - match_pattern = get_aten_graph_module(match_pattern, example_inputs, **kwargs) + match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs) replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern( is_per_channel, has_relu, has_bias, relu_is_inplace, ) - replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, **kwargs) + replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs) replacements.extend( replace_pattern_with_filters( m, diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index a228629a657e2..e900b124a6532 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -124,6 +124,19 @@ def _get_all_arguments(orig_args, orig_kwargs, args_schema): all_args.append(schema.default_value) return all_args +def _is_supported_batch_norm_for_training(node: Node): + """ + Return True if the given node refers to an aten batch norm op QAT supports. + """ + supported_ops = [ + torch.ops.aten._native_batch_norm_legit.default, + # Note: we won't need this op anymore after batch norm consolidation + # For now, we need to continue to support it because it gives better + # training numerics than `_native_batch_norm_legit` + torch.ops.aten.cudnn_batch_norm.default, + ] + return node.target in supported_ops + def fold_bn_weights_into_conv_node( conv_node: Node, conv_weight_node: Node, @@ -148,7 +161,7 @@ def fold_bn_weights_into_conv_node( bn_rv = _get_tensor_constant_from_node(bn_args[4], m) if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default: eps_arg_index = 6 - elif bn_node.target == torch.ops.aten._native_batch_norm_legit.default: + elif _is_supported_batch_norm_for_training(bn_node): eps_arg_index = 7 else: raise ValueError("BN node target is unexpected ", bn_node.target) @@ -229,11 +242,14 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]: def get_aten_graph_module( pattern: Callable, example_inputs: Tuple[Any, ...], + is_cuda: bool = False, **kwargs, ) -> GraphModule: """ Convert the pattern to an FX graph with decomposed aten ops. """ + if is_cuda: + example_inputs = tuple([x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]) aten_pattern = capture_pre_autograd_graph( pattern, example_inputs,