Skip to content

Commit

Permalink
[quant][pt2] Support cudnn_batch_norm in QAT fusion
Browse files Browse the repository at this point in the history
Summary: Today, we get different batch norm ops depending on
the device the model is placed on at export time. Exporting
`model.cpu()` gives `_native_batch_norm_legit`, while exporting
`model.cuda()` gives `cudnn_batch_norm`. QAT fusion currently
only supports the former and silently ignores the latter. This
commit fixes this by additionally matching on the latter op
during QAT fusion.

Test Plan:
python test/test_quantization.py TestQuantizePT2EQAT.test_qat_conv_bn_fusion
python test/test_quantization.py TestQuantizePT2EQAT.test_qat_conv_bn_relu_fusion

Reviewers: jerryzh168, kimishpatel

Subscribers: jerryzh168, kimishpatel, supriyar

ghstack-source-id: b2903e80497bc9696de8a93d3177d153f00ecfac
Pull Request resolved: #109908
  • Loading branch information
andrewor14 committed Sep 22, 2023
1 parent 629a628 commit c18f1a7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 11 deletions.
34 changes: 30 additions & 4 deletions test/quantization/pt2e/test_quantize_pt2e_qat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: quantization"]
import copy
import operator
import unittest
from typing import Any, Optional, Tuple

