Skip to content

Commit

Permalink
[quant][pt2e] Support conv bn fusion in convert step for QAT flow (#1…
Browse files Browse the repository at this point in the history
…00442)

Summary:
Pull Request resolved: #100442

This PR adds support for folding bn weights into conv for QAT flow, this is equivalent
to the QAT branch of `from_float` in eager mode quantized conv module: https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/modules/conv.py#L223

Items that needs followup:
* there is a workaround that removes overload (.Tensor) for q/dq ops for the match pattern graph that we get from torchdynamo export, we can remove it after we change the quantized model representation

Test Plan: buck2 test @//mode/opt //caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_convert_qat_conv_bn_fusion (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)'

Reviewed By: kimishpatel

Differential Revision: D45344281

fbshipit-source-id: 40c9600220a811a140c3bf1c23851aa3d00766de
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed May 9, 2023
1 parent 4447dfa commit 25f6957
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 70 deletions.
38 changes: 31 additions & 7 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
_convert_to_reference_decomposed_fx,
prepare_pt2e_quantizer,
prepare_qat_pt2e_quantizer,
)
Expand All @@ -32,9 +33,9 @@
default_symmetric_qnnpack_qat_qconfig,
)
from torch.ao.quantization.quantize_fx import (
convert_to_reference_fx,
prepare_fx,
prepare_qat_fx,
convert_to_reference_fx,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
Expand All @@ -44,9 +45,6 @@
)
from torch.testing._internal.common_quantized import override_quantized_engine


from torch.ao.quantization.quantize_fx import _convert_to_reference_decomposed_fx

