Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][pt2e] Support conv bn fusion in convert step for QAT flow #100442

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 31 additions & 7 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
_convert_to_reference_decomposed_fx,
prepare_pt2e_quantizer,
prepare_qat_pt2e_quantizer,
)
Expand All @@ -32,9 +33,9 @@
default_symmetric_qnnpack_qat_qconfig,
)
from torch.ao.quantization.quantize_fx import (
convert_to_reference_fx,
prepare_fx,
prepare_qat_fx,
convert_to_reference_fx,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
Expand All @@ -44,9 +45,6 @@
)
from torch.testing._internal.common_quantized import override_quantized_engine


from torch.ao.quantization.quantize_fx import _convert_to_reference_decomposed_fx

@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
def test_simple_quantizer(self):
Expand Down Expand Up @@ -599,6 +597,7 @@ def _verify_symmetric_qnnpack_qat_numerics(
model: torch.nn.Module,
example_inputs: Tuple[Any, ...],
is_per_channel: bool,
verify_convert: bool = False,
):
"""
Helper method to verify that the QAT numerics for PT2E quantization match those of
Expand All @@ -615,7 +614,7 @@ def _verify_symmetric_qnnpack_qat_numerics(
aten_graph=True,
)
model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer)
result_pt2e = model_pt2e(*example_inputs)
after_prepare_result_pt2e = model_pt2e(*example_inputs)

# FX
# Note: In order to match the PT2E numerics exactly, we need to feed the
Expand All @@ -632,11 +631,36 @@ def _verify_symmetric_qnnpack_qat_numerics(
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
backend_config = get_qnnpack_backend_config()
model_fx = prepare_qat_fx(model_fx, qconfig_mapping, example_inputs, backend_config=backend_config)
result_fx = model_fx(*example_inputs)
after_prepare_result_fx = model_fx(*example_inputs)

# Verify that numerics match
self.assertEqual(result_pt2e, result_fx)
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)

if verify_convert:
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs)

model_fx = _convert_to_reference_decomposed_fx(model_fx, backend_config=backend_config)
quant_result_fx = model_fx(*example_inputs)
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)


def test_convert_qat_conv_bn_numerics(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x

example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=False)
# TODO: enable in a separate PR
# self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=True)

class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,7 @@ def visit(n: torch.fx.Node):
unimplemented("guard on data-dependent symbolic int/float")
elif isinstance(cause, torch.utils._sympy.value_ranges.ValueRangeError):
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, e.args[0]) from e
# why don't we print the exception here?
raise TorchRuntimeError() from e


Expand Down
173 changes: 169 additions & 4 deletions torch/ao/quantization/_pt2e/qat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from torch.fx import GraphModule, Node
from torch.fx.subgraph_rewriter import _replace_pattern
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from .utils import _fold_bn_weights_into_conv_node


# Example inputs for both `_conv2d_bn_pattern` and `_fused_qat_conv2d_bn_pattern`
# Example inputs for both `_conv2d_bn_pattern` and `_qat_conv2d_bn_pattern`
_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
Expand All @@ -19,6 +20,23 @@
torch.randn(1), # bn_running_var
)

# Example inputs for both `_quantized_qat_conv2d_bn_pattern` and `_folded_quantized_qat_conv2d_bn_pattern`
_quantized_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3).to(torch.int8), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
torch.tensor([1], dtype=torch.float), # input_scale
torch.tensor([0], dtype=torch.int), # input_zero_point
torch.tensor([1], dtype=torch.float), # weight_scale
torch.tensor([0], dtype=torch.int), # weight_zero_point
torch.tensor([1], dtype=torch.float), # output_scale
torch.tensor([0], dtype=torch.int), # output_zero_point
)

def _conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
Expand All @@ -32,7 +50,7 @@ def _conv2d_bn_pattern(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True)
return x

def _fused_qat_conv2d_bn_pattern(
def _qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
Expand Down Expand Up @@ -62,6 +80,96 @@ def _fused_qat_conv2d_bn_pattern(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
return x

def _quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
output_scale: torch.Tensor,
output_zero_point: torch.Tensor,
) -> torch.Tensor:
"""
Quantized version of qat conv bn pattern,
This is based on `nniqat.ConvBn2d._forward_approximate`.
used in qat convert, we first match this pattern and then replace it with
normal conv - bn pattern and then fold the weights of bn into conv
"""
# TODO: allow setting eps
bn_eps = 1e-5
weight_quant_min = -127
weight_quant_max = 127
input_quant_min = -128
input_quant_max = 127
output_quant_min = -128
output_quant_max = 127

running_std = torch.sqrt(bn_running_var + bn_eps)
scale_factor = bn_weight / running_std
weight_shape = [1] * len(conv_weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, input_quant_min, input_quant_max, torch.int8)
zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
scaled_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
scaled_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
scaled_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
x = F.conv2d(x, scaled_weight, zero_bias)
x = x / scale_factor.reshape(bias_shape)
x = x + conv_bias.reshape(bias_shape)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, output_scale, output_zero_point, output_quant_min, output_quant_max, torch.int8)
return x

def _folded_quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
output_scale: torch.Tensor,
output_zero_point: torch.Tensor,
) -> torch.Tensor:
""" Quantized QAT conv - bn pattern with bn weights being folded into conv
"""
# TODO: allow setting eps
bn_eps = 1e-5
weight_quant_min = -127
weight_quant_max = 127
input_quant_min = -128
input_quant_max = 127
output_quant_min = -128
output_quant_max = 127

x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, input_quant_min, input_quant_max, torch.int8)
conv_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
conv_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
conv_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
conv_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
x = F.conv2d(x, conv_weight, conv_bias)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, output_scale, output_zero_point, output_quant_min, output_quant_max, torch.int8)
return x

def _get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
Expand Down Expand Up @@ -94,7 +202,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
m.recompile()
example_inputs = _conv2d_bn_pattern_example_inputs
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
replacement_pattern = _get_aten_graph_module(_fused_qat_conv2d_bn_pattern, example_inputs)
replacement_pattern = _get_aten_graph_module(_qat_conv2d_bn_pattern, example_inputs)
# TODO: use the public replace_pattern API once it also returns replacement nodes
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern)
m.recompile()
Expand Down Expand Up @@ -127,3 +235,60 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
if original_node.target == operator.getitem:
replacement_getitem_node.meta = original_node.meta
return m

def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
"""
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
"""
m.graph.eliminate_dead_code()
m.recompile()
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
match_pattern = _get_aten_graph_module(_quantized_qat_conv2d_bn_pattern, example_inputs)

# Workaround: current convert does not produce q/dq ops with a specific overload
# we'll remove the overload from the pattern here as a workaround since we do not want to break BC
for n in match_pattern.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor:
n.target = torch.ops.quantized_decomposed.quantize_per_tensor
if n.op == "call_function" and n.target == torch.ops.quantized_decomposed.dequantize_per_tensor.tensor:
n.target = torch.ops.quantized_decomposed.dequantize_per_tensor

replacement_pattern = _get_aten_graph_module(_folded_quantized_qat_conv2d_bn_pattern, example_inputs)

# TODO: use the public replace_pattern API once it also returns replacement nodes
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern, ignore_literals=True)
m.recompile()

for mr in match_and_replacement:
# Find replacement conv and bn nodes by climbing upwards from anchor node
assert len(mr.replacements) == 1, "expected only one replacement node"

# find conv, bn, weight, bias nodes in the graph
replacement_quantize_node = mr.replacements[0]
assert replacement_quantize_node.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor
n = replacement_quantize_node
conv_node = None
bn_node = None
while conv_node is None or bn_node is None:
if n.target == torch.ops.aten.convolution.default:
conv_node = n
if n.target == torch.ops.aten._native_batch_norm_legit.default:
bn_node = n
assert isinstance(n.args[0], Node)
n = n.args[0]
assert conv_node is not None and bn_node is not None

conv_weight_dq = conv_node.args[1]
assert conv_weight_dq.target == torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
conv_weight_q = conv_weight_dq.args[0]
assert conv_weight_q.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor
conv_weight = conv_weight_q.args[0]
assert conv_weight.op == "get_attr"
conv_bias = conv_node.args[2]

# fold bn weights into conv
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)

m.graph.eliminate_dead_code()
m.recompile()
return m