import torch
Expand All @@ -26,6 +27,7 @@
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skip_if_no_torchvision,
Expand Down Expand Up @@ -122,6 +124,7 @@ def _verify_symmetric_xnnpack_qat_graph(
example_inputs: Tuple[Any, ...],
has_relu: bool,
has_bias: bool = True,
is_cuda: bool = False,
expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
):
self._verify_symmetric_xnnpack_qat_graph_helper(
Expand All @@ -130,6 +133,7 @@ def _verify_symmetric_xnnpack_qat_graph(
is_per_channel=True,
has_relu=has_relu,
has_bias=has_bias,
is_cuda=is_cuda,
expected_conv_literal_args=expected_conv_literal_args,
)
self._verify_symmetric_xnnpack_qat_graph_helper(
Expand All @@ -138,6 +142,7 @@ def _verify_symmetric_xnnpack_qat_graph(
is_per_channel=False,
has_relu=has_relu,
has_bias=has_bias,
is_cuda=is_cuda,
expected_conv_literal_args=expected_conv_literal_args,
)

Expand All @@ -148,6 +153,7 @@ def _verify_symmetric_xnnpack_qat_graph_helper(
is_per_channel: bool,
has_relu: bool,
has_bias: bool = True,
is_cuda: bool = False,
expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
):
"""
Expand Down Expand Up @@ -189,10 +195,12 @@ def _verify_symmetric_xnnpack_qat_graph_helper(
relu_node = None
getitem_node = output_fq_node.args[0]
bn_node = getitem_node.args[0]
if is_cuda:
expected_bn_op = torch.ops.aten.cudnn_batch_norm.default
else:
expected_bn_op = torch.ops.aten._native_batch_norm_legit.default
self.assertEqual(getitem_node.target, operator.getitem)
self.assertEqual(
bn_node.target, torch.ops.aten._native_batch_norm_legit.default
)
self.assertEqual(bn_node.target, expected_bn_op)

# Verify: conv / scale_factor.reshape [+ bias.reshape]
if has_bias:
Expand Down Expand Up @@ -303,11 +311,20 @@ def forward(self, x):
self._verify_symmetric_xnnpack_qat_numerics(M(has_relu=True), example_inputs)

def test_qat_conv_bn_fusion(self):
example_inputs = (torch.randn(1, 3, 5, 5),)
m = TestHelperModules.ConvWithBNRelu(relu=False)
example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_xnnpack_qat_graph(m, example_inputs, has_relu=False)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_qat_conv_bn_fusion_cuda(self):
m = TestHelperModules.ConvWithBNRelu(relu=False).cuda()
example_inputs = (torch.randn(1, 3, 5, 5).cuda(),)
self._verify_symmetric_xnnpack_qat_graph(
m, example_inputs, has_relu=False, is_cuda=True,
)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

def test_qat_conv_bn_fusion_literal_args(self):
class M(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -368,6 +385,15 @@ def test_qat_conv_bn_relu_fusion(self):
self._verify_symmetric_xnnpack_qat_graph(m, example_inputs, has_relu=True)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_qat_conv_bn_relu_fusion_cuda(self):
m = TestHelperModules.ConvWithBNRelu(relu=True).cuda()
example_inputs = (torch.randn(1, 3, 5, 5).cuda(),)
self._verify_symmetric_xnnpack_qat_graph(
m, example_inputs, has_relu=True, is_cuda=True,
)
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

def test_qat_conv_bn_relu_fusion_no_conv_bias(self):
m = TestHelperModules.ConvWithBNRelu(relu=True, bias=False)
example_inputs = (torch.randn(3, 3, 5, 5),)
Expand Down
30 changes: 24 additions & 6 deletions torch/ao/quantization/pt2e/qat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
QuantizationSpecBase,
)
from .utils import (
_is_supported_batch_norm_for_training,
fold_bn_weights_into_conv_node,
get_aten_graph_module,
)
Expand Down Expand Up @@ -43,6 +44,7 @@
def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
is_per_channel: bool,
has_bias: bool,
is_cuda: bool,
) -> Dict[str, Any]:
"""
Optional example inputs for both `_quantized_qat_conv2d_bn_pattern`
Expand All @@ -59,6 +61,10 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias:
kwargs["conv_bias"] = torch.randn(1)
if is_cuda:
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
kwargs[k] = v.cuda()
return kwargs

def _conv2d_bn_pattern(
Expand Down Expand Up @@ -367,7 +373,7 @@ def _get_conv_bn_getitem_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]:
if n.target == torch.ops.aten.conv2d.default:
assert conv_node is None
conv_node = n
elif n.target == torch.ops.aten._native_batch_norm_legit.default:
elif _is_supported_batch_norm_for_training(n):
assert bn_node is None
bn_node = n
elif n.target == operator.getitem:
Expand Down Expand Up @@ -490,6 +496,11 @@ def _get_new_qspec(qspec: QuantizationSpecBase):
annotation.output_qspec = _get_new_qspec(annotation.output_qspec)

def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
m = _fuse_conv_bn_qat_helper(m, is_cuda=False)
m = _fuse_conv_bn_qat_helper(m, is_cuda=True)
return m

def _fuse_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
"""
Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
the fused QAT subgraph equivalent. The input graph should already be annotated.
Expand All @@ -501,7 +512,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
m.graph.eliminate_dead_code()
m.recompile()
example_inputs = _conv2d_bn_pattern_example_inputs
match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs, is_cuda)

# Step (1): Replace patterns with conv bias
#
Expand All @@ -512,6 +523,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
replacement_pattern_with_conv_bias = get_aten_graph_module(
_qat_conv2d_bn_pattern,
example_inputs,
is_cuda,
)
replacements_with_conv_bias = replace_pattern_with_filters(
m,
Expand All @@ -527,6 +539,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
replacement_pattern_no_conv_bias = get_aten_graph_module(
_qat_conv2d_bn_pattern_no_conv_bias,
example_inputs,
is_cuda,
)
replacements_no_conv_bias = replace_pattern_with_filters(
m,
Expand Down Expand Up @@ -569,7 +582,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
_copy_over_literal_conv_args(original_node, replacement_conv_node)
# Step (3c): Update old references in the conv node's input_qspec_map
_update_conv_input_qspec_map_after_replacement(original_node, replacement_conv_node)
if original_node.target == torch.ops.aten._native_batch_norm_legit.default:
if _is_supported_batch_norm_for_training(original_node):
replacement_bn_node.meta = original_node.meta
original_to_replacement_node[original_node] = replacement_bn_node
if original_node.target == operator.getitem:
Expand Down Expand Up @@ -632,6 +645,11 @@ def _remove_extra_dequantize(m: GraphModule):
m.recompile()

def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
m = _fold_conv_bn_qat_helper(m, is_cuda=False)
m = _fold_conv_bn_qat_helper(m, is_cuda=True)
return m

def _fold_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
"""
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
"""
Expand All @@ -653,15 +671,15 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
if not has_relu and relu_is_inplace:
continue
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias)
kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias, is_cuda)
match_pattern = _get_quantized_qat_conv2d_bn_pattern(
is_per_channel, has_relu, has_bias, relu_is_inplace,
)
match_pattern = get_aten_graph_module(match_pattern, example_inputs, **kwargs)
match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs)
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(
is_per_channel, has_relu, has_bias, relu_is_inplace,
)
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs)
replacements.extend(
replace_pattern_with_filters(
m,
Expand Down
18 changes: 17 additions & 1 deletion torch/ao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,19 @@ def _get_all_arguments(orig_args, orig_kwargs, args_schema):
all_args.append(schema.default_value)
return all_args

def _is_supported_batch_norm_for_training(node: Node):
"""
Return True if the given node refers to an aten batch norm op QAT supports.
"""
supported_ops = [
torch.ops.aten._native_batch_norm_legit.default,
# Note: we won't need this op anymore after batch norm consolidation
# For now, we need to continue to support it because it gives better
# training numerics than `_native_batch_norm_legit`
torch.ops.aten.cudnn_batch_norm.default,
]
return node.target in supported_ops

def fold_bn_weights_into_conv_node(
conv_node: Node,
conv_weight_node: Node,
Expand All @@ -148,7 +161,7 @@ def fold_bn_weights_into_conv_node(
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
eps_arg_index = 6
elif bn_node.target == torch.ops.aten._native_batch_norm_legit.default:
elif _is_supported_batch_norm_for_training(bn_node):
eps_arg_index = 7
else:
raise ValueError("BN node target is unexpected ", bn_node.target)
Expand Down Expand Up @@ -229,11 +242,14 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
def get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
is_cuda: bool = False,
**kwargs,
) -> GraphModule:
"""
Convert the pattern to an FX graph with decomposed aten ops.
"""
if is_cuda:
example_inputs = tuple([x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs])
aten_pattern = capture_pre_autograd_graph(
pattern,
example_inputs,
Expand Down

0 comments on commit c18f1a7

Please sign in to comment.