@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
def test_simple_quantizer(self):
Expand Down Expand Up @@ -599,6 +597,7 @@ def _verify_symmetric_qnnpack_qat_numerics(
model: torch.nn.Module,
example_inputs: Tuple[Any, ...],
is_per_channel: bool,
verify_convert: bool = False,
):
"""
Helper method to verify that the QAT numerics for PT2E quantization match those of
Expand All @@ -615,7 +614,7 @@ def _verify_symmetric_qnnpack_qat_numerics(
aten_graph=True,
)
model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer)
result_pt2e = model_pt2e(*example_inputs)
after_prepare_result_pt2e = model_pt2e(*example_inputs)

# FX
# Note: In order to match the PT2E numerics exactly, we need to feed the
Expand All @@ -632,11 +631,36 @@ def _verify_symmetric_qnnpack_qat_numerics(
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
backend_config = get_qnnpack_backend_config()
model_fx = prepare_qat_fx(model_fx, qconfig_mapping, example_inputs, backend_config=backend_config)
result_fx = model_fx(*example_inputs)
after_prepare_result_fx = model_fx(*example_inputs)

# Verify that numerics match
self.assertEqual(result_pt2e, result_fx)
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)

if verify_convert:
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs)

model_fx = _convert_to_reference_decomposed_fx(model_fx, backend_config=backend_config)
quant_result_fx = model_fx(*example_inputs)
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)


def test_convert_qat_conv_bn_numerics(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x

example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=False)
# TODO: enable in a separate PR
# self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=True)

class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,7 @@ def visit(n: torch.fx.Node):
unimplemented("guard on data-dependent symbolic int/float")
elif isinstance(cause, torch.utils._sympy.value_ranges.ValueRangeError):
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, e.args[0]) from e
# why don't we print the exception here?
raise TorchRuntimeError() from e


Expand Down
173 changes: 169 additions & 4 deletions torch/ao/quantization/_pt2e/qat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from torch.fx import GraphModule, Node
from torch.fx.subgraph_rewriter import _replace_pattern
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from .utils import _fold_bn_weights_into_conv_node


# Example inputs for both `_conv2d_bn_pattern` and `_fused_qat_conv2d_bn_pattern`
# Example inputs for both `_conv2d_bn_pattern` and `_qat_conv2d_bn_pattern`
_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
Expand All @@ -19,6 +20,23 @@
torch.randn(1), # bn_running_var
)

# Example inputs for both `_quantized_qat_conv2d_bn_pattern` and `_folded_quantized_qat_conv2d_bn_pattern`
_quantized_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3).to(torch.int8), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
torch.tensor([1], dtype=torch.float), # input_scale
torch.tensor([0], dtype=torch.int), # input_zero_point
torch.tensor([1], dtype=torch.float), # weight_scale
torch.tensor([0], dtype=torch.int), # weight_zero_point
torch.tensor([1], dtype=torch.float), # output_scale
torch.tensor([0], dtype=torch.int), # output_zero_point
)

def _conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
Expand All @@ -32,7 +50,7 @@ def _conv2d_bn_pattern(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True)
return x

def _fused_qat_conv2d_bn_pattern(
def _qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
Expand Down Expand Up @@ -62,6 +80,96 @@ def _fused_qat_conv2d_bn_pattern(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
return x

def _quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
output_scale: torch.Tensor,
output_zero_point: torch.Tensor,
) -> torch.Tensor:
"""
Quantized version of qat conv bn pattern,
This is based on `nniqat.ConvBn2d._forward_approximate`.
used in qat convert, we first match this pattern and then replace it with
normal conv - bn pattern and then fold the weights of bn into conv
"""
# TODO: allow setting eps
bn_eps = 1e-5
weight_quant_min = -127
weight_quant_max = 127
input_quant_min = -128
input_quant_max = 127
output_quant_min = -128
output_quant_max = 127

running_std = torch.sqrt(bn_running_var + bn_eps)
scale_factor = bn_weight / running_std
weight_shape = [1] * len(conv_weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, input_quant_min, input_quant_max, torch.int8)
zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
scaled_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
scaled_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
scaled_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
x = F.conv2d(x, scaled_weight, zero_bias)
x = x / scale_factor.reshape(bias_shape)
x = x + conv_bias.reshape(bias_shape)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, output_scale, output_zero_point, output_quant_min, output_quant_max, torch.int8)
return x

def _folded_quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
output_scale: torch.Tensor,
output_zero_point: torch.Tensor,
) -> torch.Tensor:
""" Quantized QAT conv - bn pattern with bn weights being folded into conv
"""
# TODO: allow setting eps
bn_eps = 1e-5
weight_quant_min = -127
weight_quant_max = 127
input_quant_min = -128
input_quant_max = 127
output_quant_min = -128
output_quant_max = 127

x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, input_quant_min, input_quant_max, torch.int8)
conv_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
conv_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
conv_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
conv_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
x = F.conv2d(x, conv_weight, conv_bias)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, output_scale, output_zero_point, output_quant_min, output_quant_max, torch.int8)
return x

def _get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
Expand Down Expand Up @@ -94,7 +202,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
m.recompile()
example_inputs = _conv2d_bn_pattern_example_inputs
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
replacement_pattern = _get_aten_graph_module(_fused_qat_conv2d_bn_pattern, example_inputs)
replacement_pattern = _get_aten_graph_module(_qat_conv2d_bn_pattern, example_inputs)
# TODO: use the public replace_pattern API once it also returns replacement nodes
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern)
m.recompile()
Expand Down Expand Up @@ -127,3 +235,60 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
if original_node.target == operator.getitem:
replacement_getitem_node.meta = original_node.meta
return m

def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
"""
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
"""
m.graph.eliminate_dead_code()
m.recompile()
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
match_pattern = _get_aten_graph_module(_quantized_qat_conv2d_bn_pattern, example_inputs)

# Workaround: current convert does not produce q/dq ops with a specific overload
# we'll remove the overload from the pattern here as a workaround since we do not want to break BC
for n in match_pattern.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor:
n.target = torch.ops.quantized_decomposed.quantize_per_tensor
if n.op == "call_function" and n.target == torch.ops.quantized_decomposed.dequantize_per_tensor.tensor:
n.target = torch.ops.quantized_decomposed.dequantize_per_tensor

replacement_pattern = _get_aten_graph_module(_folded_quantized_qat_conv2d_bn_pattern, example_inputs)

# TODO: use the public replace_pattern API once it also returns replacement nodes
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern, ignore_literals=True)
m.recompile()

for mr in match_and_replacement:
# Find replacement conv and bn nodes by climbing upwards from anchor node
assert len(mr.replacements) == 1, "expected only one replacement node"

# find conv, bn, weight, bias nodes in the graph
replacement_quantize_node = mr.replacements[0]
assert replacement_quantize_node.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor
n = replacement_quantize_node
conv_node = None
bn_node = None
while conv_node is None or bn_node is None:
if n.target == torch.ops.aten.convolution.default:
conv_node = n
if n.target == torch.ops.aten._native_batch_norm_legit.default:
bn_node = n
assert isinstance(n.args[0], Node)
n = n.args[0]
assert conv_node is not None and bn_node is not None

conv_weight_dq = conv_node.args[1]
assert conv_weight_dq.target == torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
conv_weight_q = conv_weight_dq.args[0]
assert conv_weight_q.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor
conv_weight = conv_weight_q.args[0]
assert conv_weight.op == "get_attr"
conv_bias = conv_node.args[2]

# fold bn weights into conv
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)

m.graph.eliminate_dead_code()
m.recompile()
return m

0 comments on commit 25f6957

Please sign in to comment.